# Project A: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [None]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union

tf.enable_v2_behavior()

builder = tfds.builder('mnist')
BATCH_SIZE = 256
NUM_EPOCHS = 12
NUM_CLASSES = 10  # 10 total classes.

# Data loading

In [None]:
# Load train and test splits.
def preprocess(x):
  image = tf.image.convert_image_dtype(x['image'], tf.float32)
  subclass_labels = tf.one_hot(x['label'], builder.info.features['label'].num_classes)
  return image, subclass_labels


mnist_train = tfds.load('mnist', split='train', shuffle_files=False).cache()
mnist_train = mnist_train.map(preprocess)
mnist_train = mnist_train.shuffle(builder.info.splits['train'].num_examples)
mnist_train = mnist_train.batch(BATCH_SIZE, drop_remainder=True)

mnist_test = tfds.load('mnist', split='test').cache()
mnist_test = mnist_test.map(preprocess).batch(BATCH_SIZE)

In [None]:
print(mnist_train)

# Model creation

In [None]:
#@test {"output": "ignore"}

# Build CNN teacher.
def create_teacher_model():
    cnn_model = tf.keras.Sequential()

    # your code start from here for stpe 2
    cnn_model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=3,strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)))
    cnn_model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=1,padding='same'))
    cnn_model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=3,strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)))
    cnn_model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=2,padding='same'))
    cnn_model.add(tf.keras.layers.Flatten())
    cnn_model.add(tf.keras.layers.Dropout(0.5))
    cnn_model.add(tf.keras.layers.Dense(128, activation='relu'))
    cnn_model.add(tf.keras.layers.Dropout(0.5))
    cnn_model.add(tf.keras.layers.Dense(NUM_CLASSES))

    return cnn_model

# Build fully connected student.
def create_student_model():
    fc_model = tf.keras.Sequential()

    # your code start from here for step 2
    fc_model.add(tf.keras.layers.Flatten())
    fc_model.add(tf.keras.layers.Dense(784, activation='relu'))
    fc_model.add(tf.keras.layers.Dense(784, activation='relu'))
    fc_model.add(tf.keras.layers.Dense(NUM_CLASSES))
    
    return fc_model


# Teacher loss function

In [None]:
@tf.function
def compute_teacher_loss(model, images, labels):
  """Compute subclass knowledge distillation teacher loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  subclass_logits = model(images, training=True)

  # Compute cross-entropy loss for subclasses.
  

  # your code start from here for step 3
  cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=subclass_logits)



  return cross_entropy_loss_value

# Student loss function

In [None]:
#@test {"output": "ignore"}

# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = tf.nn.softmax(teacher_logits/temperature)

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss(student_model, teacher_model, images, labels):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = student_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  teacher_subclass_logits = teacher_model(images, training=False)
  distillation_loss_value = distillation_loss(teacher_subclass_logits, student_subclass_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=student_subclass_logits)

  return ALPHA * cross_entropy_loss_value + (1- ALPHA) * distillation_loss_value

# Train and evaluation

In [None]:
from six import with_metaclass
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn, with_kd, teacher_model):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
  best_accuracy = 0

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4
        if not with_kd:
          loss_value = compute_loss_fn(model, images, labels)
        else:
          loss_value = compute_loss_fn(model, teacher_model, images, labels)

      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      correct_results,_,_ = compute_num_correct(model, images, labels)
      num_correct += correct_results


    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))
    if (num_correct / num_total * 100) > best_accuracy:
      best_accuracy = num_correct / num_total * 100
    
  return best_accuracy


# Training models

In [None]:
# your code start from here for step 5 
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

cnn_model = None
fc_model = None
cnn_model = create_teacher_model()
fc_model = create_student_model()

train_and_evaluate(cnn_model, compute_teacher_loss, with_kd=False, teacher_model=None)
cnn_model.save("task1_teacher_model")
train_and_evaluate(fc_model, compute_student_loss, with_kd=True, teacher_model=cnn_model)
fc_model.save("task1_student_model")

cnn_model = None
fc_model = None
cnn_model = create_teacher_model()
fc_model = create_student_model()

ALPHA = 0.9 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 8. #temperature hyperparameter

train_and_evaluate(cnn_model, compute_teacher_loss, with_kd=False, teacher_model=None)
cnn_model.save("task1_teacher_model")
train_and_evaluate(fc_model, compute_student_loss, with_kd=True, teacher_model=cnn_model)
fc_model.save("task1_student_model")

In [None]:
sk

# Test accuracy vs. tempreture curve

In [None]:
# your code start from here for step 6
import matplotlib
from matplotlib import pyplot as plt

ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE_LIST = [1,2,4,16,32,64] #temperature hyperparameter list
accuracy_list = []

for temp_single in DISTILLATION_TEMPERATURE_LIST:
  
  print('current tempreature is ', temp_single)
  DISTILLATION_TEMPERATURE = temp_single
  best_accuracy = train_and_evaluate(fc_model, compute_student_loss, with_kd=True, teacher_model=cnn_model)
  
  # Run evaluation.
  num_correct = 0
  num_total = builder.info.splits['test'].num_examples
  
  for images, labels in mnist_test:
    # your code start from here for step 4
    correct_results,_,_ = compute_num_correct(fc_model, images, labels)
    num_correct += correct_results

  print("Class_accuracy: " + '{:.2f}%'.format(
      num_correct / num_total * 100))

  accuracy_list.append(best_accuracy)#(num_correct / num_total * 100)


plt.figure(figsize=(6,6))
plt.xlabel("temperature hyperparameter list")
plt.ylabel("accuracy")
plt.plot(DISTILLATION_TEMPERATURE_LIST,accuracy_list,color='r',linewidth=1.0,linestyle='--')
plt.show()

# Train student from scratch

In [None]:
# Build fully connected student.
fc_model_no_distillation = tf.keras.Sequential()

# your code start from here for step 7
fc_model_no_distillation.add(tf.keras.layers.Flatten())
fc_model_no_distillation.add(tf.keras.layers.Dense(784, activation='relu'))
fc_model_no_distillation.add(tf.keras.layers.Dense(784, activation='relu'))
fc_model_no_distillation.add(tf.keras.layers.Dense(NUM_CLASSES))


#@test {"output": "ignore"}

def compute_plain_cross_entropy_loss(model, images, labels):
  """Compute plain loss for given images and labels.

  For fair comparison and convenience, this function also performs a
  LogSumExp over subclasses, but does not perform subclass distillation.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  # your code start from here for step 7

  student_subclass_logits = model(images, training=True)
  cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=student_subclass_logits)
  
  return cross_entropy_loss

