In [7]:
# Imports

import numpy as np
import sys
sys.path.append('Knowledge-Distillation/utils/')
import sklearn
import os

import keras
from keras import optimizers

# use non standard flow_from_directory
from image_preprocessing_ver2 import ImageDataGenerator
# it outputs y_batch that contains onehot targets and logits

In [8]:
# Load logits saved from teacher

data_dir = os.path.join(os.getcwd(), 'data')

train_logits = np.load(os.path.join(data_dir, 'mnisttrain_logits.npy'), allow_pickle=True)[()]
val_logits = np.load(os.path.join(data_dir, 'mnistval_logits.npy'), allow_pickle=True)[()]

In [9]:
# Load mnist images and associated teacher logits into ImageDataGenerator for batch processing

data_generator = ImageDataGenerator(data_format='channels_last', rescale=1/255)

train_generator = data_generator.flow_from_directory(os.path.join(data_dir, 'mnist\\train'), train_logits, target_size=(28, 28), color_mode='grayscale')
val_generator = data_generator.flow_from_directory(os.path.join(data_dir, 'mnist\\test'), val_logits, target_size=(28, 28), color_mode='grayscale')

Found 60000 images belonging to 10 classes.
Found 10000 images belonging to 10 classes.


In [10]:
# Setup student model

from keras.models import load_model
from keras import layers

# Temps of 2.5-4 "worked significantly better" that other temps on networks with 30 units per layer
temp = 4

from keras import models, layers

student = models.Sequential()
student.add(layers.Flatten(input_shape=(28, 28, 1)))
student.add(layers.Dense(32, activation='relu'))
student.add(layers.Dense(32, activation='relu'))
student.add(layers.Dense(10))
student.add(layers.Activation('softmax'))

# Remove softmax
student.pop()

# Get student logits and class probabilities
logits = student.layers[-1].output
probabilities = layers.Activation('softmax')(logits)

# Apply temperature to get softed probabilities
logits_T = layers.Lambda(lambda x: x / temp)(logits)
probabilities_T = layers.Activation('softmax')(logits_T)

In [11]:
from keras.models import Model

output = layers.concatenate([probabilities, probabilities_T])
model = Model(student.input, output)

In [12]:
# Distillation loss (soft targets and hard targets)

from keras.losses import categorical_crossentropy as logloss
from keras.metrics import categorical_accuracy, top_k_categorical_accuracy
from keras import backend as K

def distillation_loss(y_true, y_pred, hard_loss_weight):
    # 
    y_true, logits = y_true[:, :10], y_true[:, 10:]
    
    y_soft = K.softmax(logits / temp)
    
    y_pred, y_pred_soft = y_pred[:, :10], y_pred[:, 10:]
    
    return hard_loss_weight * logloss(y_true, y_pred) + logloss(y_soft, y_pred_soft)
    

In [13]:
# Custom metric functions

def accuracy(y_true, y_pred):
    y_true = y_true[:, :10]
    y_pred = y_pred[:, :10]
    return categorical_accuracy(y_true, y_pred)

def top_5_accuracy(y_true, y_pred):
    y_true = y_true[:, :10]
    y_pred = y_pred[:, :10]
    return top_k_categorical_accuracy(y_true, y_pred)

def categorical_crossentropy(y_true, y_pred):
    y_true = y_true[:, :10]
    y_pred = y_pred[:, :10]
    return logloss(y_true, y_pred)

def soft_logloss(y_true, y_pred):
    print(f'y_true: {y_true.shape} y_pred: {y_pred.shape}')
    logits = y_true[:, 10:]
    print(f'logits: {logits.shape}')
    y_soft = K.softmax(logits/temp)
    print(f'y_soft: {y_soft.shape}')
    y_pred_soft = y_pred[:, 10:]
    print(f'y_pred_soft: {y_pred_soft.shape}')
    return logloss(y_soft, y_pred_soft)

In [14]:
hard_loss_weight = 0.07


model.compile(
    optimizer=optimizers.SGD(lr=1e-2, momentum=0.9, nesterov=True),
    loss=lambda y_true, y_pred: distillation_loss(y_true, y_pred, hard_loss_weight),
    metrics=[accuracy, top_5_accuracy, categorical_crossentropy, soft_logloss])

y_true: (?, ?) y_pred: (?, 20)
logits: (?, ?)
y_soft: (?, ?)
y_pred_soft: (?, 10)


In [None]:
from keras.callbacks import ReduceLROnPlateau, EarlyStopping

history = model.fit_generator(
    train_generator,
    epochs=35,
    steps_per_epoch=60000/32,
    verbose=1,
    validation_data=val_generator,
    validation_steps=100,
    callbacks=[
            EarlyStopping(monitor='val_accuracy', patience=4, min_delta=0.01), 
            ReduceLROnPlateau(monitor='val_accuracy', factor=0.1, patience=2, min_delta=0.007)
        ])

In [None]:
import matplotlib.pyplot as plt

plt.plot(model.history.history['categorical_crossentropy'], label='train')
plt.plot(model.history.history['val_categorical_crossentropy'], label='val')
plt.legend()
plt.xlabel('epoch');
plt.ylabel('logloss');

In [None]:
plt.plot(model.history.history['accuracy'], label='train');
plt.plot(model.history.history['val_accuracy'], label='val');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('accuracy');

In [None]:
plt.plot(model.history.history['top_5_accuracy'], label='train');
plt.plot(model.history.history['val_top_5_accuracy'], label='val');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('top5_accuracy');

In [None]:
val_generator_no_shuffle = data_generator.flow_from_directory(
    os.path.join(data_dir, 'mnist\\test'), val_logits,
    target_size=(28, 28),
    batch_size=64, color_mode='grayscale', shuffle=False
)

In [None]:
results = model.evaluate_generator(val_generator_no_shuffle, 80)
print('Errors: ', 10000 - results[1] * 10000)
print('Accuracy:', results[1]*100 )

In [None]:
from keras.models import save_model

save_model(model, 'mnist_distilled_student.h5')