# Quantization Aware Training using the Model Compression Toolkit - example in Keras


## Overview
This tutorial will show how to use the Quantization Aware Training API of the Model Compression Toolkit. We will train a model on the MNIST dataset and quantize it with the Model Compression Toolkit QAT API.[Run this tutorial in Google Colab](https://colab.research.google.com/github/elad-c/model_optimization/blob/main/tutorials/example_keras_qat.ipynb)

## Setup
Install relevant packages

In [4]:
! pip install -q tensorflow
! pip install -q model-compression-toolkit 

[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.9.1 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.1 which is incompatible.[0m[31m
[0m

In [5]:
import tensorflow as tf
from keras.datasets import mnist
from keras import Model, layers, datasets
import model_compression_toolkit as mct
from model_compression_toolkit.qat.keras.quantization_facade import DEFAULT_KERAS_TPC as default_tpc
import numpy as np

OSError: /data/projects/swat/envs/eladc/jupyterlab/lib/python3.8/site-packages/tensorflow/python/platform/../../core/platform/_cpu_feature_guard.so: cannot open shared object file: No such file or directory

## Init Keras model

In [None]:
num_classes = 10
input_shape = (28, 28, 1)

_input = layers.Input(shape=input_shape)
x = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(_input)
x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(num_classes, activation='softmax')(x)
model = Model(inputs=_input, outputs=x)

model.summary()

## Init MNIST dataset

In [None]:
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# Normalize images
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

# Add Channels axis to data
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)


## Train a Keras classifier model on MNIST

In [None]:
# train float model
batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)

# evaluate float model
score = model.evaluate(x_test, y_test, verbose=0)
print(f"Float model test accuracy: {score[1]:02.4f}")


## Prepare model for Hardware-Friendly Quantization Aware Training with MCT
The MCT takes the float model and quantizes it in a post-training quantization fashion. Then returns a QAT ready model to the user for Quantization Aware Training.

In [None]:
def gen_representative_dataset():
    def _generator():
        for _img in x_train:
            yield [_img[np.newaxis, ...]]
    return _generator().__next__


# Set quantization params to: 2 bits for weights, 3 bits for activations
default_tpc.tp_model.default_qco.base_config.weights_n_bits = 2
default_tpc.tp_model.default_qco.base_config.activation_n_bits = 3

qat_model, _, custom_objects = mct.keras_quantization_aware_training_init(model,
                                                                          gen_representative_dataset(),
                                                                          core_config=mct.CoreConfig(n_iter=10),
                                                                          target_platform_capabilities=default_tpc)

qat_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
score = qat_model.evaluate(x_test, y_test, verbose=0)
print(f"PTQ model test accuracy: {score[1]:02.4f}")

## User Quantization Aware Training

In [None]:
qat_model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)

score = qat_model.evaluate(x_test, y_test, verbose=0)
print(f"QAT model test accuracy: {score[1]:02.4f}")


In [None]:
## Finalize QAT model
Remove QuantizeQrapper layers and leave only layers with quantized weights (FakeQuant values)

In [None]:
quantized_model = mct.keras_quantization_aware_training_finalize(qat_model)

quantized_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
score = quantized_model.evaluate(x_test, y_test, verbose=0)
print(f"Quantized model test accuracy: {score[1]:02.4f}")