In [None]:
import tensorflow as tf
import numpy as np

In [None]:
class Distiller(tf.keras.Model):
  def __init__(self, student, teacher):
    super(Distiller, self).__init__()
    self.teacher = teacher
    self.student = student

  def compile(
      self,
      optimizer,
      metrics,
      student_loss_fn,
      distillation_loss_fn,
      alpha=0.1,
      temperature=0.3
  ):
    super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
    self.student_loss_fn = student_loss_fn
    self.distillation_loss_fn = distillation_loss_fn
    self.alpha = alpha
    self.temperature = temperature

  def train_step(self, data):
    x, y = data
    teacher_predictions = self.teacher(x, training=False)
    with tf.GradientTape() as tape:
      student_predictions = self.student(x, training=True)
      student_loss = self.student_loss_fn(y, student_predictions)
      distillation_loss = self.distillation_loss_fn(
          tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
          tf.nn.softmax(student_predictions / self.temperature, axis=1)
      )
      loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

    # Compute gradients
    trainable_vars = self.student.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    # Update metrics
    self.compiled_metrics.update_state(y, student_predictions)

    results = {m.name: m.result() for m in self.metrics}
    results.update(
        {"student_loss": student_loss, "distillation_loss": distillation_loss}
    )
    return results

  def test_step(self, data):
    x, y = data
    y_prediction = self.student(x, training=False)
    student_loss = self.student_loss_fn(y, y_prediction)
    self.compiled_metrics.update_state(y, y_prediction)
    results = {m.name: m.result() for m in self.metrics}
    results.update({"student_loss": student_loss})
    return results


In [None]:
teacher = tf.keras.Sequential([
  tf.keras.Input(shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same'),
  tf.keras.layers.LeakyReLU(alpha=0.2),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same'),
  tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding='same'),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
], name='teacher')

In [None]:
student = tf.keras.Sequential([
  tf.keras.layers.Input(shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding='same'),
  tf.keras.layers.LeakyReLU(alpha=0.2),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same'),
  tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding='same'),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
], name='student')

In [None]:
student_scratch = tf.keras.models.clone_model(student)

In [None]:
batch_size = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype('float32') / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

In [None]:
teacher.compile(
    optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
teacher.fit(x_train, y_train, epochs=5)

In [None]:
teacher.evaluate(x_test, y_test)

In [None]:
distiller = Distiller(student, teacher)
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10
)

In [None]:
distiller.fit(x_train, y_train, epochs=3)

In [None]:
distiller.evaluate(x_test, y_test)

In [None]:
student_scratch.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

In [None]:
student_scratch.fit(x_train, y_train, epochs=3)

In [None]:
student_scratch.evaluate(x_test, y_test)