In [None]:
#Set path to MAIN FOLDER OF EXPERIMENT
#cd /path/to/EXPERIMENT_FOLDER/

In [None]:
#LOAD DEPENDENCIES
import os
import time
import pickle
import logging
import numpy as np
import talos
import tensorflow as tf
import matplotlib.pyplot as plt

#IMPORT LOSS, OPTIMIZER, CALLBACK AND LAYERS
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy, KLDivergence
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Dense, Dropout, GlobalAveragePooling2D

#IMPORT MODEL APIs
from tensorflow.keras.models import Model, load_model, save_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as selected_model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as student_preprocess
from tensorflow.keras.applications.efficientnet import preprocess_input as teacher_preprocess

#PREVENT ERROR UNCESSARY MESSAGES
tf.get_logger().setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("LIBRARIES LOADED")

In [None]:
#Tweakable parameters
MODEL_KIND = "KD_model_ENSEMBLE"
MODEL_NAME = "MiniMobileNetV2"

#Models paths
PROPOSED_MODEL_PATH = "models/proposed_model/" + MODEL_NAME
HPO_PATH = "models/hpo_model/" + MODEL_NAME

#Figures paths
FIG_PATH = 'figures/' + MODEL_KIND + "/" + MODEL_NAME

#Data paths
MAIN_DATA_DIR = "ds/"
TRAIN_DATA_DIR = MAIN_DATA_DIR + "train/"
TEST_DATA_DIR = MAIN_DATA_DIR + "test/"
VALIDATION_DATA_DIR = MAIN_DATA_DIR + "val/"

print("ALL REQUERED PATHS SET")

In [None]:
#Save Model Function
def save_m(model, directory, model_name):
    if not os.path.exists(directory):
        os.makedirs(directory)
    model.save(directory + "/" + model_name + ".h5")

#Save History Function
def save_h(history, directory, history_name):
    if not os.path.exists(directory):
        os.makedirs(directory)
    with open(directory + '/' + history_name + '.history', 'wb') as file:
        pickle.dump(history, file)

#Load model Function
def load_m(directory, model_name):
    if not os.path.exists(directory):
        print("Model File Does Not Exist!!")
        return 
    model = load_model(directory + "/" + model_name + ".h5")
    return model

#Load History Function
def load_h(directory, history_name):
    if not os.path.exists(directory):
        print("History File Does Not Exist!!")
        return 
    with open(directory + '/' + history_name + '.history', 'rb') as file:
        his = pickle.load(file)
    return his

def save_fig(directory, fig_name):
    if not os.path.exists(directory):
        os.makedirs(directory)
    plt.savefig(directory + '/' + fig_name + '.tiff', bbox_inches='tight', dpi=600, format='tiff')
    
print("ALL CUSTOM FUNCTIONS DEFIEND")

In [None]:
#DATA GENERATORS
BATCH_SIZE = 4

img_rows, img_cols = 224, 224
INPUT_SHAPE = (img_rows, img_cols,3)
MODEL_INPUT = Input(shape=INPUT_SHAPE)
print("INPUT SIZE -->", MODEL_INPUT.shape, "\n")
NUM_CLASSES = 199

