# Quantization-Aware Training with BatchNorm Re-estimation

This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training) with batchnorm re-estimation.
Batchnorm re-estimation is a technique for countering potential instability of batchnrom statistics (i.e. running mean and variance) during QAT. More specifically, batchnorm re-estimation recalculates the batchnorm statistics based on the model after QAT. By doing so, we aim to make our model learn batchnorm statistics from from stable outputs after QAT, rather than from likely noisy outputs during QAT.

#### Overall flow
This notebook covers the following steps:
1. Load the dataset2.
3. Create the model in Keras
4. Train and evaluate the model
5. Quantize the model with QuantSim
6. Finetune and evaluate the quantization simulation model
7. Re-estimate batchnorm statistics and compare the eval score before and after re-estimation.
8. Fold the re-estimated batchnorm layers and export the quantization simulation model


---
## 1. Load the dataset

This notebook relies on the MNIST dataset for classification, as provided by Keras..

In [None]:
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

## 2. Create the model in Keras

Currently, only Keras models built using the Sequential or Functional APIs are compatible with QuantSim - models making use of subclassed layers are incompatible. Therefore, we use the Functional API to create the model used in this example

In [None]:
tf.keras.backend.clear_session()
inputs = tf.keras.Input(shape=(28, 28, 1,), name="inputs")
conv = tf.keras.layers.Conv2D(16, (3, 3), name ='conv1')(inputs)
bn = tf.keras.layers.BatchNormalization(fused=True)(conv)
relu = tf.keras.layers.ReLU()(bn)
pool = tf.keras.layers.MaxPooling2D()(relu)
conv2 = tf.keras.layers.Conv2D(8, (3, 3), name ='conv2')(pool)
flatten = tf.keras.layers.Flatten()(conv2)
dense  = tf.keras.layers.Dense(10)(flatten)
functional_model = tf.keras.Model(inputs=inputs, outputs=dense)

## 3. Train and evaluate the model

Before we can quantize the model and apply QAT, the FP32 model must be trained so that we can get a baseline accuracy.

In [None]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

functional_model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

functional_model.fit(x_train, y_train, epochs=5)

# Evaluate the model on the test data using `evaluate`
print("Evaluate quantized model (post QAT) on test data")
results =  functional_model.evaluate(x_test, y_test, batch_size=128)
print("test loss, test acc:", results)

## 4. Create a QuantizationSim Model

Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.
A few of the parameters are explained here
- **quant_scheme**: We set this to "QuantScheme.training_range_learning_with_tf_init"
    - Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced
- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision
- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision

There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters.

In [None]:
import json
from aimet_common.defs import QuantScheme
from aimet_tensorflow.keras.quantsim import QuantizationSimModel

default_config_per_channel = {
        "defaults":
            {
                "ops":
                    {
                        "is_output_quantized": "True"
                    },
                "params":
                    {
                        "is_quantized": "True",
                        "is_symmetric": "True"
                    },
                "strict_symmetric": "False",
                "unsigned_symmetric": "True",
                "per_channel_quantization": "True"
            },

        "params":
            {
                "bias":
                    {
                        "is_quantized": "False"
                    }
            },

        "op_type":
            {
                "Squeeze":
                    {
                        "is_output_quantized": "False"
                    },
                "Pad":
                    {
                        "is_output_quantized": "False"
                    },
                "Mean":
                    {
                        "is_output_quantized": "False"
                    }
            },

        "supergroups":
            [
                {
                    "op_list": ["Conv", "Relu"]
                },
                {
                    "op_list": ["Conv", "Clip"]
                },
                {
                    "op_list": ["Conv", "BatchNormalization", "Relu"]
                },
                {
                    "op_list": ["Add", "Relu"]
                },
                {
                    "op_list": ["Gemm", "Relu"]
                }
            ],

        "model_input":
            {
                "is_input_quantized": "True"
            },

        "model_output":
            {}
    }

with open("/tmp/default_config_per_channel.json", "w") as f:
    json.dump(default_config_per_channel, f)



qsim = QuantizationSimModel(functional_model, quant_scheme=QuantScheme.training_range_learning_with_tf_init, config_file="/tmp/default_config_per_channel.json")


**Compute Encodings**

Even though AIMET has added 'quantizer' nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each 'quantizer' node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as 'computing encodings'.

So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don't need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples
- In practice, we need a very small percentage of the overall data samples for computing encodings.
- It may be beneficial if the samples used for computing encoding are well distributed. It's not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all positive or negative samples are used.

The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example.

