# Quantization-Aware Training with Range Learning

This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training). QAT is an AIMET feature adding quantization simulation ops (also called fake quantization ops sometimes) to a trained ML model and using a standard training pipeline to fine-tune or train the model for a few epochs. The resulting model should show improved accuracy on quantized ML accelerators.

AIMET supports two different types of QAT
1. Simply referred to as QAT - quantization parameters like per-tensor scale/offsets for activations are computed once. During fine-tuning, the model weights are updated to minimize the effects of quantization in the forward pass, keeping the quantization parameters constant.
2. Referred to as QAT with range-learning - quantization parameters like per-tensor scale/offsets for activations are computed initially. Then both the quantization parameters and the model weights are jointly updated during fine-tuning to minimize the effects of quantization in the forward pass.

This notebook specifically shows working code example for #2 above. You can find a separate notebook for #1 in the same folder.

#### Overall flow
This notebook covers the following
1. Instantiate the example evaluation and training pipeline
2. Load the FP32 model and evaluate the model to find the baseline FP32 accuracy
3. Create a quantization simulation model (with fake quantization ops inserted) and evaluate this simuation model to get a quantized accuracy score
4. Fine-tune the quantization simulation model and evaluate the simulation model to get a post-finetuned quantized accuracy score

#### What this notebook is not
* This notebook is not designed to show state-of-the-art QAT results. For example, it uses a relatively quantization-friendly model like Resnet50. Also, some optimization parameters like number of epochs to fine-tune are deliberately chosen to have the notebook execute more quickly.

---
## Dataset

This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#) and convert them into tfrecords.

**Note1**: The ImageNet tfrecords dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these
- A folder containing tfrecords files starting with **'train\*'** for training files and **'valid\*'** for validation files. Each tfrecord file should have features: **'image/encoded'** for image data and **'image/class/label'** for its corresponding class.

**Note2**: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class and then convert it into tfrecords. This exercise is left upto the reader and is not necessary.

Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved.

In [None]:
TFRECORDS_DIR = '/path/to/tfrecords/dir/'        # Please replace this with a real directory

We disable logs at the INFO level and disable eager execution. We set verbosity to the level as displayed (ERORR), so TensorFlow will display all messages that have the label ERROR (or more critical).

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
tf.logging.set_verbosity(tf.logging.ERROR)

---
## 1. Example Evaluation and Training Pipeline

The following is an example training and validation loop for this image classification task.

- **Does AIMET have any limitations on how the training, validation pipeline is written?** Not really. We will see later that AIMET will modify the user's model to create a QuantizationSim model which is still a TensorFlow model. This QuantizationSim model can be used in place of the original model when doing inference or training.
- **Does AIMET put any limitation on the interface of the evaluate() or train() methods?** Not really. You should be able to use your existing evaluate and train routines as-is.


In [None]:
from typing import List

from Examples.common import image_net_config
from Examples.tensorflow.utils.image_net_evaluator import ImageNetDataLoader
from Examples.tensorflow.utils.image_net_evaluator import ImageNetEvaluator
from Examples.tensorflow.utils.image_net_trainer import ImageNetTrainer

class ImageNetDataPipeline:
    """
    Provides APIs for model evaluation and finetuning using ImageNet Dataset.
    """
    
    @staticmethod
    def get_val_dataloader():
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(TFRECORDS_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         format_bgr=True)

        return data_loader
    
    @staticmethod
    def evaluate(sess: tf.Session) -> float:
        """
        Given a TF session, evaluates its Top-1 accuracy on the validation dataset
        :param sess: The sess graph to be evaluated.
        :return: The accuracy for the sample with the maximum accuracy.
        """
        evaluator = ImageNetEvaluator(TFRECORDS_DIR, training_inputs=['keras_learning_phase:0'],
                                      data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                      image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      format_bgr=True)

        return evaluator.evaluate(sess)

    
    @staticmethod
    def finetune(sess: tf.Session, update_ops_name: List[str], epochs: int, learning_rate: float, decay_steps: int):
        """
        Given a TF session, finetunes it to improve its accuracy
        :param sess: The sess graph to fine-tune.
        :param update_ops_name: list of name of update ops (mostly BatchNorms' moving averages).
                                tf.GraphKeys.UPDATE_OPS collections is always used
                                in addition to this list
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param decay_steps: A number used to adjust(decay) the learning rate after every decay_steps epochs in training.
        """
        trainer = ImageNetTrainer(TFRECORDS_DIR, training_inputs=['keras_learning_phase:0'],
                                  data_inputs=['input_1:0'], validation_inputs=['labels:0'],
                                  image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_epochs=epochs, format_bgr=True)

        trainer.train(sess, update_ops_name=update_ops_name, learning_rate=learning_rate, decay_steps=decay_steps)