def create_data_generator(pre_process=None):
    nb_train_samples = 0
    nb_val_samples = 0
    num_classes = 0
    train_generator = None
    validation_generator = None

    train_datagen = ImageDataGenerator(preprocessing_function=pre_process)
    val_datagen = ImageDataGenerator(preprocessing_function=pre_process)

    if not os.path.exists(TRAIN_DATA_DIR):
        print("TRAIN DATA DOES NOT EXITS!")
        return None, None
    else:
        print("LOAD TRAIN SAMPLES...")
        train_generator = train_datagen.flow_from_directory(
                TRAIN_DATA_DIR,
                target_size=(img_rows,img_cols),
                batch_size=BATCH_SIZE,
                class_mode='categorical',
                seed=42)
        
        #CHECK  THE NUMBER OF SAMPLES
        nb_train_samples = len(train_generator.filenames)
        if nb_train_samples == 0:
            print("NO DATA TRAIN FOUND IN TRAIN FOLDER!")
            return None, None

    print()
    if not os.path.exists(TRAIN_DATA_DIR):
        print("VALIDATION DATA DOES NOT EXITS!")
        return None, None
    else:
        print("LOAD VALIDATION SAMPLES...")
        validation_generator = val_datagen.flow_from_directory(
                VALIDATION_DATA_DIR,
                target_size=(img_rows,img_cols),
                batch_size=BATCH_SIZE,
                class_mode='categorical',
                seed=42,
                shuffle=False)

        #CHECK  THE NUMBER OF SAMPLES
        nb_validation_samples = len(validation_generator.filenames)
        if nb_validation_samples == 0:
            print("NO DATA VALIDATION FOUND IN VALIDATION FOLDER!")
            return None, None

    print()
    if nb_train_samples > 0 and nb_validation_samples > 0:
        num_classes= len(train_generator.class_indices)
        print("GENERATER ARE SET!")
        print('CLASSES TO TRAIN', num_classes, 'classes')
    
    return (train_generator, nb_train_samples), (validation_generator, nb_validation_samples)

_, _=create_data_generator()

In [None]:
#Knowledge Distiller(KD)

class KDistiller(Model):
    def __init__(self, student, teacher, student_preprocess=None, teacher_preprocess=None):
        super(KDistiller, self).__init__()
        self.teacher = teacher
        self.student = student
        self.student_preprocess = student_preprocess
        self.teacher_preprocess = teacher_preprocess

    def compile(self, optimizer,  metrics, student_loss_fn, distillation_loss_fn, alpha, temperature):
        super(KDistiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # UNPACK DATA
        student_x, y = data
        teacher_x, y = data

        # PREPROCESS DATA
        if self.student_preprocess != None: student_x = self.student_preprocess(student_x)
        if self.teacher_preprocess != None: teacher_x = self.teacher_preprocess(teacher_x)
        
        # FORWARD PASS OF TEACHER
        teacher_preds = self.teacher(teacher_x, training=False)

        with tf.GradientTape() as tape:
            # FORWARD PASS OF STUDENT
            student_preds = self.student(student_x, training=True)

            # CALCULATE STUDENT LOSS
            student_loss = self.student_loss_fn(y, student_preds)

            # CALCULATE DISTISLATION LOSS
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_preds / self.temperature, axis=1),
                tf.nn.softmax(student_preds / self.temperature, axis=1),
            )

            # CALCULATE TOTAL LOSS
            total_loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # CALCULATE GRADIENT
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(total_loss, trainable_vars)

        # SET WEIGHTS
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # SET METRICES
        self.compiled_metrics.update_state(y, student_preds)
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss, 'loss':total_loss}
        )

        return results

    def test_step(self, data):
        # UNPACK DATA
        x, y = data

        # PREPROCESS DATA
        x = self.student_preprocess(x)
        
        # GET PREDICTIONS FROM STUDENT
        y_preds = self.student(x, training=False)

        # CALCULATE STUDENT LOSS
        student_loss = self.student_loss_fn(y, y_preds)

        # SET METRICES 
        self.compiled_metrics.update_state(y, y_preds)
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})

        return results

    def call(self, inputs, training):
        return self.student(inputs, training=training)

**Teacher Model :** EnsembleModel 

In [None]:
TEACHER_NAME = "EnsembleModel"
TEACHER_MODEL_PATH = "models/teacher_model/" + TEACHER_NAME

In [None]:
#PLOT THE MODEL STRUCTURE
def get_teacher():
    model = load_m(TEACHER_MODEL_PATH, TEACHER_NAME)
    model.layers[-1].activation = None
    if model != None:
        print("TEACHER MODEL LOADED SUCCESSFULLY!")

    return model

if get_teacher() != None: 
    print("PLEASE CHECK THE ENTIRE MODEL UP TO THE END")
    get_teacher().summary()

In [None]:
#Sanity Checker
#Re-Create Data Generator
_, (validation_generator, nb_validation_samples)  = create_data_generator(pre_process=teacher_preprocess)

get_teacher().evaluate(validation_generator)

**Student Model :** Mini-MNV2