train_and_evaluate(fc_model_no_distillation, compute_plain_cross_entropy_loss, with_kd=False, teacher_model=None)

# Comparing the teacher and student model (number of of parameters and FLOPs) 

In [None]:
# your code start from here for step 8
#!pip install keras_flops

from keras_flops import get_flops

teacher_model_flops = get_flops(cnn_model, batch_size=1)
student_model_flops = get_flops(fc_model, batch_size=1)
student_without_KD_model_flops = get_flops(fc_model_no_distillation, batch_size=1)


print('The flops of teacher model is',teacher_model_flops)
print('The flops of student model is',student_model_flops)
print('The flops of fc_model_no_distillation model is',student_without_KD_model_flops)

cnn_model.summary()
fc_model.summary()
fc_model_no_distillation.summary()


# Implementing the state-of-the-art KD algorithm

In [None]:
# your code start from here for step 5 
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

# your code start from here for step 12
# Seyed Iman Mirzadeh, Mehrdad Farajtabar, Ang Li, Nir Levine, Akihiro Matsukawa, and Hassan Ghasemzadeh. 
# Improved knowledge distillation via teacher assistant. 
# In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pages 5191–5198, 2020. 
# https://ojs.aaai.org/ index.php/AAAI/article/view/5963/5819

# Build Teacher assistant.
ta_model = tf.keras.Sequential()
ta_model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=3,strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)))
ta_model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=1,padding='same'))
ta_model.add(tf.keras.layers.Flatten())
ta_model.add(tf.keras.layers.Dense(784, activation='relu'))
ta_model.add(tf.keras.layers.Dense(784, activation='relu'))
ta_model.add(tf.keras.layers.Dropout(0.5))
ta_model.add(tf.keras.layers.Dense(NUM_CLASSES))

