In [None]:
# Imports

import numpy as np
import sys
# https://github.com/Ujjwal-9/Knowledge-Distillation
sys.path.append('Knowledge/utils/')
import sklearn
import os
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.models import Model
import os
import keras
from keras import optimizers

# use non standard flow_from_directory
# it outputs y_batch that contains onehot targets and logits
from image_preprocessing_ver2 import ImageDataGenerator

In [None]:
# Load logits saved from teacher

data_dir = os.path.join(os.getcwd(), 'data')

train_logits = np.load(os.path.join(data_dir, 'cifar10_train_logits.npy'), allow_pickle=True)[()]
val_logits = np.load(os.path.join(data_dir, 'cifar10_val_logits.npy'), allow_pickle=True)[()]

In [None]:
# Load cifar10 images and associated teacher logits into ImageDataGenerator for batch processing

data_generator = ImageDataGenerator(data_format='channels_last', rescale=1/255)

batch_size = 128
data_dir = r''
train_generator = data_generator.flow_from_directory(os.path.join(data_dir, 'cifar10\\train'), train_logits, target_size=(32, 32), color_mode='rgb', batch_size=batch_size)
val_generator = data_generator.flow_from_directory(os.path.join(data_dir, 'cifar10\\test'), val_logits, target_size=(32, 32), color_mode='rgb', batch_size=batch_size)

In [None]:
# Define student model

from keras import models, layers

def load_student_for_training():
    from keras.initializers import lecun_normal
    from keras.layers import BatchNormalization
    
    initializer = lecun_normal(seed=1)
    
    student = Sequential()
    student.add(Conv2D(16, (3, 3),
                     input_shape=(32, 32, 3),
                      kernel_initializer=initializer))
    student.add(Activation('relu'))
    student.add(MaxPooling2D(pool_size=(2, 2)))

    student.add(Conv2D(32, (3, 3),
                       kernel_initializer=initializer))
    student.add(Activation('relu'))
    student.add(MaxPooling2D(pool_size=(2, 2)))

    student.add(Flatten())
    student.add(Dense(64, kernel_initializer=initializer))
    student.add(BatchNormalization())
    student.add(Activation('relu'))
    student.add(Dense(10, kernel_initializer=initializer))
    student.add(Activation('softmax'))
    
    return student

In [None]:
# Distillation loss (soft targets and hard targets)

from keras.losses import categorical_crossentropy as logloss
from keras.metrics import categorical_accuracy, top_k_categorical_accuracy
from keras import backend as K

def distillation_loss(y_true, y_pred, hard_loss_weight, temp):
    y_true, logits = y_true[:, :10], y_true[:, 10:]
    
    y_soft = K.softmax(logits / temp)
    
    y_pred, y_pred_soft = y_pred[:, :10], y_pred[:, 10:]
    
    return hard_loss_weight * logloss(y_true, y_pred) + logloss(y_soft, y_pred_soft)
    

In [None]:
# Custom metric functions

def accuracy(y_true, y_pred):
    y_true = y_true[:, :10]
    y_pred = y_pred[:, :10]
    return categorical_accuracy(y_true, y_pred)

def top_5_accuracy(y_true, y_pred):
    y_true = y_true[:, :10]
    y_pred = y_pred[:, :10]
    return top_k_categorical_accuracy(y_true, y_pred)

def categorical_crossentropy(y_true, y_pred):
    y_true = y_true[:, :10]
    y_pred = y_pred[:, :10]
    return logloss(y_true, y_pred)

def soft_logloss(y_true, y_pred, temp):     
    logits = y_true[:, 10:]
    y_soft = K.softmax(logits/temp)
    y_pred_soft = y_pred[:, 10:]    
    return logloss(y_soft, y_pred_soft)

