<a href="https://colab.research.google.com/github/shreyash53/SMAI-Knowledge-Distilation/blob/main/KD_CIFAR10_02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 32, 32, 3))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 32, 32,3))
print("Input Train data  ",x_train.shape)
print("Train data Labels ",y_train.shape)
print("Input Test data   ",x_test.shape)
print("Test data Labels  ",y_test.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Input Train data   (50000, 32, 32, 3)
Train data Labels  (50000, 1)
Input Test data    (10000, 32, 32, 3)
Test data Labels   (10000, 1)


In [3]:
teacher = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(32, 32, 3)),
        
        tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10),
    ],
    name="teacher",
)
teacher.summary()

Model: "teacher"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 16, 16, 256)       7168      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 16, 16, 256)       0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 16, 16, 256)      0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 512)         1180160   
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 8, 8, 512)         0         
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 8, 8, 512)        0         
 2D)                                                       

In [None]:
teacher.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)

Epoch 1/5
Epoch 2/5
  99/1563 [>.............................] - ETA: 11:15 - loss: 1.0628 - sparse_categorical_accuracy: 0.6269

In [None]:
# Create the student
student = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(32, 32, 3)),
        tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10),
    ],
    name="student",
)
student.summary()

In [None]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student
    
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ 
        Configure the distiller.student_loss_fn: Loss function of difference 
        between student predictions and ground-truth

        distillation_loss_fn: Loss function of difference between soft
        student predictions and soft teacher predictions

        alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn

        temperature: Temperature for softening probability distributions.
        Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn=student_loss_fn
        self.distillation_loss_fn= distillation_loss_fn
        self.temperature= temperature
        self.alpha= alpha
        
    def train_step(self, data):
        x,y=data
        
        # Forward pass of teacher
        teacher_prediction=self.teacher(x, training=False)
        print("Tecaher prediction   ...", teacher_prediction)
        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predcition= self.student(x, training=True)
            # Compute losses
            student_loss= self.student_loss_fn(y, student_predcition)
            
            distillation_loss=self.distillation_loss_fn(
            tf.nn.softmax(teacher_prediction/self.temperature, axis=1),
            tf.nn.softmax(student_predcition/self.temperature, axis=1)
            )
            loss= self.alpha* student_loss + (1-self.alpha)* distillation_loss
            print("Loss in distiller :",loss)
            # Compute gradients
            trainable_vars= self.student.trainable_variables
            gradients=tape.gradient(loss, trainable_vars)
            gradients = [gradient * (self.temperature ** 2) for gradient in gradients]
            # Update weights
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
            
            # Update the metrics configured in `compile()`
            self.compiled_metrics.update_state(y, student_predcition)
            
            # Return a dict of performance
            results={ m.name: m.result()  for m in self.metrics}
            results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})
            print("Train...", results)
            return results
        
    def test_step(self, data):
        # Unpack the data
        x, y = data
        
        ## Compute predictions
        y_prediction= self.student(x, training=False)
        
        # calculate the loss
        student_loss= self.student_loss_fn(y, y_prediction)
        
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        
        # Return a dict of performance
        results ={m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        print("Test...", results)
        return results# Initialize  distiller
distiller= Distiller(student=student, teacher=teacher)

In [None]:
#compile distiller
distiller1= Distiller(student=student, teacher=teacher)
distiller1.compile(optimizer=keras.optimizers.Adam(),
                 metrics=[keras.metrics.SparseCategoricalAccuracy()],
                 student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                 distillation_loss_fn=keras.losses.KLDivergence(),
                 alpha=0.3,
                 temperature=2)

# Distill teacher to student
distiller1.fit(x_train, y_train, epochs=5)# Evaluate student on test dataset
distiller1.evaluate(x_test, y_test)


In [None]:
#compile distiller
distiller2= Distiller(student=student, teacher=teacher)
distiller2.compile(optimizer=keras.optimizers.Adam(),
                 metrics=[keras.metrics.SparseCategoricalAccuracy()],
                 student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                 distillation_loss_fn=keras.losses.KLDivergence(),
                 alpha=0.3,
                 temperature=3)

# Distill teacher to student
distiller2.fit(x_train, y_train, epochs=5)# Evaluate student on test dataset
distiller2.evaluate(x_test, y_test)


In [None]:
#compile distiller
distiller3= Distiller(student=student, teacher=teacher)
distiller3.compile(optimizer=keras.optimizers.Adam(),
                 metrics=[keras.metrics.SparseCategoricalAccuracy()],
                 student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                 distillation_loss_fn=keras.losses.KLDivergence(),
                 alpha=0.3,
                 temperature=4)

# Distill teacher to student
distiller3.fit(x_train, y_train, epochs=5)# Evaluate student on test dataset
distiller3.evaluate(x_test, y_test)


In [None]:
student.compile(
                 optimizer=keras.optimizers.Adam(),
                 metrics=[keras.metrics.SparseCategoricalAccuracy()],
                 loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)

In [None]:
student.fit(x_train, y_train, epochs=5)
student.evaluate(x_test, y_test)