In [None]:
#TRANSFER LEARNING
def get_student(model_input):
    model = load_m(PROPOSED_MODEL_PATH, MODEL_NAME)
    model.layers[-1].activation = None
    if model != None:
        print("STUDENT MODEL SUCESSFULLY BUILT!")
    return model

#PLOT THE MODEL STRUCTURE
print("PLEASE CHECK THE ENTIRE MODEL UP TO THE END")
get_student(MODEL_INPUT).summary()

In [None]:
#FIXED HYPERPARAMETERS
BATCH_SIZE = 4
EPOCHS = 30
DROPOUT_RATE = 0.5
OPTIMIZER = Adam

#HPO HYPERPAMETERS
TEMPERATURE = [5, 2, 10] 
ALPHA = [0.1, 0.3, 0.5]
LEARNING_RATE = [0.001, 0.01, 0.0001]

print("FIXED HYPERPARAMETERS")
print("---------------------")

print("BATCH_SIZE -->", BATCH_SIZE)
print("EPOCHS SET -->", EPOCHS)
print("DROPOUT_RATE -->", DROPOUT_RATE)
print("OPTIMIZER -->", OPTIMIZER.__name__,"\n")

print("HPO HYPERPARAMETERS")
print("--------------------")
print("TEMPERATURE -->", TEMPERATURE)
print("ALPHA -->", ALPHA)
print("LEARNING_RATE -->", LEARNING_RATE)

**Training Student With KD**

In [None]:
dummy_x = []
dummy_y = []

# HPO PARAMETERS
p = {
    'temperature':TEMPERATURE,
    'alpha':ALPHA,
    'lr':LEARNING_RATE
    }

def distiller_model(x_train, y_train, x_val, y_val, params):
    print("\nCURRENT PARAMETERS:", params)

    #START TIMER
    import time
    start_time = time.time()
    
    #GET NEW TEACHER AND STUDENT MODEL
    teacher = get_teacher()
    student = get_student(MODEL_INPUT)

    #SET CALLBACK
    reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', 
                              factor=0.5, 
                              patience=2,
                              verbose=1, 
                              mode='max', 
                              min_lr=0.000001)

    callbacks = [reduce_lr]

    #CREATE KNOWLEDGE DISTILLER
    distiller = KDistiller(
        student=student,teacher=teacher,
        student_preprocess=student_preprocess,
        teacher_preprocess=teacher_preprocess
    )

    #COMPILE KNOWLEDGE DISTILLER
    distiller.compile(
        optimizer = OPTIMIZER(learning_rate=params['lr']),
        metrics=['accuracy'],
        student_loss_fn=CategoricalCrossentropy(from_logits=True),
        distillation_loss_fn= KLDivergence(),
        alpha=params['alpha'],
        temperature=params['temperature'],
    )

    print()
    (train_generator, nb_train_samples), (validation_generator, nb_validation_samples) = create_data_generator()

    #DISTILLING
    distiller_history = distiller.fit(train_generator,
                                        validation_data = validation_generator,
                                        steps_per_epoch = nb_train_samples // BATCH_SIZE,
                                        validation_steps = nb_validation_samples // BATCH_SIZE,
                                        epochs=EPOCHS,
                                        callbacks=callbacks, 
                                      )
    #STOP TIMER
    elapsed_time = time.time() - start_time
    train_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
    print('\n\n' + train_time, 'train_time\n')
    print(elapsed_time, 'Seconds\n\n')

    print("MODEL SERIALIZING WAIT FOR A MOMENT...\n")
    save_m(distiller.student, HPO_PATH + '/HPO(t={0},a={1},l={2})'.format(params['temperature'],params['alpha'],params['lr']), MODEL_NAME)
    save_h(distiller_history.history, HPO_PATH + '/HPO(t={0},a={1},l={2})'.format(params['temperature'],params['alpha'],params['lr']), MODEL_NAME)

    return distiller_history, distiller.student

scan_object = talos.Scan(dummy_x, dummy_y,
                         x_val=dummy_x, y_val=dummy_y, 
                         model=distiller_model,
                         experiment_name='logs/',
                         params=p,
                         print_params=False,
                         save_weights=True)

**Select Best Student**

In [None]:
#Re-Create Data Generator
_, (validation_generator, nb_validation_samples)  = create_data_generator(pre_process=student_preprocess)