In [None]:
qsim.compute_encodings(lambda m, _: m(x_test[0:100]), None)

Next, we can evaluate the performance of the quantized model

In [None]:
qsim.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss=tf.keras.losses.MeanSquaredError())
print("Evaluate quantized model on test data")
results = qsim.model.evaluate(x_test, y_test, batch_size=128)
print("test loss, test acc:", results)

## 5. Perform QAT

To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.
For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

In [None]:
quantized_callback = tf.keras.callbacks.TensorBoard(log_dir="./log/quantized")
history = qsim.model.fit(
        x_train[0:1024], y_train[0:1024], batch_size=32, epochs=1, validation_data=(x_test, y_test),
        callbacks=[quantized_callback]
    )

Finally, let's evaluate the validation accuracy of our model after QAT.

In [None]:
print("Evaluate quantized model (post QAT) on test data")
results =  qsim.model.evaluate(x_test, y_test, batch_size=128)
print("test loss, test acc:", results)

***6. Re-estimate BatchNorm Statistics***

AIMET provides a helper function, `reestimate_bn_stats`, for re-estimating batchnorm statistics.
Here is the full list of parameters for this function:
* **model**: Model to re-estimate the BatchNorm statistics.
* **dataloader** Train dataloader.
* **num_batches** (optional): The number of batches to be used for reestimation. (Default: 100)
* **forward_fn** (optional): Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader. If not specified, it is expected that inputs yielded from dataloader can be passed directly to the model.

In [None]:
import numpy as np
from aimet_tensorflow.keras.bn_reestimation import reestimate_bn_stats, _get_bn_submodules

batch_size = 4
dataset = tf.data.Dataset.from_tensor_slices(x_train[0:100])
dataset = dataset.batch(batch_size=batch_size)
it = iter(dataset)
dummy_inputs = next(it)

bn_layers = _get_bn_submodules(qsim.model)
bn_mean_ori = {layer.name: layer.moving_mean.numpy() for layer in bn_layers}
bn_var_ori = {layer.name: layer.moving_variance.numpy() for layer in bn_layers}
bn_momentum_ori = {layer.name: layer.momentum for layer in bn_layers}
output_ori = qsim.model(dummy_inputs, training=False)

with reestimate_bn_stats(qsim.model, dataset, 1):
    # check re_estimation mean, var, momentum
    bn_mean_est = {layer.name: layer.moving_mean.numpy() for layer in bn_layers}
    bn_var_est = {layer.name: layer.moving_variance.numpy() for layer in bn_layers}
    bn_momentum_est = {layer.name: layer.momentum for layer in bn_layers}
    assert not all(np.allclose(bn_mean_ori[key], bn_mean_est[key]) for key in bn_mean_est)
    assert not all(np.allclose(bn_var_ori[key], bn_var_est[key]) for key in bn_var_est)
    assert not (bn_momentum_ori == bn_momentum_est)
    output_est = qsim.model(dummy_inputs, training=False)
    assert not np.allclose(output_est, output_ori)

# check restored  mean, var, momentum
bn_mean_restored = {layer.name: layer.moving_mean.numpy() for layer in bn_layers}
bn_var_restored = {layer.name: layer.moving_variance.numpy() for layer in bn_layers}
bn_momentum_restored = {layer.name: layer.momentum for layer in bn_layers}

assert all(np.allclose(bn_mean_ori[key], bn_mean_restored[key]) for key in bn_mean_ori)
assert all(np.allclose(bn_var_ori[key], bn_var_restored[key]) for key in bn_var_ori)
assert (bn_momentum_ori == bn_momentum_restored)

### Fold BatchNorm Layers

So far, we have improved our quantization simulation model through QAT and batchnorm re-estimation. The next step would be to actually take this model to target. But first, we should fold the batchnorm layers for our model to run on target devices more efficiently.

In [None]:
from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms_to_scale
fold_all_batch_norms_to_scale(qsim)

---
## 5. Export Model
As the final step, we will export the model to run it on actual target devices. AIMET QuantizationSimModel provides an export API for this purpose.

In [None]:
import os
os.makedirs('./output/', exist_ok=True)
qsim.export(path='./output/', filename_prefix='mnist_after_bn_re_estimation_qat_range_learning')

## Summary

Hope this notebook was useful for you to understand how to use batchnorm re-estimation feature of AIMET.

Few additional resources
- Refer to the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) to know more details of the APIs and optional parameters.
- Refer to the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/tensorflow/quantization/keras) to understand how to use AIMET post-training quantization techniques and QAT methods.