# Quantization Aware Training using the Model Compression Toolkit - example in Keras
[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_qat.ipynb)
## Overview
This tutorial will demonstrate how to use the Quantization Aware Training (QAT) API of the Model Compression Toolkit (MCT). We will train a neural network on the MNIST dataset and apply quantization using the MCT QAT API to optimize the model for efficient hardware deployment without sacrificing accuracy.

## Summary
In this tutorial, we will cover:

1. **Training a Keras model on MNIST:** We'll begin by constructing a simple neural network and training it on the MNIST dataset. 
2. **Configuring Target Platform Capabilities (TPC):** Define the quantization settings for weights and activations.
3. **Preparing the Model for QAT:** Convert the floating-point model into a QAT-ready model using MCT. 
4. **Training the Model with QAT:**  Perform quantization-aware training to preserve model accuracy.
5. **Evaluating and Exporting the Quantized Model:** Finalize and export the optimized quantized model for deployment.

## Setup
Install the relevant packages:

In [None]:
TF_VER = '2.14.0'
!pip install -q tensorflow~={TF_VER}

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

In [None]:
import tensorflow as tf
from keras import Model, layers, datasets
import model_compression_toolkit as mct
import numpy as np

## Loading and Preprocessing MNIST
Let's define the dataset loaders to retrieve the train and test parts of the MNIST dataset, including preprocessing:

In [None]:
num_classes = 10
input_shape = (28, 28, 1)

# Load the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

# Normalize the images to [0, 1] range
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0

# Add Channels axis to data
train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)

# convert class vectors to binary class matrices
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes)

## Creating a Keras Model
In this section, we create a simple Keras model to demonstrate the QAT process. The model consists of two convolutional layers, two dense layers, and dropout layers for regularization.

In [None]:
def create_model():
    _input = layers.Input(shape=input_shape)
    x = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(_input)
    x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=_input, outputs=x)
    model.summary()
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    return model

## Training the Model on MNIST
Next, we will train the dense model using the preprocessed MNIST dataset.

In [None]:
epochs = 6
batch_size = 128

# Train and evaluate the model
model = create_model()
model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size, validation_data=(test_images, test_labels))
model.evaluate(test_images, test_labels)

## Preparing the Model for Hardware-Friendly Quantization Aware Training with MCT
## Target Platform Capabilities
MCT optimizes the model for dedicated hardware. This is done using TPC (for more details, please visit our [documentation](https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/modules/target_platform_capabilities.html)). In this tutorial, we use a TPC configuration that applies 2-bit quantization for weights and 3-bit quantization for activations.

