In [1]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.backend as K

# 設置數據生成器
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=45,
    brightness_range=[0.9, 1.1],
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    shear_range=0.1,
)

# 訓練數據生成器
train_generator = train_datagen.flow_from_directory(
    r'E:\Codes\CV train\回收系統dataset\train',  # 替換為您的資料路徑
    target_size=(224,224),
    batch_size=16,
    class_mode='categorical',
)


validation_datagen = ImageDataGenerator(
    rescale=1./255,
)
# 驗證數據生成器
validation_generator = validation_datagen.flow_from_directory(
    r'E:\Codes\CV train\回收系統dataset\TrashBox',
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical',
)

# 教師模型
teacher_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = teacher_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(64,activation='relu')(x)
x = Dense(3, activation='softmax')(x)
teacher_model = Model(teacher_model.input, x)
teacher_model.trainable = False

# 學生模型
student_model = MobileNetV2(alpha=0.35, weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = student_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(64,activation='relu')(x)
x = Dense(3, activation='softmax')(x)
student_model = Model(student_model.input, x)
for layer in student_model.layers:
    layer.trainable = False  # 可選擇解凍部分層

# 蒸餾損失
def distillation_loss(y_true, y_pred, teacher_logits, temperature=5.0, alpha=0.5):
    teacher_soft = K.softmax(teacher_logits / temperature)
    student_soft = K.softmax(y_pred / temperature)
    distill_loss = K.categorical_crossentropy(teacher_soft, student_soft) * (temperature ** 2)
    student_loss = K.categorical_crossentropy(y_true, y_pred)
    return alpha * distill_loss + (1 - alpha) * student_loss

# 訓練
optimizer = tf.keras.optimizers.Adam()
student_model.compile(optimizer=optimizer, loss=lambda y_true, y_pred: distillation_loss(y_true, y_pred, teacher_model.output))

def train_with_distillation(epochs=10):
    steps_per_epoch = train_generator.samples // 16
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        for step, (x_batch, y_batch) in enumerate(train_generator):
            if step >= steps_per_epoch:  # 限制步數
                break
            with tf.GradientTape() as tape:
                student_logits = student_model(x_batch, training=True)
                teacher_logits = teacher_model(x_batch, training=False)
                loss = distillation_loss(y_batch, student_logits, teacher_logits)
            grads = tape.gradient(loss, student_model.trainable_weights)
            optimizer.apply_gradients(zip(grads, student_model.trainable_weights))
            if step % 10 == 0:
                print(f"Step {step}, Loss: {loss.numpy()}")

train_with_distillation(epochs=10)
student_model.save('distilled_model.h5')

Found 6370 images belonging to 3 classes.
Found 9394 images belonging to 3 classes.
Epoch 1/10
Step 0, Loss: [14.489811 14.579288 13.993038 14.33701  14.089329 14.535907 14.314287
 13.984363 14.355761 14.192928 14.199024 14.429907 14.08404  14.486324
 14.307125 14.005787]
Step 10, Loss: [14.013374  14.283559  14.2919855 14.093738  14.576868  14.231466
 14.109308  14.2658205 14.449     14.280907  14.3002205 14.599022
 14.652584  13.82783   15.019039  14.262299 ]
Step 20, Loss: [14.350805  14.108237  14.687073  13.890065  14.574833  14.845942
 14.206521  14.189568  14.3220215 14.425041  14.208998  14.265815
 14.41795   14.395954  14.309627  14.465677 ]
Step 30, Loss: [14.553475 14.521157 14.759769 14.206756 14.267199 14.225507 14.215296
 14.347721 14.050127 14.148804 14.269066 14.02713  14.086663 14.281988
 14.49991  14.056877]
Step 40, Loss: [14.211646  14.382012  14.287823  14.59356   14.315852  14.37658
 14.458922  14.451364  14.3333435 14.67121   14.255772  14.06497
 14.457702  14.51