In [None]:
#Select the model with highest validation accuracy
def ChooseBest():
    best_model = 0
    best_temp = 0
    best_alpha = 0
    best_lr = 0
    max_val_acc = 0

    for a in ALPHA:
        for t in TEMPERATURE:
          for l in LEARNING_RATE:
              print("\nFor Temperature = {0} & alpha= {1} & lr={2}".format(t,a,l))
              #load trained model with temp t and alpha a
              model_path = HPO_PATH + '/HPO(t={0},a={1},l={2})'.format(t,a,l)
              student_model = load_m(model_path, MODEL_NAME)
              student_model.compile(metrics=['accuracy'], loss=CategoricalCrossentropy(from_logits=True))
              #validate model
              val_acc = student_model.evaluate(validation_generator)[1]

              #update best parameters
              if val_acc > max_val_acc:
                  max_val_acc = val_acc
                  best_model = student_model
                  best_alpha = a
                  best_temp = t
                  best_lr = l
                
    return best_alpha, best_temp, best_lr, best_model

best_alpha, best_temp, best_lr, best_model = ChooseBest()

print('\nBest Temperature:', best_temp)
print('Best Alpha:', best_alpha)
print('Best Learning Rate:', best_lr)

**Evaluating best student model on Validation and Test**

In [None]:
#LOAD TEST DATA
test_datagen = ImageDataGenerator(preprocessing_function=student_preprocess)

if not os.path.exists(TEST_DATA_DIR):
    print("TEST DATA DOES NOT EXITS!")
else:
    print("LOAD TEST SAMPLES...")
    test_generator = test_datagen.flow_from_directory(
                TEST_DATA_DIR,
                target_size=(img_rows,img_cols),
                batch_size=BATCH_SIZE,
                class_mode='categorical',
                seed=42,
                shuffle=False)

    #CHECK  THE NUMBER OF SAMPLES
    nb_test_samples = len(test_generator.filenames)
    if nb_test_samples == 0:
        print("NO DATA TEST FOUND IN TEST FOLDER!")

In [None]:
#Evaluate Best Student model against Teacher model on test set
print("Evaluating Best Student on validation dataset")
best_model.evaluate(validation_generator)

print("\nEvaluating Best Student on test dataset")
best_model.evaluate(test_generator)

**Saving Model**

In [None]:
import shutil

des = PROPOSED_MODEL_PATH + '-KD'

#save models
shutil.copytree(HPO_PATH + '/HPO(t={0},a={1},l={2})'.format(best_temp,best_alpha,best_lr), des)

print("[INFO] BEST STUDENT MODEL AND HISTORY SAVED")

In [None]:
#Figure
dpi = 1000
plt.rcParams.update({'figure.dpi': dpi})
figsize = (12, 12)

history = load_h(PROPOSED_MODEL_PATH + '-KD', MODEL_NAME)

#Markers
marker_train_accuracy = 's'
marker_validation_accuracy = 'x'
marker_train_loss = 'o'
marker_validation_loss = '|'
marker_fillstyle_train = 'none'
marker_fillstyle_validation = 'none'
marker_plot_markersize = 25
marker_plot_markerwidth = 3

#Lines
line_style_train = '-' 
line_style_validation = '--'
line_width_train = '5'
line_width_val = line_width_train
line_color_train_accuracy = 'black'
line_color_val_accuracy = 'black'
line_color_train_loss = 'black'
line_color_val_loss = 'black'

#Labels
train_accuracy_label = 'Train ' + 'Acc'
validation_accuracy_label = 'Val ' + 'Acc'
train_loss_label = 'Train ' + 'Loss'
validation_loss_label = 'Val ' 'Loss'
x_label_font_size = 56
y_label_font_size = x_label_font_size
x_label_font = 'Tahoma'
y_label_font = x_label_font
# x_label_fontweight = 'bold'
# y_label_fontweight = x_label_fontweight

#Ticks
spine_axis_thickness = 4
tick_font_size = 42
tick_length = 12
tick_width = spine_axis_thickness

