<a href="https://colab.research.google.com/github/shreyash53/SMAI-Knowledge-Distilation/blob/main/Knowledge_Distillation_in_NN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd
import os

import tensorflow as tf
from tensorflow import keras

In [2]:
# Importing MNIST dataset

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

print("Train data shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

Train data shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [4]:
from sklearn.preprocessing import OneHotEncoder

print ('label shape: ',np.shape(y_train),np.shape(y_test))
def encode(y):
    df = pd.DataFrame(y, columns=["class"])
    encoder = OneHotEncoder(handle_unknown='ignore')
    encoder_df = pd.DataFrame(encoder.fit_transform(df[['class']]).toarray())
    final_df = df.join(encoder_df)
    final_df.drop('class', axis=1, inplace=True)
    return final_df.to_numpy()
y_train = encode(y_train)
y_test = encode(y_test)
print ('label shape: ',np.shape(y_train),np.shape(y_test))
# print(y_test_[2], y_test[2])

label shape:  (60000,) (10000,)
label shape:  (60000, 10) (10000, 10)


In [None]:
# Teacher Model

teacher_model = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding="same"),
    keras.layers.LeakyReLU(alpha=0.2),
    keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
    # keras.layers.Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding="same"),
    # keras.layers.LeakyReLU(alpha=0.2),
    # keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
    # keras.layers.Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding="same"),
    # keras.layers.LeakyReLU(alpha=0.2),
    # keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
    keras.layers.Flatten(),
    keras.layers.Dense(1200, activation="relu"),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1200, activation="relu"),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation="softmax"),
    ], name="teacher"
)

teacher_model.summary()

In [None]:
callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)
teacher_model.compile(loss="categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])
teacher_model.fit(x_train, y_train, epochs=20, validation_split=0.1, callbacks=[callback])

In [7]:
y_test.shape

(10000, 10)

In [8]:
# Evaluate Teacher Model

score = teacher_model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

Test loss: 0.03838657587766647
Test accuracy: 0.9864000082015991


In [9]:
# Small Model

small_model = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding="same"),
    keras.layers.LeakyReLU(alpha=0.2),
    keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation="relu"),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation="relu"),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation="softmax"),
    ], name="small"
)

small_model.summary()

Model: "small"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_1 (Conv2D)           (None, 14, 14, 64)        640       
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 13, 13, 64)       0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 10816)             0         
                                                                 
 dense_3 (Dense)             (None, 10)                108170    
                                                                 
 dropout_2 (Dropout)         (None, 10)                0         
                                                             

In [10]:
# callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)
# small_model.compile(loss="categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])
# small_model.fit(x_train, y_train, epochs=20, validation_split=0.1, callbacks=[callback])

In [11]:
# Evaluate Small Model

# score = small_model.evaluate(x_test, y_test, verbose=0)
# print("Test loss:", score[0])
# print("Test accuracy:", score[1])

In [12]:
# Student Model

student_model = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding="same"),
    keras.layers.LeakyReLU(alpha=0.2),
    keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1)),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation="relu"),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation="relu"),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation="softmax"),
    # keras.layers.Reshape((32, 10))
    ], name="student"
)

student_model.summary()

Model: "student"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_2 (Conv2D)           (None, 14, 14, 64)        640       
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 13, 13, 64)       0         
 2D)                                                             
                                                                 
 flatten_2 (Flatten)         (None, 10816)             0         
                                                                 
 dense_6 (Dense)             (None, 10)                108170    
                                                                 
 dropout_4 (Dropout)         (None, 10)                0         
                                                           