In [None]:
def distill(temp, hard_weight, epochs=25, verbose=False):
    """
    Metrics are redefined here because soft_logloss depends on non-standard param (temp).
    model.compile wouldn't take lambdas as metrics, so this was the workaround.
    """
    def accuracy(y_true, y_pred):
        y_true = y_true[:, :10]
        y_pred = y_pred[:, :10]
        return categorical_accuracy(y_true, y_pred)

    def top_5_accuracy(y_true, y_pred):
        y_true = y_true[:, :10]
        y_pred = y_pred[:, :10]
        return top_k_categorical_accuracy(y_true, y_pred)

    def categorical_crossentropy(y_true, y_pred):
        y_true = y_true[:, :10]
        y_pred = y_pred[:, :10]
        return logloss(y_true, y_pred)

    def soft_logloss(y_true, y_pred):     
        logits = y_true[:, 10:]
        y_soft = K.softmax(logits/temp)
        y_pred_soft = y_pred[:, 10:]    
        return logloss(y_soft, y_pred_soft)
    
    student = load_student_for_training()
    
    # Remove softmax
    student.pop()
    
    # Get student logits and class probabilities
    logits = student.layers[-1].output
    probabilities = layers.Activation('softmax')(logits)

    # Apply temperature to get softed probabilities
    # Temps of 2.5-4 "worked significantly better" than other temps on networks with 30 units per layer
    logits_T = layers.Lambda(lambda x: x / temp)(logits)
    probabilities_T = layers.Activation('softmax')(logits_T)

    # Define student that outputs probabilities and softed probabilities
    output = layers.concatenate([probabilities, probabilities_T])
    model = Model(student.input, output)
    
    model.compile(
        optimizer='adam',
        loss=lambda y_true, y_pred: distillation_loss(y_true, y_pred, hard_weight, temp),
        metrics=[accuracy, top_5_accuracy, categorical_crossentropy, soft_logloss])

    if verbose:
        verbose = 1
    else:
        verbose = 0
    history = model.fit_generator(
        train_generator,
        epochs=epochs,
        steps_per_epoch=50000/batch_size,
        verbose=verbose,
        validation_data=val_generator,
        validation_steps=25,
        callbacks=[
                EarlyStopping(monitor='val_accuracy', patience=5, min_delta=0.005)
            ])

    results = model.evaluate_generator(val_generator_no_shuffle, steps=50000/batch_size)
    return results, history, model

In [None]:
# Define val data generator

val_generator_no_shuffle = data_generator.flow_from_directory(
    os.path.join(data_dir, 'cifar10\\test'), val_logits,
    target_size=(32, 32),
    batch_size=128, color_mode='rgb', shuffle=False
)

In [None]:
# Iterate over param combinations, saving results of each training session

temp_accuracies_weight_1 = []
for temp in [1, 2, 5, 10]:
    try:
        results, _, _ = distill(temp, 0.1, epochs=50)
        print(f'Acc: {results[1]} \n Temp: {temp} \n')
        temp_accuracies_weight_1.append((temp, results))
    except ValueError as e:
        print('Temp:', temp)
        print(e) 
temp_accuracies_weight_1 = np.asarray(temp_accuracies_weight_1)
np.save('data/temp_accuracies_weight_1.npy', temp_accuracies_weight_1)

temp_accuracies_weight_2 = []
for temp in [1, 2, 5, 10]:
    try:
        results, _, _ = distill(temp, 0.5, epochs=50)
        print(f'Acc: {results[1]} \n Temp: {temp} \n')
        temp_accuracies_weight_2.append((temp, results))
    except ValueError as e:
        print('Temp:', temp)
        print(e) 
temp_accuracies_weight_2 = np.asarray(temp_accuracies_weight_2)
np.save('data/temp_accuracies_weight_2.npy', temp_accuracies_weight_2)

temp_accuracies_weight_3 = []
for temp in [1, 2, 5, 10]:
    try:
        results, _, _ = distill(temp, 0.9, epochs=50)
        print(f'Acc: {results[1]} \n Temp: {temp} \n')
        temp_accuracies_weight_3.append((temp, results))
    except ValueError as e:
        print('Temp:', temp)
        print(e) 
temp_accuracies_weight_3 = np.asarray(temp_accuracies_weight_3)
np.save('data/temp_accuracies_weight_3.npy', temp_accuracies_weight_3)

# weight_accuracies = []
# for weight in np.arange(0, 1.6, 0.1):
#     try:
#         results, _, _ = distill(10, weight, epochs=50)
#         print(f'Acc: {results[1]} \n Weight: {weight} \n')
#         weight_accuracies.append((weight, results))
#     except ValueError as e:
#         print('Weight:', weight)
#         print(e) 
# weight_accuracies = np.asarray(weight_accuracies)
# np.save('data/weight_accuracies2.npy', weight_accuracies)

In [None]:
# Load distillation results and plot

import matplotlib.pyplot as plt
import numpy as np