#Legend
legend_border_pad = 0.35
legend_line_width = 5
legend_font_size = 50
legend_edge_color = 'black'
legend_label_spacing = 0.5
legend_location = 'best'
legend_ncol = 1
legend_font = 'Tahoma'
legend_has_frame = True

In [None]:
#Accuracy and Loss Graph
epochs = EPOCHS

plt.style.use("default")
plt.figure(figsize = figsize, 
           dpi = 600, 
           edgecolor = 'black', 
           facecolor = 'white', 
           linewidth = 0)
plt.tight_layout()
plt.rc('xtick', labelsize = tick_font_size, direction="in") 
plt.rc('ytick', labelsize = tick_font_size, direction="in") 

fig, ax = plt.subplots(figsize = figsize)
plt.gcf().subplots_adjust(bottom = 0.15)
plt.setp(ax.spines.values(), linewidth = spine_axis_thickness)

plt.tick_params(length = tick_length, 
                width = tick_width, 
                right = True, 
                top = True)

plt.plot(np.arange(1, epochs + 1), 
         history["accuracy"], 
         mew = marker_plot_markerwidth, 
         color = line_color_train_accuracy, 
         lw = line_width_train, 
         marker = marker_train_accuracy, 
         markersize = marker_plot_markersize, 
         fillstyle = marker_fillstyle_train, 
         ls = line_style_train, 
         label = train_accuracy_label)

plt.plot(np.arange(1, epochs + 1), 
         history["val_accuracy"], 
         mew = marker_plot_markerwidth, 
         color = line_color_val_accuracy, 
         lw = line_width_val, 
         marker = marker_validation_accuracy, 
         markersize = marker_plot_markersize, 
         fillstyle = marker_fillstyle_validation, 
         ls = line_style_validation,  
         label = validation_accuracy_label)

plt.tight_layout()
save_fig(FIG_PATH, MODEL_NAME + '-AccuracyGraph')

In [None]:
#Accuracy and Loss Graph
epochs = EPOCHS

plt.style.use("default")
plt.figure(figsize = figsize, 
           dpi = 600, 
           edgecolor = 'black', 
           facecolor = 'white', 
           linewidth = 0)
plt.tight_layout()
plt.rc('xtick', labelsize = tick_font_size, direction="in") 
plt.rc('ytick', labelsize = tick_font_size, direction="in") 

fig, ax = plt.subplots(figsize = figsize)
plt.gcf().subplots_adjust(bottom = 0.15)
plt.setp(ax.spines.values(), linewidth = spine_axis_thickness)

plt.tick_params(length = tick_length, 
                width = tick_width, 
                right = True, 
                top = True)

plt.plot(np.arange(1, epochs + 1), 
         history["loss"], 
         mew = marker_plot_markerwidth, 
         color = line_color_train_loss, 
         lw = line_width_train, 
         marker = marker_train_loss, 
         markersize = marker_plot_markersize, 
         fillstyle = marker_fillstyle_train, 
         ls = line_style_train, label = train_loss_label)

plt.plot(np.arange(1, epochs + 1), 
         history["val_student_loss"], 
         mew = marker_plot_markerwidth, 
         color = line_color_val_loss, 
         lw = line_width_val, 
         marker = marker_validation_loss, 
         markersize = marker_plot_markersize, 
         fillstyle = marker_fillstyle_validation, 
         ls = line_style_validation,  
         label = validation_loss_label)

plt.xlabel("Epochs", fontfamily = x_label_font, fontsize = x_label_font_size, color ='black')
plt.ylabel("Accuracy/Loss", fontfamily = y_label_font, fontsize = y_label_font_size, color = 'black')

legend = plt.legend(loc = legend_location, 
                    ncol = legend_ncol, 
                    frameon = legend_has_frame, 
                    fontsize=legend_font_size, 
                    edgecolor=legend_edge_color, 
                    borderpad=legend_border_pad, 
                    labelspacing=legend_label_spacing)

frame = legend.get_frame()
legend.get_frame().set_linewidth(legend_line_width)
legend.get_frame().set_edgecolor(legend_edge_color)
plt.setp(legend.texts, family = legend_font)

plt.tight_layout()
save_fig(FIG_PATH, MODEL_NAME + '-LossGraph')