# Importing Libraries

In [None]:
from tensorflow.keras.models import model_from_json
from fer_model import get_fer_model
import tensorflow as tf
import tempfile 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from livelossplot import PlotLossesKerasTF
from tensorflow.keras.optimizers import Adam

import os
import zipfile

import tensorflow_model_optimization as tfmot
import pickle

In [None]:
import time
TIMESTAMP = round(time.time())

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Loading the RAFDB Dataset

In [None]:
from data import load_rafdb
train_generator = load_rafdb("train")
test_generator = load_rafdb("test")

# Load Baseline Model

In [None]:
model = get_fer_model(input_size=100, input_channels=1, out_classes=7)
model.load_weights("weights_rafdb/model_weights_1626211720.h5")

In [None]:
from evaluation import get_metrics_rafdb

print("Baseline model performance:")
get_metrics_rafdb(model, test_generator)

# Apply Weight Clustering

In [None]:
if "N_CLUSTERS" not in locals(): 
    N_CLUSTERS = 32
print("N_CLUSTERS =", N_CLUSTERS)

In [None]:
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': N_CLUSTERS,
  'cluster_centroids_init': CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)

# Use smaller learning rate for fine-tuning clustered model
opt = Adam(lr=0.0005)

clustered_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
# Fine-Tuning 
epochs = 3
steps_per_epoch = train_generator.n//train_generator.batch_size
validation_steps = test_generator.n//test_generator.batch_size

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                              patience=2, min_lr=0.00001, mode='auto')
checkpoint_name = "weights_rafdb/clustered_model_c%d_%s.h5" % (N_CLUSTERS, TIMESTAMP)
checkpoint = ModelCheckpoint(checkpoint_name, 
                             monitor='val_accuracy',
                             save_best_only=True,
                             save_weights_only=True, 
                             mode='max', verbose=1)
callbacks = [PlotLossesKerasTF(), 
             checkpoint, 
             reduce_lr, 
             ]

clustered_model.fit(
    x=train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=epochs,
    validation_data = test_generator,
    validation_steps = validation_steps,
    callbacks=callbacks
)

# Only use the best weights 
clustered_model.load_weights(checkpoint_name)

In [None]:
clustered_model = tfmot.clustering.keras.strip_clustering(clustered_model)

# Evaluate Clustered Model

In [None]:
print("Model performance after weight clustering:")
metrics = get_metrics_rafdb(clustered_model, test_generator)
metrics

### Store results

In [None]:
with open("logs_rafdb/clustered_model_metrics_c%d_%s" % (N_CLUSTERS, TIMESTAMP), 
          'wb') as clustered_model_metrics_file:
        pickle.dump(metrics, clustered_model_metrics_file)

# Apply Quantisation to Clustered Model 

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(clustered_model)
# This optimisation includes the quantisation 
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_clustered_tflite_model = converter.convert()

# Evaluate Quantised and Clustered Model

In [None]:
from evaluation import get_metrics_quantised
metrics = get_metrics_quantised(quantized_and_clustered_tflite_model, test_generator, dataset="rafdb")

In [None]:
print("Clustered and quantised model performance:")
metrics

### Store results

In [None]:
with open("logs_rafdb/clustered_and_quantised_model_metrics_c%d_%s" % (N_CLUSTERS, TIMESTAMP), 
          'wb') as clustered_and_quantised_model_metrics_file:
        pickle.dump(metrics, clustered_and_quantised_model_metrics_file)