In [None]:
!pip install image-classifiers

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from classification_models.tfkeras import Classifiers
import matplotlib.pyplot as plt
import os
import time

from modelcompressionutils import eval_tflite_model, get_gzipped_model_size

In [None]:
from collections import namedtuple
from typing import NamedTuple, List

In [None]:
class ModelDescriptor:
    def __init__(self, name: str, path: str):
        if not os.path.exists(path):
            raise Exception('Path for model {name} is not valid ({path})'
                            .format(name=name, path=path))
        self.name = name
        self.path = path
    
    def is_tflite_model(self):
        return self.path.endswith('.tflite')
    
    def evalute_model(self, generator) -> float:
        accuracy = -1.0
        if self.is_tflite_model():
            accuracy = eval_tflite_model(self.path, generator)
        else:
            model = keras.models.load_model(self.path)
            model.compile(
                optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy']
            )
            res = model.evaluate(val_dg, verbose=1)
            accuracy = res[1]
        return accuracy

In [None]:
def validate_model_paths(model_descriptors: List[ModelDescriptor]):
    for model_desc in model_descriptors:
        if not os.path.exists(model_desc.path):
            raise Exception('Path for model {name} is not valid ({path})'
                            .format(name=model_desc.name, path=model_desc.path))

In [None]:
class ModelPerformance(NamedTuple):
    name: str
    evaluation_time: float # not comparable between lite and regurlar,
    accuracy: float
    size_on_disk: int
    zipped_size: int

In [None]:
# ModelDescriptorFactory = namedtuple('ModelDescriptor', ['name', 'path'])
PerformanceFactory = namedtuple('ModelPerformance', ['name', 'evaluation_time', 'accuracy', 'size_on_disk', 'zipped_size'])

In [None]:
model_list = [
    ModelDescriptor('ResNet18', '/media/pi/KINGSTON/models/base_modelv6.h5'),
    ModelDescriptor('Base_TL', '/media/pi/KINGSTON/models/Quantization/base_tflite_model.tflite'),
    # Distillation
    ModelDescriptor('MobNetv2', '/media/pi/KINGSTON/models/Distillation/best_mobilenetv2.h5'),
    ModelDescriptor('Dist_MobNetv2', '/media/pi/KINGSTON/models/Distillation/best_distilled_mobilenetv2.h5'),
    # Pruning
    ModelDescriptor('PR', '/media/pi/KINGSTON/models/Pruning/pruned_model.h5'),
    ModelDescriptor('PR_TL_Fp16', '/media/pi/KINGSTON/models/Pruning/fp16_quant_pruned.tflite'),
    ModelDescriptor('PR_TL_Int8', '/media/pi/KINGSTON/models/Pruning/int_quant_pruned.tflite'),
    # Quantization
    ModelDescriptor('QU_TL_Fp16', '/media/pi/KINGSTON/models/Quantization/quantized_fp16.tflite'),
    ModelDescriptor('QU_TL_Int8', '/media/pi/KINGSTON/models/Quantization/quantized_int8.tflite'),
    # QAT
    ModelDescriptor('QAT_Fp16', '/media/pi/KINGSTON/models/QAT/qat_fp16.tflite'),
    ModelDescriptor('QAT_Int8', '/media/pi/KINGSTON/models/QAT/qat_in8.tflite'),
    # Weight clustering
    ModelDescriptor('CL_KM32', '/media/pi/KINGSTON/models/Weight clustering/clustered_model_kpp32.h5'),
    ModelDescriptor('CL_KM32_TL', '/media/pi/KINGSTON/models/Weight clustering/clustered_model_kpp32_tflite.tflite'),
    ModelDescriptor('CL_KM256', '/media/pi/KINGSTON/models/Weight clustering/clustered_model_kpp256.h5'),
    ModelDescriptor('CL_KM256_TL', '/media/pi/KINGSTON/models/Weight clustering/clustered_model_kpp256_tflite.tflite'),
    ModelDescriptor('CL_Lin32', '/media/pi/KINGSTON/models/Weight clustering/clustered_model_lin32.h5'),
    ModelDescriptor('CL_Lin32_TL', '/media/pi/KINGSTON/models/Weight clustering/clustered_model_lin32_tflite.tflite'),
    # Combined Pruning, Weight clustering and QAT
    ModelDescriptor('PCQ_TL_Fp16', '/media/pi/KINGSTON/models/Combined/quantized_fp16.tflite'),
    ModelDescriptor('PCQ_TL_Int8', '/media/pi/KINGSTON/models/Combined/int8_quantized_model.tflite'),
]

