<a href="https://colab.research.google.com/github/saniya1027108/Research---Knowledge-Distillation-/blob/main/KD_imdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence

In [None]:
# Load IMDB dataset
max_features = 20000
max_len = 200
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = sequence.pad_sequences(x_train, maxlen=max_len)
x_test = sequence.pad_sequences(x_test, maxlen=max_len)


In [None]:
# Define a teacher model
teacher = models.Sequential()
teacher.add(layers.Embedding(max_features, 128, input_length=max_len))
teacher.add(layers.LSTM(64))
teacher.add(layers.Dense(1, activation='sigmoid'))

teacher.compile(optimizer='adam',
                loss='binary_crossentropy',
                metrics=['accuracy'])

teacher.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       (None, 200, 128)          2560000   
                                                                 
 lstm (LSTM)                 (None, 64)                49408     
                                                                 
 dense (Dense)               (None, 1)                 65        
                                                                 
Total params: 2609473 (9.95 MB)
Trainable params: 2609473 (9.95 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [None]:
# Train teacher model
teacher.fit(x_train, y_train, epochs=3, batch_size=64, validation_split=0.2)


Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x7f2df4671810>

In [None]:
# Define a student model (simpler architecture)
student = models.Sequential()
student.add(layers.Embedding(max_features, 64, input_length=max_len))
student.add(layers.LSTM(32))
student.add(layers.Dense(1, activation='sigmoid'))

student.compile(optimizer='adam',
                loss='binary_crossentropy',
                metrics=['accuracy'])

student.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_1 (Embedding)     (None, 200, 64)           1280000   
                                                                 
 lstm_1 (LSTM)               (None, 32)                12416     
                                                                 
 dense_1 (Dense)             (None, 1)                 33        
                                                                 
Total params: 1292449 (4.93 MB)
Trainable params: 1292449 (4.93 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [None]:
# Define Distiller class
class Distiller(models.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.temperature = temperature
        self.alpha = alpha

    def train_step(self, data):
        x, y = data

        # Forward pass of teacher
        teacher_prediction = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_prediction = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_prediction)

            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_prediction / self.temperature, axis=1),
                tf.nn.softmax(student_prediction / self.temperature, axis=1)
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

            # Compute gradients
            trainable_vars = self.student.trainable_variables
            gradients = tape.gradient(loss, trainable_vars)
            gradients = [gradient * tf.constant(self.temperature ** 2, dtype=tf.float32) if isinstance(gradient, tf.IndexedSlices) else gradient * (self.temperature ** 2) for gradient in gradients]


            # Update weights
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))

            # Update the metrics configured in `compile()`
            self.compiled_metrics.update_state(y, student_prediction)

            # Return a dict of performance
            results = {m.name: m.result() for m in self.metrics}
            results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})
            print("Train...", results)
            return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        print("Test...", results)
        return results

# Initialize Distiller
distiller = Distiller(student=student, teacher=teacher)

In [None]:
# Compile Distiller
distiller.compile(optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'],
                  student_loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  distillation_loss_fn=tf.keras.losses.KLDivergence(),
                  alpha=0.3,
                  temperature=5)

In [None]:
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3, batch_size=64, validation_split=0.2)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

Epoch 1/3
Train... {'accuracy': <tf.Tensor 'Identity_6:0' shape=() dtype=float32>, 'student_loss': <tf.Tensor 'binary_crossentropy/weighted_loss/value:0' shape=() dtype=float32>, 'distillation_loss': <tf.Tensor 'kl_divergence/weighted_loss/value:0' shape=() dtype=float32>}
Train... {'accuracy': <tf.Tensor 'Identity_6:0' shape=() dtype=float32>, 'student_loss': <tf.Tensor 'binary_crossentropy/weighted_loss/value:0' shape=() dtype=float32>, 'distillation_loss': <tf.Tensor 'kl_divergence/weighted_loss/value:0' shape=() dtype=float32>}
Epoch 2/3
Epoch 3/3


[0.8465200066566467, 0.5559797286987305]

hyperparameter tuning


In [None]:
# Initialize Distiller
distiller = Distiller(student=student, teacher=teacher)

# Hyperparameter tuning
learning_rate = 0.001
batch_size = 32
epochs = 10

# Compile Distiller with tuned hyperparameters
distiller.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                  metrics=['accuracy'],
                  student_loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  distillation_loss_fn=tf.keras.losses.KLDivergence(),
                  alpha=0.6,
                  temperature=15)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=0.2)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

Epoch 1/10
Train... {'accuracy': <tf.Tensor 'Identity_6:0' shape=() dtype=float32>, 'student_loss': <tf.Tensor 'binary_crossentropy/weighted_loss/value:0' shape=() dtype=float32>, 'distillation_loss': <tf.Tensor 'kl_divergence/weighted_loss/value:0' shape=() dtype=float32>}
Train... {'accuracy': <tf.Tensor 'Identity_6:0' shape=() dtype=float32>, 'student_loss': <tf.Tensor 'binary_crossentropy/weighted_loss/value:0' shape=() dtype=float32>, 'distillation_loss': <tf.Tensor 'kl_divergence/weighted_loss/value:0' shape=() dtype=float32>}
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


[0.8423600196838379, 0.6283491849899292]