In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

# One-hot encode the labels
y_train_onehot = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=10)


In [2]:
# Define the Teacher Model
teacher_model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])
teacher_model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [3]:
teacher_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])


In [4]:
# Train the Teacher Model
teacher_model.fit(x_train, y_train_onehot, epochs=5, validation_split=0.1)


Epoch 1/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.9046 - loss: 0.3118 - val_accuracy: 0.9838 - val_loss: 0.0526
Epoch 2/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.9857 - loss: 0.0457 - val_accuracy: 0.9893 - val_loss: 0.0412
Epoch 3/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.9922 - loss: 0.0263 - val_accuracy: 0.9893 - val_loss: 0.0358
Epoch 4/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 6ms/step - accuracy: 0.9936 - loss: 0.0187 - val_accuracy: 0.9882 - val_loss: 0.0436
Epoch 5/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.9951 - loss: 0.0147 - val_accuracy: 0.9917 - val_loss: 0.0341


<keras.src.callbacks.history.History at 0x323d2bdf0>

In [5]:
# Generate Soft Labels from the Teacher Model
soft_labels = teacher_model.predict(x_train)


[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step


In [6]:
# Define the Student Model
student_model = models.Sequential([
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])
student_model.summary()

In [7]:
# Define Knowledge Distillation Loss
def distillation_loss(y_true, y_pred, teacher_pred, temperature=5):
    # Scale predictions by temperature
    teacher_pred_scaled = tf.nn.softmax(teacher_pred / temperature)
    y_pred_scaled = tf.nn.softmax(y_pred / temperature)
    
    # Cross-entropy between teacher and student predictions
    kd_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(teacher_pred_scaled, y_pred_scaled)
    )
    # Add standard cross-entropy loss with true labels
    ce_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    )
    return kd_loss * 0.5 + ce_loss * 0.5


In [8]:
epochs = 1
batch_size = 32
num_batches = len(x_train) // batch_size

print("num_batches : ",num_batches)

num_batches :  1875


In [9]:
optimizer = tf.keras.optimizers.Adam()

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    for i in range(num_batches):
        # Get a batch of data
        start = i * batch_size
        end = start + batch_size
        x_batch = x_train[start:end]
        y_batch = y_train_onehot[start:end]
        
        with tf.GradientTape() as tape:
            predictions = student_model(x_batch, training=True)
            teacher_predictions = teacher_model(x_batch, training=True)
            loss = distillation_loss(y_batch, predictions, teacher_predictions)


        gradients = tape.gradient(loss, student_model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, student_model.trainable_weights))

        if i % 200 == 0:  # Print progress every 200 batches
            print(f"Batch {i}/{num_batches}, Loss: {loss.numpy():.4f}")


Epoch 1/1
Batch 0/1875, Loss: 2.2985
Batch 200/1875, Loss: 1.3579
Batch 400/1875, Loss: 1.2145
Batch 600/1875, Loss: 1.2190
Batch 800/1875, Loss: 1.2070
Batch 1000/1875, Loss: 1.2568
Batch 1200/1875, Loss: 1.2126
Batch 1400/1875, Loss: 1.1862
Batch 1600/1875, Loss: 1.2304
Batch 1800/1875, Loss: 1.1770


In [10]:

# Evaluate the Student Model
student_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
student_model.evaluate(x_test, y_test_onehot, verbose=2)

313/313 - 0s - 1ms/step - accuracy: 0.9700 - loss: 0.0970


[0.09695473313331604, 0.9700000286102295]