<a href="https://colab.research.google.com/github/sayakpaul/Knowledge-Distillation-in-Keras/blob/master/Distillation_Toy_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Imports
import tensorflow as tf

from tensorflow.keras import models
from tensorflow.keras import layers

tf.random.set_seed(666)

In [2]:
# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.

X_train.shape, X_test.shape, y_train.shape, y_test.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

In [3]:
# Change the pixel values to float32 and reshape input data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

In [4]:
# Define utility function for building a basic shallow Convnet 
def get_teacher_model():
    model = models.Sequential()
    model.add(layers.Conv2D(16, (5, 5), activation="relu",
        input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(32, (5, 5), activation="relu"))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.2))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation="relu"))
    model.add(layers.Dense(10))
    
    return model

In [5]:
# Define loass function and optimizer
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [6]:
# Prepare TF dataset
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(100).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)

# Train the teacher model
teacher_model = get_teacher_model()
teacher_model.compile(loss=loss_func, optimizer=optimizer, metrics=["accuracy"])
teacher_model.fit(train_ds,
                  validation_data=test_ds,
                  epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f161017ddd8>

In [7]:
# Evaluate and serialize
print("Test accuracy: {:.2f}".format(teacher_model.evaluate(test_ds)[1]*100))
teacher_model.save("teacher_model.h5")

Test accuracy: 90.25


In [8]:
# Student model utility
def get_student_model():
    model = models.Sequential()
    model.add(layers.Input(shape=(28, 28, 1)))
    model.add(layers.Flatten())
    model.add(layers.Dense(784, activation="relu"))
    model.add(layers.Dense(10))
    
    return model

In [9]:
# Credits: https://github.com/google-research/simclr/blob/master/colabs/distillation_self_training.ipynb
def get_kd_loss(student_logits, teacher_logits, temperature=0.5):
    teacher_probs = tf.nn.softmax(teacher_logits / temperature)
    kd_loss = tf.compat.v1.losses.softmax_cross_entropy(
        teacher_probs, student_logits / temperature, temperature**2)
    return kd_loss

In [10]:
# Model, optimizer
student_model = get_student_model()
optimizer = tf.keras.optimizers.Adam()

# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")

# Specify the performance metric
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="valid_acc")

In [11]:
# Train the model
@tf.function
def model_train(images, labels):
    teacher_logits = teacher_model(images)

    with tf.GradientTape() as tape:
        student_logits = student_model(images)
        loss = get_kd_loss(student_logits, teacher_logits, temperature=0.5)
    
    gradients = tape.gradient(loss, student_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))

    train_loss(loss)
    train_acc(labels, tf.nn.softmax(student_logits))

In [12]:
# Validating the model
@tf.function
def model_validate(images, labels):
    teacher_logits = teacher_model(images)

    student_logits = student_model(images)
    loss = get_kd_loss(student_logits, teacher_logits, temperature=0.5)

    valid_loss(loss)
    valid_acc(labels, tf.nn.softmax(student_logits))

In [13]:
for epoch in range(10):
    # Run the model through train and test sets respectively
    for (images, labels) in train_ds:
        model_train(images, labels)

    for (images, labels) in test_ds:
        model_validate(images, labels)
        
    # Grab the results
    (loss, acc) = train_loss.result(), train_acc.result()
    (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
    
    # Clear the current state of the metrics
    train_loss.reset_states(), train_acc.reset_states()
    valid_loss.reset_states(), valid_acc.reset_states()
    
    # Local logging
    template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
    print (template.format(epoch+1,
                         loss,
                         acc,
                         val_loss,
                         val_acc))

Epoch 1, loss: 0.099, acc: 0.838, val_loss: 0.082, val_acc: 0.853
Epoch 2, loss: 0.070, acc: 0.873, val_loss: 0.073, val_acc: 0.863
Epoch 3, loss: 0.061, acc: 0.885, val_loss: 0.066, val_acc: 0.872
Epoch 4, loss: 0.056, acc: 0.891, val_loss: 0.063, val_acc: 0.880
Epoch 5, loss: 0.052, acc: 0.896, val_loss: 0.062, val_acc: 0.882
Epoch 6, loss: 0.050, acc: 0.900, val_loss: 0.063, val_acc: 0.882
Epoch 7, loss: 0.047, acc: 0.904, val_loss: 0.063, val_acc: 0.885
Epoch 8, loss: 0.046, acc: 0.906, val_loss: 0.068, val_acc: 0.877
Epoch 9, loss: 0.044, acc: 0.907, val_loss: 0.063, val_acc: 0.886
Epoch 10, loss: 0.043, acc: 0.909, val_loss: 0.063, val_acc: 0.886


In [14]:
# Serialize
student_model.save("student_model.h5")

In [15]:
!ls -lh *.h5

-rw-r--r-- 1 root root 2.4M Aug 31 07:06 student_model.h5
-rw-r--r-- 1 root root 982K Aug 31 07:05 teacher_model.h5


Further size decrease is possible with TFLite. 

In [19]:
# Credits: https://www.tensorflow.org/lite/performance/post_training_quant

def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(X_train).batch(1).take(100):
        yield [input_value]

def convert_to_tflite(model, tflite_file):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_data_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    tflite_quant_model = converter.convert()

    open(tflite_file, 'wb').write(tflite_quant_model)

In [20]:
convert_to_tflite(teacher_model, "teacher.tflite")
convert_to_tflite(student_model, "student.tflite")

INFO:tensorflow:Assets written to: /tmp/tmpj76nijkr/assets


INFO:tensorflow:Assets written to: /tmp/tmpj76nijkr/assets


INFO:tensorflow:Assets written to: /tmp/tmpe5_7fs7r/assets


INFO:tensorflow:Assets written to: /tmp/tmpe5_7fs7r/assets


In [21]:
!ls -lh *.tflite

-rw-r--r-- 1 root root 613K Aug 31 07:24 student.tflite
-rw-r--r-- 1 root root  85K Aug 31 07:24 teacher.tflite