def compute_ta_or_student_loss(images, labels, learn_model, teach_model):
  """Compute subclass knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_subclass_logits = learn_model(images, training=True)

  # Compute subclass distillation loss between student subclass logits and
  # softened teacher subclass targets probabilities.

  # your code start from here for step 3

  teacher_subclass_logits = teach_model(images, training=False)
  distillation_loss_value = distillation_loss(teacher_subclass_logits, student_subclass_logits, DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=student_subclass_logits)

  return ALPHA * cross_entropy_loss_value + (1- ALPHA) * distillation_loss_value

def train_and_evaluate_state_of_the_art(model, teach_model):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4

        loss_value = compute_ta_or_student_loss(images, labels, model, teach_model)

      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
      # your code start from here for step 4
      correct_results,_,_ = compute_num_correct(model, images, labels)
      num_correct += correct_results


    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))

# train teacher assistant model
train_and_evaluate_state_of_the_art(ta_model, teach_model=cnn_model)
ta_model.save("task1_ta_model")

# train student model from teacher assistant model
train_and_evaluate_state_of_the_art(fc_model, teach_model=ta_model)
ta_model.save("task1_student_from_ta_model")

In [None]:
ta_model.summary()

# (Optional) XAI method to explain models

In [None]:
# your code start from here for step 13
!pip install pillow
from PIL import Image, ImageDraw
def build_mask_randomly(h=7, w=7, H=224, W=224, p_1=0.5, resample=Image.BILINEAR):

    assert H>h, 'Masks should be higher dimensions.'
    assert W>w, 'Masks should be higher dimensions.'
    mask=np.random.choice([0, 1], size=(h, w), p=[1-p_1, p_1])

    # upsample
    mask = Image.fromarray(mask*255.)
    mask = mask.resize((H + h, W + w), resample=resample)
    mask = np.array(mask)

    # randomly crop mask to HxW
    w_crop = np.random.randint(0,w+1)
    h_crop = np.random.randint(0,h+1)
    mask = mask[h_crop:H + h_crop, w_crop:W + w_crop]
    # normalize between 0 and 1
    mask /= np.max(mask)
    return mask

def RISE(img, model, class_index, N_MASKS=8000, H=28, W=28, C=1):

    X = np.zeros(shape=(N_MASKS, H, W, C), dtype=np.float32)
    masks = np.zeros((N_MASKS,H,W), dtype=np.float32)
    for i in range(N_MASKS):
        m =build_mask_randomly(H=H, W=W)
        masks[i] = m
        x = img.copy()
        x[:, :, 0] *= m
        X[i] = x
    preds_masked = model.predict(X, verbose=0)
    sum_mask = np.zeros(masks[0].shape, dtype=np.float32)

    for i, mask in enumerate(masks):
        m = mask * preds_masked[i, class_index]
        sum_mask += m

    sum_mask -= np.min(sum_mask)
    sum_mask /= np.max(sum_mask)
    return sum_mask

In [None]:
import numpy as np

one_sample, = mnist_test.take(1)
images, labels = one_sample

import random
random.seed(99)
number = random.randint(0,255)
test_image = tf.squeeze(images[number])
test_label = np.argmax(labels[number])
# plt.figure(figsize=(12,12))
plt.imshow(test_image)
plt.axis("off")
plt.title("The inference of the image label is: {}".format(test_label))
plt.show()

teacher_pre = np.argmax(tf.nn.softmax(cnn_model(tf.expand_dims(sample_image, axis=0))))
student_with_KD_pre = np.argmax(tf.nn.softmax(fc_model(tf.expand_dims(sample_image, axis=0))))
student_without_KD_pre = np.argmax(tf.nn.softmax(fc_model_no_distillation(tf.expand_dims(sample_image, axis=0))))
print("teacher_pre:",teacher_pre,"\nstudent_with_KD_pre:",student_with_KD_pre,"\nstudent_without_KD_pre",student_without_KD_pre)

In [None]:
rise_teacher = RISE(images[number].numpy(), cnn_model, class_index=test_label, N_MASKS=8000)
rise_teacher -= rise_teacher.min()
rise_teacher /= rise_teacher.max()+1e-10
rise_student_with_KD = RISE(images[number].numpy(), fc_model, class_index=test_label, N_MASKS=8000)
rise_student_with_KD -= rise_student_with_KD.min()
rise_student_with_KD /= rise_student_with_KD.max()+1e-10
rise_student_without_KD = RISE(images[number].numpy(), fc_model_no_distillation, class_index=test_label, N_MASKS=8000)
rise_student_without_KD -= rise_student_without_KD.min()
rise_student_without_KD /= rise_student_without_KD.max()+1e-10

In [None]:
# Plot Three explanation map from different models
# Teacher Model
plt.figure(figsize=(8,8))
plt.subplot(3,2,1)
plt.title("origin label: {}".format(test_label))
plt.imshow(test_image)
plt.axis("off")

plt.subplot(3,2,2)
plt.title('Teacher mdoel map')
plt.imshow(rise_teacher, cmap='jet', alpha=0.5)
plt.axis('off')

# Student w/ KD Model
plt.subplot(3,2,3)
plt.title("origin label: {}".format(test_label))
plt.imshow(test_image)
plt.axis('off')

plt.subplot(3,2,4)
plt.title('Student model with KD map')
plt.imshow(rise_student_with_KD, cmap='jet', alpha=0.5)
plt.axis('off')

# Student w/o KD Model
plt.subplot(3,2,5)
plt.title("origin label: {}".format(test_label))
plt.imshow(test_image)
plt.axis('off')

plt.subplot(3,2,6)
plt.title('Student without KD map')
plt.imshow(rise_student_without_KD, cmap='jet', alpha=0.5)
plt.axis('off')

plt.show()

In [None]:
# origin image and explanation map are merged
plt.figure(figsize=(12,12))
plt.subplot(1,3,1)
plt.title('Teacher')
plt.imshow(test_image)
plt.imshow(rise_teacher, cmap='jet', alpha=0.5)
plt.axis('off')

plt.subplot(1,3,2)
plt.title('Student model with KD')
plt.imshow(test_image)
plt.imshow(rise_student_with_KD, cmap='jet', alpha=0.5)
plt.axis('off')

plt.subplot(1,3,3)
plt.title('Student model without KD')
plt.imshow(test_image)
plt.imshow(rise_student_without_KD, cmap='jet', alpha=0.5)
plt.axis('off')

plt.show()