temp_accuracies1 = np.load('data/temp_accuracies_weight_1.npy', allow_pickle=True)
accuracies1 = list(map(lambda x: x[1][1], temp_accuracies1))
temp_accuracies2 = np.load('data/temp_accuracies_weight_2.npy', allow_pickle=True)
accuracies2 = list(map(lambda x: x[1][1], temp_accuracies2))
temp_accuracies3 = np.load('data/temp_accuracies_weight_3.npy', allow_pickle=True)
accuracies3 = list(map(lambda x: x[1][1], temp_accuracies3))

plt.plot(temp_accuracies1[:, 0], accuracies1, label='Weight = 0.1')
plt.plot(temp_accuracies2[:, 0], accuracies2, label='Weight = 0.5')
plt.plot(temp_accuracies3[:, 0], accuracies3, label='Weight = 0.9')
plt.title('Temp vs Accuracy')
plt.xlabel('Temp')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('weight_v_accuracy3.png')

In [None]:
import matplotlib.pyplot as plt

ts = list(map(lambda x: x[0], accs))
a = list(map(lambda x: x[1], accs))

plt.plot(ts, a)
plt.title('Distillation Temp vs Accuracy')
plt.xlabel('Temp')
plt.ylabel('Accuracy')

In [None]:
# Distill single model
res, history, model = distill(2, 0.5, verbose=True)

In [None]:
# Plot hard loss

import matplotlib.pyplot as plt

plt.plot(range(1, len(history.history['val_categorical_crossentropy'])+1), history.history['val_categorical_crossentropy'], label='val')
plt.plot(range(1, len(history.history['categorical_crossentropy'])+1), history.history['categorical_crossentropy'], label='training')
plt.title('Progression of hard logloss')
plt.xlabel('epoch');
plt.ylabel('hard logloss');
plt.legend()
plt.savefig('hard_logloss.png')

In [None]:
# Plot soft loss

import matplotlib.pyplot as plt

plt.plot(range(1, len(history.history['val_soft_logloss'])+1), history.history['val_soft_logloss'], label='val')
plt.plot(range(1, len(history.history['soft_logloss'])+1), history.history['soft_logloss'], label='training')
plt.title('Progression of soft logloss')
plt.xlabel('epoch');
plt.ylabel('soft logloss');
plt.legend()
plt.savefig('soft_logloss.png')

In [None]:
# Plot temp v acc

import matplotlib.pyplot as plt
import numpy as np

temp_accuracies = np.load('data/temp_accuracies1.npy', allow_pickle=True)
accuracies = list(map(lambda x: x[1][1], temp_accuracies))

plt.plot(temp_accuracies[:, 0], accuracies)
plt.title('Distillation Temp vs Accuracy')
plt.xlabel('Temp')
plt.ylabel('Accuracy')
plt.savefig('temp_v_accuracy1.png')

In [None]:
# Plot weight v acc

import matplotlib.pyplot as plt
import numpy as np

weight_accuracies = np.load('data/weight_accuracies1.npy', allow_pickle=True)
accuracies = list(map(lambda x: x[1][1], weight_accuracies))

plt.plot(weight_accuracies[:, 0], accuracies)
plt.title('Hard Loss Weight vs Accuracy')
plt.xlabel('Weight')
plt.ylabel('Accuracy')
plt.savefig('weight_v_accuracy1.png')

In [None]:
# Plot scatter plot of soft loss v hard loss

import matplotlib.pyplot as plt

plt.scatter(history.history['soft_logloss'], history.history['categorical_crossentropy'], label='training')
plt.scatter(history.history['val_soft_logloss'], history.history['val_categorical_crossentropy'], label='val')
plt.title('Hard Loss vs. Soft Loss')
plt.xlabel('soft_logloss');
plt.ylabel('logloss');
plt.legend()
plt.savefig('logloss_vs_softlogloss.png')

In [None]:
# Plot acc curves

plt.plot(history.history['accuracy'], label='train');
plt.plot(history.history['val_accuracy'], label='val');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('accuracy');

In [None]:
results = model.evaluate_generator(val_generator_no_shuffle, 80)
print('Errors: ', 10000 - results[1] * 10000)
print('Accuracy:', results[1]*100 )

In [None]:
from keras.models import save_model

save_model(model, 'models/cifar10_distilled_student.h5')