# Quantization Aware Training using the Model Compression Toolkit - example in Keras


## Overview
This tutorial will show how to use the Quantization Aware Training API of the Model Compression Toolkit. We will train a model on the MNIST dataset and quantize it with the Model Compression Toolkit QAT API.
[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_qat.ipynb)

## Setup
Install relevant packages

In [4]:
TF_VER = '2.14.0'

!pip install -q tensorflow=={TF_VER}
! pip install -q mct-nightly

In [5]:
import tensorflow as tf
from keras import Model, layers, datasets
import model_compression_toolkit as mct
import numpy as np


## Create TargetPlatformCapabilities
For this tutorial, we will use a TargetPlatformCapabilities (TPC) with quantization of 2 bits for weights and 3 bits for activations.

You can skip this part and use [get_target_platform_capabilities](https://sony.github.io/model_optimization/docs/api/api_docs/methods/get_target_platform_capabilities.html) to get an initilized TPC.

In [7]:
from model_compression_toolkit import DefaultDict
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import AttributeQuantizationConfig, Signedness
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, BIAS

tp = mct.target_platform


def get_tpc():
    """
    Assuming a target hardware that uses a power-of-2 threshold for activations and
    a symmetric threshold for the weights. The activations are quantized to 3 bits, and the kernel weights
    are quantized to 2 bits. Our assumed hardware does not require quantization of some layers
    (e.g. Flatten & Droupout).
    This function generates a TargetPlatformCapabilities with the above specification.

    Returns:
         TargetPlatformCapabilities object
    """

    # define a default quantization config for all non-specified weights attributes.
    default_weight_attr_config = AttributeQuantizationConfig(
        weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
        weights_n_bits=8,
        weights_per_channel_threshold=False,
        enable_weights_quantization=False,
        lut_values_bitwidth=None)

    # define a quantization config to quantize the kernel (for layers where there is a kernel attribute).
    kernel_base_config = AttributeQuantizationConfig(
        weights_quantization_method=tp.QuantizationMethod.SYMMETRIC,
        weights_n_bits=2,
        weights_per_channel_threshold=True,
        enable_weights_quantization=True,
        lut_values_bitwidth=None)

    # define a quantization config to quantize the bias (for layers where there is a bias attribute).
    bias_config = AttributeQuantizationConfig(
        weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
        weights_n_bits=FLOAT_BITWIDTH,
        weights_per_channel_threshold=False,
        enable_weights_quantization=False,
        lut_values_bitwidth=None)

    # Create a default OpQuantizationConfig where we use default_weight_attr_config as the default
    # AttributeQuantizationConfig for weights with no specific AttributeQuantizationConfig.
    # MCT will compress a layer's kernel and bias according to the configurations that are
    # set in KERNEL_ATTR and BIAS_ATTR that are passed in attr_weights_configs_mapping.
    default_config = tp.OpQuantizationConfig(
        default_weight_attr_config=default_weight_attr_config,
        attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config,
                                      BIAS_ATTR: bias_config},
        activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
        activation_n_bits=3,
        supported_input_activation_n_bits=8,
        enable_activation_quantization=True,
        quantization_preserving=False,
        fixed_scale=None,
        fixed_zero_point=None,
        simd_size=None,
        signedness=Signedness.AUTO)

    # Set default QuantizationConfigOptions in new TargetPlatformModel to be used when no other
    # QuantizationConfigOptions is set for an OperatorsSet.
    default_configuration_options = tp.QuantizationConfigOptions([default_config])
    tp_model = tp.TargetPlatformModel(default_configuration_options)
    with tp_model:
        default_qco = tp.get_default_quantization_config_options()
        # Group of OperatorsSets that should not be quantized.
        tp.OperatorsSet("NoQuantization",
                        default_qco.clone_and_edit(enable_activation_quantization=False)
                        .clone_and_edit_weight_attribute(enable_weights_quantization=False))
        # Group of linear OperatorsSets such as convolution and matmul.
        tp.OperatorsSet("LinearOp")

    tpc = tp.TargetPlatformCapabilities(tp_model)
    with tpc:
        # No need to quantize Flatten and Dropout layers
        tp.OperationsSetToLayers("NoQuantization", [layers.Flatten, layers.Dropout])
        # Assign the framework layers' attributes to KERNEL_ATTR and BIAS_ATTR that were used during creation
        # of the default OpQuantizationConfig.
        tp.OperationsSetToLayers("LinearOp", [layers.Dense, layers.Conv2D],
                                 attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
                                               BIAS_ATTR: DefaultDict(default_value=BIAS)})
    return tpc


## Init Keras model

In [None]:
num_classes = 10
input_shape = (28, 28, 1)

_input = layers.Input(shape=input_shape)
x = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(_input)
x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(num_classes, activation='softmax')(x)
model = Model(inputs=_input, outputs=x)

model.summary()

## Init MNIST dataset

In [None]:
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# Normalize images
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

# Add Channels axis to data
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)


## Train a Keras classifier model on MNIST

In [None]:
# train float model
batch_size = 128
epochs = 15

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)

# evaluate float model
score = model.evaluate(x_test, y_test, verbose=0)
print(f"Float model test accuracy: {score[1]:02.4f}")


## Prepare model for Hardware-Friendly Quantization Aware Training with MCT
The MCT takes the float model and quantizes it in a post-training quantization fashion. The returned model contains trainable quantizers and is ready to be retrained (namely, a "QAT ready" model).

In [None]:
n_iter = 10


def gen_representative_dataset():
    def _generator():
        for _ind in range(n_iter):
            yield [x_train[_ind][np.newaxis, ...]]
    return _generator


qat_model, _, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model,
                                                                                           gen_representative_dataset(),
                                                                                           core_config=mct.core.CoreConfig(),
                                                                                           target_platform_capabilities=get_tpc())

qat_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"], run_eagerly=True)
score = qat_model.evaluate(x_test, y_test, verbose=0)
print(f"PTQ model test accuracy: {score[1]:02.4f}")

## User Quantization Aware Training

In [None]:
qat_model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)

score = qat_model.evaluate(x_test, y_test, verbose=0)
print(f"QAT model test accuracy: {score[1]:02.4f}")


In [None]:
## Finalize QAT model: Remove QuantizeWrapper layers and leave only layers with quantized weights (FakeQuant values)
quantized_model = mct.qat.keras_quantization_aware_training_finalize_experimental(qat_model)

quantized_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
score = quantized_model.evaluate(x_test, y_test, verbose=0)
print(f"Quantized model test accuracy: {score[1]:02.4f}")

In [None]:
# Export quantized model to Keras
mct.exporter.keras_export_model(model=quantized_model, 
                                save_model_path='qmodel.keras')

Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.