In [None]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import Dense,GlobalAveragePooling2D, MaxPool2D
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model, load_model
from keras.optimizers import Adam
from keras.preprocessing import image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from keras.utils import to_categorical # for encoding data to one not form
import tensorflow as tf
from keras.objectives import categorical_crossentropy
from keras import backend as K
import numpy as np
import keras

In [None]:
# download mnist
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

In [None]:
# build teacher CNN
def create_teacher_CNN():
    inputs = Input(shape=(28, 28, 1))

    x = Conv2D(32,(3,3),padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(3,3),strides=(2,2),padding='same')(x)

    x = Conv2D(64,(3,3),padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = MaxPool2D(pool_size=(3,3),strides=(2,2),padding='same')(x)

    x = Flatten()(x)
    x = Dropout(0.5)(x)

    x = Dense(512, activation="relu")(x)
    x = Dense(10, name="logits")(x)
    predictions = Activation('softmax')(x)

    return Model(inputs=inputs, outputs=predictions)

In [None]:
# build student net
def build_student_net():
    inputs = Input(shape=(28,28,1))
    x = Flatten()(inputs)
    x = Dense(1024,activation="relu")(x) 
    logits = Dense(10)(x)
    return Model(inputs=inputs, outputs=[logits, logits])  

In [None]:
teacher_model = create_teacher_CNN()

In [None]:
teacher_lr = 1e-3
opt = Adam(lr=teacher_lr)

teacher_model.compile(
    loss=['categorical_crossentropy'],
    optimizer=opt,
    metrics=['accuracy']
)

In [None]:
# train teacher model
iteration = 8000
for i in range(iteration):
  
    # set learning rate schedule
    if i > iteration*0.4: 
        K.set_value(teacher_model.optimizer.lr, teacher_lr*0.2)
    if i > iteration*0.7: 
        K.set_value(teacher_model.optimizer.lr, teacher_lr*0.02)
    if i > iteration*0.9:
        K.set_value(teacher_model.optimizer.lr, teacher_lr*0.002)

    batch_x, batch_y = mnist.train.next_batch(128)
    batch_x = batch_x.reshape((-1,28,28,1))
  
    tr_results = teacher_model.train_on_batch(batch_x, batch_y)
    tr_loss, tr_acc = tr_results[0], tr_results[1]
  
    if i % 100 == 0:
    
        test_x, test_y = mnist.test.images, mnist.test.labels
        test_x = test_x.reshape((-1,28,28,1))
        val_results = teacher_model.test_on_batch(test_x, test_y)
        val_loss, val_acc = val_results[0], val_results[1]
        print('Iteration: {:}  tr_loss; {:.6}  tr_acc: {:.4}  val_loss: {:.6}  val_acc: {:.4}'.format(i, tr_loss, tr_acc, val_loss, val_acc))
    
test_x, test_y = mnist.test.images, mnist.test.labels
test_x = test_x.reshape((-1,28,28,1))
val_results = teacher_model.test_on_batch(test_x, test_y)
val_loss, val_acc = val_results[0], val_results[1]
print('============== training done =========================')
print('val_loss: {:.6}  val_acc: {:.4}'.format(val_loss, val_acc))
print('======================================================')

In [None]:
# save teacher model
teacher_model.save('./teacher_model.h5')

In [None]:
def get_teacher_logits(teacher_net):
    logits_layer = teacher_net.get_layer('logits').output
    return Model(inputs=teacher_net.input,outputs=logits_layer)

In [None]:
# load teacher pretrained model
teacher_model = load_model('./teacher_model.h5')

# build student net
student = build_student_net()
teacher_kd_model = get_teacher_logits(teacher_model)
teacher_kd_model.trainable = False


T = 7 # temperature
a = 0.9 # ratio of kd

def kd_loss(y_true, y_pred):
    soft_y = K.softmax(y_true/T)
    soft_pred = K.softmax(y_pred/T)
    return categorical_crossentropy(soft_y, soft_pred)


def cross_entropy(y_true, y_pred):
    y_pred = K.softmax(y_pred)
    return categorical_crossentropy(y_true, y_pred)


student_lr = 1e-3
opt = Adam(lr=student_lr)
student.compile(
    loss=[kd_loss, cross_entropy],
    optimizer=opt,
    metrics=['accuracy'],
    loss_weights=[a,(1.0-a)] # a is the ratio for kd, (1-a) is for ground truth
)

In [None]:
# train student model
counter = 0
iteration = 8000
for i in range(iteration):
  
    # learning rate schedule
    if i > iteration*0.4: 
        K.set_value(student.optimizer.lr, student_lr*0.2)
    if i > iteration*0.7: 
        K.set_value(student.optimizer.lr, student_lr*0.02)
    if i > iteration*0.9:
        K.set_value(student.optimizer.lr, student_lr*0.002)
  
    # get ground truth
    batch_x, batch_y = mnist.train.next_batch(128)
    batch_x = batch_x.reshape((-1,28,28,1))

    # get teacher label
    teacher_y = teacher_kd_model.predict(batch_x)

    tr_results = student.train_on_batch(batch_x, [teacher_y, batch_y])
    tr_loss, tr_acc = tr_results[0], tr_results[-1]
  
    if i % 100 == 0:
        test_x, test_y = mnist.test.images, mnist.test.labels
        test_x = test_x.reshape((-1,28,28,1))

        teacher_y = teacher_kd_model.predict(test_x)

        val_results = student.test_on_batch(test_x, [teacher_y, test_y])
        val_loss, val_acc = val_results[0], val_results[-1]

        print('Iteration: {:}  tr_loss; {:.6}  tr_acc: {:.4}  val_loss: {:.6}  val_acc: {:.4}'.format(i, tr_loss, tr_acc, val_loss, val_acc))

test_x, test_y = mnist.test.images, mnist.test.labels
test_x = test_x.reshape((-1,28,28,1))
teacher_y = teacher_kd_model.predict(test_x)
val_results = student.test_on_batch(test_x, [teacher_y, test_y])
val_loss, val_acc = val_results[0], val_results[-1]
print('============== training done =========================')
print('val_loss: {:.6}  val_acc: {:.4}'.format(val_loss, val_acc))
print('======================================================')