# Importing Libraries

In [None]:
from tensorflow.keras.models import model_from_json
from fer_model import get_fer_model
import tensorflow as tf
import tempfile 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from livelossplot import PlotLossesKerasTF
from tensorflow.keras.optimizers import Adam

import os
import zipfile

import tensorflow_model_optimization as tfmot
import pickle

from evaluation import get_metrics_rafdb

In [None]:
import time
TIMESTAMP = round(time.time())
print("Timestamp is", TIMESTAMP)
if "SPARSITY" not in locals(): 
    SPARSITY = 0.6
print("SPARSITY =", SPARSITY)

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Loading the RAFDB Dataset

In [None]:
from data import load_rafdb
train_generator = load_rafdb("train")
test_generator = load_rafdb("test")

# Load Baseline Model

In [None]:
model = get_fer_model(input_size=100, input_channels=1, out_classes=7)
model.load_weights("weights_rafdb/model_weights_1626211720.h5")

In [None]:
print("Baseline model performance:")
file = open("logs_rafdb/model_metrics_1626211720", "rb")
baseline_metrics = pickle.load(file)
for x, y in baseline_metrics.items(): 
    print(x, "-->", y)

# Apply Pruning

In [None]:
def prune(model, sparsity=0.8):
    
    epochs = 2
    
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
            sparsity, 0, end_step=-1, frequency=32
        )
    }
    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
    opt = Adam(lr=0.0005)
    model_for_pruning.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
    
    # Fine-Tuning 
    logdir = tempfile.mkdtemp()
    print("Saving logs to:", logdir)

    steps_per_epoch = train_generator.n//train_generator.batch_size
    validation_steps = test_generator.n//test_generator.batch_size

    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                                  patience=2, min_lr=0.00001, mode='auto')
    checkpoint_name = "weights_rafdb/pruned_model_weights_s%d_%s.h5" % (sparsity*100, TIMESTAMP)
    checkpoint = ModelCheckpoint(checkpoint_name, 
                                 monitor='val_accuracy', 
                                 save_weights_only=True,
                                 save_best_only=True,
                                 mode='max', verbose=1)
    callbacks = [PlotLossesKerasTF(), 
                 checkpoint, 
                 reduce_lr, 
                 tfmot.sparsity.keras.UpdatePruningStep(),
                 tfmot.sparsity.keras.PruningSummaries(log_dir=logdir)]

    model_for_pruning.fit(
        x=train_generator,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_data = test_generator,
        validation_steps = validation_steps,
        callbacks=callbacks
    )
    
    # Only use the best weights 
    model_for_pruning.load_weights(checkpoint_name)
    model_for_pruning = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    model_for_pruning.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
    return model_for_pruning

In [None]:
pruned_model = prune(model, sparsity=SPARSITY)

# Evaluate Pruned Model

In [None]:
print("Model performance after pruning:")
metrics = get_metrics_rafdb(pruned_model, test_generator)
metrics

### Store results

In [None]:
with open("logs_rafdb/pruned_model_metrics_s%d_%s" % (SPARSITY*100, TIMESTAMP), 'wb') as pruned_model_metrics_file:
        pickle.dump(metrics, pruned_model_metrics_file)

# Apply Quantisation to the Pruned Model 

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
# This optimisation includes the quantisation 
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

# Evaluate Pruned and Quantised Model

In [None]:
from evaluation import get_metrics_quantised
metrics = get_metrics_quantised(quantized_and_pruned_tflite_model, test_generator, dataset="rafdb")

In [None]:
print("Pruned and quantised model performance:")
metrics

### Store results

In [None]:
with open("logs_rafdb/pruned_and_quantised_model_metrics_s%d_%s" % (SPARSITY*100, TIMESTAMP), 
          'wb') as pruned_and_quantised_model_metrics_file:
        pickle.dump(metrics, pruned_and_quantised_model_metrics_file)