# Quantization Aware Training using the Model Compression Toolkit - example in Keras


## Overview
This tutorial will show how to use the Quantization Aware Training API of the Model Compression Toolkit. We will train a model on the MNIST dataset and quantize it with the Model Compression Toolkit QAT API.
[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/example_keras_qat.ipynb)

## Setup
Install relevant packages

In [4]:
! pip install -q tensorflow
! pip install -q model-compression-toolkit 

In [5]:
import tensorflow as tf
from keras.datasets import mnist
from keras import Model, layers, datasets
import model_compression_toolkit as mct
import numpy as np

## Init TargetPlatformModel
Setting a TP Model with quantization of 2 bits for weights and 3 bits for activations.

In [None]:
def get_tpc():
    # Generate a TargetPlatformCapabilities with power of two quantization, 3 bits for
    # activations and 2 bits for weights
    tp = mct.target_platform
    default_config = tp.OpQuantizationConfig(
        activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
        weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
        activation_n_bits=3,
        weights_n_bits=2,
        weights_per_channel_threshold=True,
        enable_weights_quantization=True,
        enable_activation_quantization=True,
        quantization_preserving=False,
        fixed_scale=1.0,
        fixed_zero_point=0,
        weights_multiplier_nbits=0)

    default_configuration_options = tp.QuantizationConfigOptions([default_config])
    tp_model = tp.TargetPlatformModel(default_configuration_options)
    with tp_model:
        tp.OperatorsSet("NoQuantization",
                        tp.get_default_quantization_config_options().clone_and_edit(
                            enable_weights_quantization=False,
                            enable_activation_quantization=False))

    tpc = tp.TargetPlatformCapabilities(tp_model)
    with tpc:
        # No need to quantize Flatten and Dropout layers
        tp.OperationsSetToLayers("NoQuantization", [layers.Flatten,
                                                    layers.Dropout])

    return tpc


## Init Keras model

In [None]:
num_classes = 10
input_shape = (28, 28, 1)

_input = layers.Input(shape=input_shape)
x = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(_input)
x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(num_classes, activation='softmax')(x)
model = Model(inputs=_input, outputs=x)

model.summary()

## Init MNIST dataset

In [None]:
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# Normalize images
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

# Add Channels axis to data
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)


## Train a Keras classifier model on MNIST

In [None]:
# train float model
batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)

# evaluate float model
score = model.evaluate(x_test, y_test, verbose=0)
print(f"Float model test accuracy: {score[1]:02.4f}")


## Prepare model for Hardware-Friendly Quantization Aware Training with MCT
The MCT takes the float model and quantizes it in a post-training quantization fashion. Then returns a QAT ready model to the user for Quantization Aware Training.

In [None]:
def gen_representative_dataset():
    def _generator():
        for _img in x_train:
            yield [_img[np.newaxis, ...]]
    return _generator().__next__


qat_model, _, custom_objects = mct.keras_quantization_aware_training_init(model,
                                                                          gen_representative_dataset(),
                                                                          core_config=mct.CoreConfig(n_iter=10),
                                                                          target_platform_capabilities=get_tpc())

qat_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
score = qat_model.evaluate(x_test, y_test, verbose=0)
print(f"PTQ model test accuracy: {score[1]:02.4f}")

## User Quantization Aware Training

In [None]:
qat_model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)

score = qat_model.evaluate(x_test, y_test, verbose=0)
print(f"QAT model test accuracy: {score[1]:02.4f}")


In [None]:
## Finalize QAT model
Remove QuantizeQrapper layers and leave only layers with quantized weights (FakeQuant values)

In [None]:
quantized_model = mct.keras_quantization_aware_training_finalize(qat_model)

quantized_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
score = quantized_model.evaluate(x_test, y_test, verbose=0)
print(f"Quantized model test accuracy: {score[1]:02.4f}")

Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.