# Project B: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [1]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union

tf.enable_v2_behavior()

BATCH_SIZE = 256
NUM_EPOCHS = 12
NUM_CLASSES = 2 

  from .autonotebook import tqdm as notebook_tqdm


# Data loading

In [11]:
import pandas as pd

# Load train and test splits.
anno_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\annotations.csv' #you should change to your directory
df = pd.read_csv(anno_dir)
df = df[['Image Name','Majority Vote Label','Partition']]
df

Unnamed: 0,Image Name,Majority Vote Label,Partition
0,MHIST_aaa.png,SSA,train
1,MHIST_aab.png,HP,train
2,MHIST_aac.png,SSA,train
3,MHIST_aae.png,HP,train
4,MHIST_aaf.png,SSA,train
...,...,...,...
3147,MHIST_cpn.png,SSA,train
3148,MHIST_cfc.png,SSA,test
3149,MHIST_cgp.png,SSA,test
3150,MHIST_dlf.png,SSA,train


In [15]:
import shutil

hp_test_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images_sorted\HP_test'
hp_train_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images_sorted\HP_train'
ssa_test_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images_sorted\SSA_test'
ssa_train_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images_sorted\SSA_train'

img_root_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images\images'

for i, row in df.iterrows():
    img_dir = img_root_dir + '\\' + row['Image Name']
    if row['Majority Vote Label'] == 'SSA':
        if row['Partition'] == 'train':
            new_dir = ssa_train_dir + '\\' + row['Image Name']
            shutil.copyfile(img_dir,new_dir)
        else:
            new_dir = ssa_test_dir + '\\' + row['Image Name']
            shutil.copyfile(img_dir,new_dir)
    else:
        if row['Partition'] == 'train':
            new_dir = hp_train_dir + '\\' + row['Image Name']
            shutil.copyfile(img_dir,new_dir)
        else:
            new_dir = hp_test_dir + '\\' + row['Image Name']
            shutil.copyfile(img_dir,new_dir)
    #break

In [2]:
from keras.preprocessing.image import ImageDataGenerator

test_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images_sorted\test'
train_dir = r'D:\Github Repos\ECE1512\-ECE1512_2022W_ProjectRepo_StephanieDiNunzio\Project B\Project_B_Supp\mhist_dataset\images_sorted\train'

train_datagen = ImageDataGenerator(rescale=1/255.,
shear_range=0.1,
rotation_range=15,
horizontal_flip=True,
vertical_flip=True)

test_datagen = ImageDataGenerator(rescale=1/255.)

train_generator = train_datagen.flow_from_directory(train_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=32,
shuffle=True)

test_generator = test_datagen.flow_from_directory(test_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=32,
shuffle=False)

# Model creation

In [3]:
#@test {"output": "ignore"}

# Build CNN teacher.
cnn_model = tf.keras.Sequential()

# your code start from here for stpe 2
from tensorflow.keras import layers
 
cnn_model.add(layers.Conv2D(32,3,strides=1,padding='same',activation='relu',input_shape=(28,28,1)))
cnn_model.add(layers.MaxPool2D(pool_size=(2,2),strides=1))
cnn_model.add(layers.Conv2D(64,3,strides=1,padding='same',activation='relu'))
cnn_model.add(layers.MaxPool2D(pool_size=(2,2),strides=2))
cnn_model.add(layers.Flatten())
cnn_model.add(layers.Dropout(0.5))
cnn_model.add(layers.Dense(128,activation='relu'))
cnn_model.add(layers.Dropout(0.5))
cnn_model.add(layers.Dense(128,activation='relu'))
cnn_model.add(layers.Dropout(0.5))
cnn_model.add(layers.Dense(NUM_CLASSES))
##

#cnn_model.summary()


# Build fully connected student.
fc_model = tf.keras.Sequential()

# your code start from here for step 2

fc_model.add(layers.Flatten())
fc_model.add(layers.Dense(784,activation='relu'))
fc_model.add(layers.Dense(784,activation='relu'))
fc_model.add(layers.Dense(NUM_CLASSES))


# Teacher loss function

In [4]:
import keras
import numpy as np
@tf.function
def compute_teacher_loss(images, labels):
  """Compute class knowledge distillation teacher loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  class_logits = cnn_model(images, training=True)

  # Compute cross-entropy loss for classes.

  # your code start from here for step 3
 
  loss = keras.losses.CategoricalCrossentropy(from_logits=True)
  
  cross_entropy_loss_value = loss(labels,class_logits)


  return cross_entropy_loss_value


# Student loss function

In [5]:
#@test {"output": "ignore"}

# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits)

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss(images, labels):
  """Compute class knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_class_logits = fc_model(images, training=True)

  # Compute class distillation loss between student class logits and
  # softened teacher class targets probabilities.

  # your code start from here for step 3

  teacher_class_logits = cnn_model(images, training=False)
  distillation_loss_value = distillation_loss(teacher_class_logits, student_class_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3
  loss = keras.losses.CategoricalCrossentropy(from_logits=True)
  
  cross_entropy_loss_value = loss(labels,student_class_logits).numpy()

  return ALPHA*cross_entropy_loss_value + (1-ALPHA)*distillation_loss_value

# Train and evaluation

In [6]:
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4
        loss_value = compute_loss_fn(images, labels)
        #tape.watch(loss_value)
      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      num_correct += compute_num_correct(model, images, labels)[0]
    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))

  return num_correct / num_total


# Training models

In [7]:
train_and_evaluate(cnn_model, compute_teacher_loss)

Epoch 1: Class_accuracy: 97.54%
Epoch 2: Class_accuracy: 98.07%
Epoch 3: Class_accuracy: 98.70%
Epoch 4: Class_accuracy: 98.71%
Epoch 5: Class_accuracy: 98.92%
Epoch 6: Class_accuracy: 98.94%
Epoch 7: Class_accuracy: 99.01%
Epoch 8: Class_accuracy: 99.11%
Epoch 9: Class_accuracy: 99.07%
Epoch 10: Class_accuracy: 99.07%
Epoch 11: Class_accuracy: 99.11%
Epoch 12: Class_accuracy: 99.19%


<tf.Tensor: shape=(), dtype=float32, numpy=0.9919>

In [8]:
##RESET MODEL

# Build fully connected student.
fc_model = tf.keras.Sequential()

fc_model.add(layers.Flatten())
fc_model.add(layers.Dense(784,activation='relu'))
fc_model.add(layers.Dense(784,activation='relu'))
fc_model.add(layers.Dense(NUM_CLASSES))


# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 2.5 #temperature hyperparameter

# your code start from here for step 5 

accuracy = train_and_evaluate(fc_model, compute_student_loss)

Epoch 1: Class_accuracy: 96.34%
Epoch 2: Class_accuracy: 97.57%
Epoch 3: Class_accuracy: 97.64%
Epoch 4: Class_accuracy: 97.96%
Epoch 5: Class_accuracy: 98.03%
Epoch 6: Class_accuracy: 98.17%
Epoch 7: Class_accuracy: 98.29%
Epoch 8: Class_accuracy: 98.16%
Epoch 9: Class_accuracy: 98.21%
Epoch 10: Class_accuracy: 98.27%
Epoch 11: Class_accuracy: 98.42%
Epoch 12: Class_accuracy: 98.57%


In [9]:
print(accuracy.numpy())

0.9857