In [19]:
# Distilation

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """     Configure the distiller.student_loss_fn: Loss function of difference between student
                predictions and ground-truth
                distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
                alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
                temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn=student_loss_fn
        self.distillation_loss_fn= distillation_loss_fn
        self.temperature= temperature
        self.alpha= alpha
        
    def train_step(self, data):
        x,y=data
        
        # Forward pass of teacher
        teacher_prediction=self.teacher(x, training=False)
        print("Tecaher prediction   ...", teacher_prediction)
        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predcition= self.student(x, training=True)
            # student_predcition = tf.reshape(student_predcition, (1,320))
            print("Student Prediction: ", student_predcition)
            # Compute losses

            # y = tf.reshape(y, (1,320))
            print(y.shape, student_predcition.shape)
            student_loss= self.student_loss_fn(y, student_predcition)
            print("Student loss: ", student_loss)
            print("Teacher Prediction: ", teacher_prediction)
            # teacher_prediction = tf.reshape(teacher_prediction, (1,320))
            # print("Student Prediction: ", student_predcition)
            distillation_loss=self.distillation_loss_fn(
                tf.nn.softmax(teacher_prediction/self.temperature, axis=1),
                tf.nn.softmax(student_predcition/self.temperature, axis=1)
            )
            print("Distillation Liss")
            loss= self.alpha * student_loss + (1-self.alpha) * distillation_loss
            print("Loss in distiller :",loss)
            # Compute gradients
            trainable_vars= self.student.trainable_variables
            gradients=tape.gradient(loss, trainable_vars)
            gradients = [gradient * (self.temperature ** 2) for gradient in gradients]
            # Update weights
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
            
            # Update the metrics configured in `compile()`
            # print(y.reshape(320,1))
            self.compiled_metrics.update_state(y, student_predcition)
            
            # Return a dict of performance
            results={ m.name: m.result()  for m in self.metrics}
            results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})
            print("Train...", results)
            return results
        
    def test_step(self, data):
        # Unpack the data
        x, y = data
        
        ## Compute predictions
        y_prediction= self.student(x, training=False)
        
        # calculate the loss
        student_loss= self.student_loss_fn(y, y_prediction)
        
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        
        # Return a dict of performance
        results ={m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        print("Test...", results)
        return results# Initialize  distiller

In [22]:
distiller= Distiller(student=student_model, teacher=teacher_model)
distiller.compile(optimizer=keras.optimizers.Adam(),
                 metrics=[keras.metrics.CategoricalAccuracy()],
                 student_loss_fn=keras.losses.CategoricalCrossentropy(),
                 distillation_loss_fn=keras.losses.CategoricalCrossentropy(),
                 alpha=0.7,
                 temperature=7)

In [15]:
y_train

array([[0., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.]])

In [23]:
# Distill teacher to student

distiller.fit(x_train/255, y_train/255, epochs=5)# Evaluate student on test dataset
score = distiller.evaluate(x_test/255, y_test/255)

Epoch 1/5
Tecaher prediction   ... Tensor("teacher/dense_2/Softmax:0", shape=(32, 10), dtype=float32)
Student Prediction:  Tensor("student/dense_8/Softmax:0", shape=(32, 10), dtype=float32)
(32, 10) (32, 10)
Student loss:  Tensor("categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)
Teacher Prediction:  Tensor("teacher/dense_2/Softmax:0", shape=(32, 10), dtype=float32)
Distillation Liss
Loss in distiller : Tensor("add:0", shape=(), dtype=float32)
Train... {'categorical_accuracy': <tf.Tensor 'Identity:0' shape=() dtype=float32>, 'student_loss': <tf.Tensor 'categorical_crossentropy/weighted_loss/value:0' shape=() dtype=float32>, 'distillation_loss': <tf.Tensor 'categorical_crossentropy_1/weighted_loss/value:0' shape=() dtype=float32>}
Tecaher prediction   ... Tensor("teacher/dense_2/Softmax:0", shape=(32, 10), dtype=float32)
Student Prediction:  Tensor("student/dense_8/Softmax:0", shape=(32, 10), dtype=float32)
(32, 10) (32, 10)
Student loss:  Tensor("categorical_cro

KeyboardInterrupt: ignored

In [160]:
print("Test loss:", score[0])
print("Test accuracy:", score[1])

Test loss: 0.09740000218153
Test accuracy: 5471.373046875