In [None]:
validate_model_paths(model_list)

## Model testing

In [None]:
_, preprocess_input = Classifiers.get('resnet18')

In [None]:
VALIDATION_DS_PATH = '/media/pi/KINGSTON/CRC-VAL-HE-7K'
SEED = 1
BATCH_SIZE = 16

In [None]:
val_img_gen = keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=preprocess_input
)
val_dg = val_img_gen.flow_from_directory(
    VALIDATION_DS_PATH,
    target_size=(224, 224),
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    shuffle=False, 
    seed=SEED
)

In [None]:
results = []
for model_desc in model_list:
    
    print(f'Processing {model_desc.name}')
    start_time = time.perf_counter()
    accuracy = model_desc.evalute_model(val_dg)
    end_time = time.perf_counter()
    
    size_on_disk = os.path.getsize(model_desc.path) / pow(10, 6)
    zipped_size = get_gzipped_model_size(model_desc.path) / pow(10, 6)
    evaluation_time = end_time - start_time # time in seconds
    
    model_perf = PerformanceFactory(
        name=model_desc.name, 
        evaluation_time=evaluation_time,
        accuracy=accuracy,
        size_on_disk=size_on_disk,
        zipped_size=zipped_size
    )
    results.append(model_perf)

In [None]:
fields = list(filter(lambda x: x != 'name', results[0]._fields))

In [None]:
def model_color_code(model_name: str):
    if model_name.startswith("PR"):
        return 'green'
    elif model_name.startswith("QU"):
        return 'orangered'
    elif model_name.startswith("QAT"):
        return 'darkviolet'
    elif model_name.startswith("CL"):
        return 'gold'
    elif model_name.startswith("PCQ"):
        return 'purple'
    elif model_name.startswith("Dist"):
        return "cyan"
    else:
        return 'blue'
    
def plot_method_color_legend():
    import matplotlib
    # example taken from https://stackoverflow.com/a/53615732
    # Create a color palette
    palette = dict(zip(
        ['Plain Model', 'Distillation', 'Pruning', 'Post Training Quantization', 
         'Quantization Aware Training (QAT)', 'Weight Clustering', 'Pruning + Weight Clustering + QAT (PCQAT)'], 
        ['blue', 'cyan', 'green', 'orangered', 'darkviolet', 'gold', 'purple'])
    )
    # Create legend handles manually
    handles = [matplotlib.patches.Patch(color=palette[x], label=x) for x in palette.keys()]
    # Create legend
    plt.legend(handles=handles, loc='center', markerscale=2.0, fontsize='xx-large')
    # Get current axes object and turn off axis
    plt.gca().set_axis_off()
    plt.show()

In [None]:
plot_method_color_legend()

In [None]:
for field in fields:
    fig = plt.figure(figsize=(20, 5))
    ax = fig.add_axes([0,0,1,1])
    names = list(map(lambda x: x.name, results))
    values = list(map(lambda x: getattr(x, field), results))
    colors = list(map(lambda x: model_color_code(x.name), results))
    bars = ax.bar(names, values, color=colors)
    for rect, val in zip(bars, values):
        height = float(rect.get_height())
        plt.text(rect.get_x() + rect.get_width()/2.0, height, '%.2f' % val,
                 ha='center', va='bottom',fontsize=10)
    ax.set_xlabel('Model name')
    ax.set_ylabel(field)
    plt.show()