In [14]:
import tensorflow as tf
from keras.datasets import mnist
from tensorflow.keras import layers, models, optimizers, mixed_precision

In [6]:
# Enable mixed precision training
mixed_precision.set_global_policy('mixed_float16')

# Load and preprocess the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train.astype('float32') / 255.0, X_test.astype('float32') / 255.0

In [10]:
# Define a simple model
teacher_model = models.Sequential([
    layers.Input(shape=(28, 28,)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax'),
])

In [11]:
# Compile the teacher model with an optimizer and loss function
optimizer = optimizers.Adam(learning_rate=1e-3)
teacher_model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

In [39]:
student_model = models.Sequential([
    layers.Input(shape=(28, 28,)),
    layers.Flatten(),
    layers.Dense(32, activation='relu'),
    layers.Dense(10, activation='softmax'),
])

In [40]:
# Compile the student model with an optimizer and loss function
optimizer = optimizers.Adam(learning_rate=1e-3)
student_model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

In [41]:
# Distillation loss function using TensorFlow operations
def distillation_loss(teacher_logits, student_logits, temperature=3):
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    student_probs = tf.nn.softmax(student_logits / temperature)
    return tf.reduce_mean(tf.keras.losses.categorical_crossentropy(teacher_probs, student_probs))

In [51]:
# Train the student model using knowledge distillation with smaller data and fewer epochs
def train_student(student: models.Sequential, teacher: models.Sequential, X, y, batch_size=32, epochs=2, temperature=3):
    print(type(student))
    for epoch in range(epochs):
        num_batches = len(X) // batch_size
        for batch in range(num_batches):
            X_batch = X[batch * batch_size : (batch + 1) * batch_size]
            y_batch = y[batch * batch_size : (batch + 1) * batch_size]

            # Predict teacher logits for the batch
            teacher_logits = teacher.predict(X_batch)

            with tf.GradientTape() as tape:
                # Predict student logits for the batch
                student_logits = student(X_batch)
                # Calculate distillation loss
                student_loss = distillation_loss(teacher_logits, student_logits, temperature=temperature)

            # Apply gradients
            grads = tape.gradient(student_loss, student.trainable_variables)
            student.optimizer.apply_gradients(zip(grads, student.trainable_variables))

            print(f"Epoch {epoch + 1}/{epochs}. Loss: {student_loss.numpy()}")

In [52]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train[:1000], X_test[:1000]

X_train, X_test = X_train.astype('float32') / 255., X_test.astype('float32') / 255.

# Train the student model
train_student(student_model, teacher_model, X_train, y_train, batch_size=32, epochs=3)

<class 'keras.src.models.sequential.Sequential'>
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
Epoch 1/3. Loss: 2.30078125
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
Epoch 1/3. Loss: 2.30078125
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
Epoch 1/3. Loss: 2.302734375
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3