---
## 2. Load the model and evaluate to get a baseline FP32 accuracy score

For this example notebook, we are going to load a pretrained ResNet50 model from keras and covert it to a tensorflow session. Similarly, you can load any pretrained tensorflow model instead.


Calling clear_session() releases the global state: this helps avoid clutter from old models and layers, especially when memory is limited.


By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. Since batchnorm ops are folded, these need to be ignored during training.

In [None]:
from tensorflow.compat.v1.keras.applications.resnet import ResNet50

tf.keras.backend.clear_session()

model = ResNet50(weights='imagenet', input_shape=(224, 224, 3))
update_ops_name = [op.name for op in model.updates] # Used for finetuning

The following utility method in AIMET sets BN layers in the model to eval mode. This allows AIMET to more easily read the BN parameters from the graph. Eventually we will fold BN layers into adjacent conv layers.

In [None]:
from aimet_tensorflow.utils.graph import update_keras_bn_ops_trainable_flag

model = update_keras_bn_ops_trainable_flag(model, load_save_path="./", trainable=False)

AIMET features currently support tensorflow sessions. **add_image_net_computational_nodes_in_graph** adds an output layer, softmax and loss functions to the Resnet50 model graph.

In [None]:
from Examples.tensorflow.utils.add_computational_nodes_in_graph import add_image_net_computational_nodes_in_graph

sess = tf.keras.backend.get_session()

# Creates the computation graph of ResNet within the tensorflow session.
add_image_net_computational_nodes_in_graph(sess, model.output.name, image_net_config.dataset['images_classes'])

Since all tensorflow input and output tensors have names, we identify the tensors needed by AIMET APIs here. 

In [None]:
starting_op_names = [model.input.name.split(":")[0]]
output_op_names = [model.output.name.split(":")[0]]

We are checking if TensorFlow is using CPU or CUDA device. This example code will use CUDA if available in your current execution environment.

In [None]:
use_cuda = tf.test.is_gpu_available(cuda_only=True)

---
Let's determine the FP32 (floating point 32-bit) accuracy of this model using the evaluate() routine

In [None]:
accuracy = ImageNetDataPipeline.evaluate(sess=sess)
print(accuracy)

---
## 3. Create a quantization simulation model and determine quantized accuracy

## Fold Batch Normalization layers
Before we determine the simulated quantized accuracy using QuantizationSimModel, we will fold the BatchNormalization (BN) layers in the model. These layers get folded into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.

**Why do we need to this?**
On quantized runtimes (like TFLite, SnapDragon Neural Processing SDK, etc.), it is a common practice to fold the BN layers. Doing so, results in an inferences/sec speedup since unnecessary computation is avoided. Now from a floating point compute perspective, a BN-folded model is mathematically equivalent to a model with BN layers from an inference perspective, and produces the same accuracy. However, folding the BN layers can increase the range of the tensor values for the weight parameters of the adjacent layers. And this can have a negative impact on the quantized accuracy of the model (especially when using INT8 or lower precision). So, we want to simulate that on-target behavior by doing BN folding here.

The following code calls AIMET to fold the BN layers on the given model and returns a new session

In [None]:
from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms

BN_folded_sess, _= fold_all_batch_norms(sess,
                                        input_op_names=starting_op_names,
                                        output_op_names=output_op_names)

## Create Quantization Sim Model

Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.
A few of the parameters are explained here
- **quant_scheme**: We set this to "training_range_learning_with_tf_init"
    - This is the key setting that enables "range learning". With this choice of quant scheme, AIMET will use the TF quant scheme to initialize the quantization parameters like scale/offset. And then those parameters are set to be trainable so they can continue to be updated during fine-tuning.
    - Another choice for quant_scheme is "training_range_learning_with_tf_enhanced_init". Similar to the above, but the initialization for scale/offset is doing using the TF Enhanced scheme. Since in both schemes the quantization parameters are set to be trainable, there is not much benefit to using this choice instead of "training_range_learning_with_tf_init.
- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision
- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision

There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters.

In [None]:
from aimet_common.defs import QuantScheme
from aimet_tensorflow.quantsim import QuantizationSimModel

sim = QuantizationSimModel(session=BN_folded_sess,
                           starting_op_names=starting_op_names,
                           output_op_names=output_op_names,
                           quant_scheme= QuantScheme.training_range_learning_with_tf_enhanced_init,
                           default_output_bw=8,
                           default_param_bw=8,
                           use_cuda=use_cuda)

---
## Compute Encodings
Even though AIMET has added 'quantizer' nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each 'quantizer' node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as 'computing encodings'.

So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don't need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples
- In practice, we need a very small percentage of the overall data samples for computing encodings. For example, the training dataset for ImageNet has 1M samples. For computing encodings we only need 500 or 1000 samples.
- It may be beneficial if the samples used for computing encoding are well distributed. It's not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all 'dark' or 'light' samples are used - e.g. only using pictures captured at night might not give ideal results.

The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example.

In [None]:
def pass_calibration_data(session: tf.compat.v1.Session, _):
    data_loader = ImageNetDataPipeline.get_val_dataloader()
    batch_size = data_loader.batch_size

    input_label_tensors = [session.graph.get_tensor_by_name('input_1:0'),
                           session.graph.get_tensor_by_name('labels:0')]
    
    train_tensors = [session.graph.get_tensor_by_name('keras_learning_phase:0')]
    train_tensors_dict = dict.fromkeys(train_tensors, False)
    
    eval_outputs = [session.graph.get_operation_by_name('top1-acc').outputs[0]]

    samples = 500

    batch_cntr = 0
    for input_label in data_loader:
        input_label_tensors_dict = dict(zip(input_label_tensors, input_label))

        feed_dict = {**input_label_tensors_dict, **train_tensors_dict}

        with session.graph.as_default():
            _ = session.run(eval_outputs, feed_dict=feed_dict)

        batch_cntr += 1
        if (batch_cntr * batch_size) > samples:
            break

---
Now we call AIMET to use the above routine to pass data through the model and then subsequently compute the quantization encodings. Encodings here refer to scale/offset quantization parameters.

In [None]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=None)

---
Now the QuantizationSim model is ready to be used for inference or training. First we can pass this model to the same evaluation routine we used before. The evaluation routine will now give us a simulated quantized accuracy score for INT8 quantization instead of the FP32 accuracy score we saw before.

In [None]:
accuracy = ImageNetDataPipeline.evaluate(sim.session)
print(accuracy)

---
## 4. Perform QAT

To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.

For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit.

In [None]:
ImageNetDataPipeline.finetune(sim.session, update_ops_name=update_ops_name, epochs=1, learning_rate=1e-3, decay_steps=5)

---
After we are done with QAT, we can run quantization simulation inference against the validation dataset at the end to observe any improvements in accuracy.

In [None]:
accuracy = ImageNetDataPipeline.evaluate(sim.session)
print(accuracy)

---
Depending on your settings you may have observed a slight gain in accuracy after one epoch of training. Ofcourse, this was just an example. Please try this against the model of your choice and play with the hyper-parameters to get the best results.

So we should have an improved model after QAT. Now the next step would be to actually take this model to target. For this purpose, we need to export the model with the updated weights without the fake quant ops. AIMET QuantizationSimModel provides an export API for this purpose. This API would save the model as #TODO

In [None]:
os.makedirs('./output/', exist_ok=True)
sim.export(path='./output/', filename_prefix='resnet50_after_qat_range_learning')

---
## Summary

Hope this notebook was useful for you to understand how to use AIMET for performing QAT with range-learning.

Few additional resources
- Refer to the AIMET API docs to know more details of the APIs and optional parameters
- Refer to the other example notebooks to understand how to use AIMET post-training quantization techniques and the vanilla QAT method (without range-learning)