If desired, you can skip this step and directly use the pre-configured [`get_target_platform_capabilities`](https://sonysemiconductorsolutions.github.io/mct-model-optimization/api/api_docs/methods/get_target_platform_capabilities.html) function to obtain an initialized TPC.

In [None]:
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import schema

def get_tpc():
    """
    Assuming a target hardware that uses a power-of-2 threshold for activations and
    a symmetric threshold for the weights. The activations are quantized to 3 bits, and the kernel weights
    are quantized to 2 bits. Our assumed hardware does not require quantization of some layers
    (e.g. Flatten & Droupout).
    This function generates a FrameworkQuantizationCapabilities with the above specification.

    Returns:
         FrameworkQuantizationCapabilities object
    """

    # define a default quantization config for all non-specified weights attributes.
    default_weight_attr_config = schema.AttributeQuantizationConfig(
        weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
        weights_n_bits=8,
        weights_per_channel_threshold=False,
        enable_weights_quantization=False,
        lut_values_bitwidth=None)

    # define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
    kernel_base_config = schema.AttributeQuantizationConfig(
        weights_quantization_method=QuantizationMethod.SYMMETRIC,
        weights_n_bits=2,
        weights_per_channel_threshold=True,
        enable_weights_quantization=True,
        lut_values_bitwidth=None)

    # define a quantization config to quantize the bias (for layers where there is a bias attribute).
    bias_config = schema.AttributeQuantizationConfig(
        weights_quantization_method=QuantizationMethod.POWER_OF_TWO,
        weights_n_bits=FLOAT_BITWIDTH,
        weights_per_channel_threshold=False,
        enable_weights_quantization=False,
        lut_values_bitwidth=None)

    # Create a default OpQuantizationConfig where we use default_weight_attr_config as the default
    # AttributeQuantizationConfig for weights with no specific AttributeQuantizationConfig.
    # MCT will compress a layer's kernel and bias according to the configurations that are
    # set in KERNEL_ATTR and BIAS_ATTR that are passed in attr_weights_configs_mapping.
    default_config = schema.OpQuantizationConfig(
        default_weight_attr_config=default_weight_attr_config,
        attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config,
                                      BIAS_ATTR: bias_config},
        activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
        activation_n_bits=3,
        supported_input_activation_n_bits=8,
        enable_activation_quantization=True,
        quantization_preserving=False,
        fixed_scale=None,
        fixed_zero_point=None,
        simd_size=None,
        signedness=schema.Signedness.AUTO)

    # Set default QuantizationConfigOptions in new TargetPlatformCapabilities to be used when no other
    # QuantizationConfigOptions is set for an OperatorsSet.
    default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=[default_config])
    no_quantization_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False)
                              .clone_and_edit_weight_attribute(enable_weights_quantization=False))

    operator_set = []

    operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.DROPOUT, qc_options=no_quantization_config))
    operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FLATTEN, qc_options=no_quantization_config))


    tpc = schema.TargetPlatformCapabilities(default_qco=default_configuration_options,
                                            tpc_minor_version=1,
                                            tpc_patch_version=0,
                                            tpc_platform_type="custom_qat_notebook_tpc",
                                            operator_set=tuple(operator_set))
    return tpc


## Representative Dataset
For quantization with MCT, we need to define a representative dataset required by the PTQ algorithm. This dataset is a generator that returns a list of images:

In [None]:
n_iter = 10

def representative_data_gen():
    def _generator():
        for _ind in range(n_iter):
            yield [train_images[_ind][np.newaxis, ...]]
    return _generator

### Creating a QAT-Ready Model with MCT
The MCT converts a floating-point model into a quantized model using post-training quantization. The returned model includes trainable quantizers and is ready for fine-tuning, making it a "QAT-ready" model.

In [None]:
qat_model, _, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(
    model,
    representative_data_gen(),
    target_platform_capabilities=get_tpc())
qat_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"], run_eagerly=True)

Lets evaluate the performance after the basic post-trainig quantization.

In [None]:
score = qat_model.evaluate(test_images, test_labels, verbose=0)
print(f"PTQ model test accuracy: {score[1]:02.4f}")

## User Quantization Aware Training

In [None]:
qat_model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size, validation_split=0.2)

score = qat_model.evaluate(test_images, test_labels, verbose=0)
print(f"QAT model test accuracy: {score[1]:02.4f}")

## Finalizing the QAT model: 
Remove the 'QuantizeWrapper' layers to retain only the layers with quantized weights (FakeQuant values).

In [None]:
quantized_model = mct.qat.keras_quantization_aware_training_finalize_experimental(qat_model)

quantized_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
score = quantized_model.evaluate(test_images, test_labels, verbose=0)
print(f"Quantized model test accuracy: {score[1]:02.4f}")

Now, we can export the quantized model to Keras:

In [None]:
mct.exporter.keras_export_model(model=quantized_model, save_model_path='qmodel.keras')

## Conclusion
In this tutorial, we explored how to perform Quantization Aware Training (QAT) using the Model Compression Toolkit (MCT) with a Keras model. We began by constructing a simple neural network and preparing it for quantization by configuring the Target Platform Capabilities (TPC). Then, we converted the model into a QAT-ready format and demonstrated how to train and fine-tune it using hardware-friendly quantization settings. This approach can significantly reduce the model size and improve inference speed while maintaining high accuracy, making it ideal for edge AI applications.

Feel free to experiment with different configurations to see how they impact your models.

Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.