# Quantization Aware Training Example Code

The following notebook is an example to show quantization simulation and finetuning using the AIMET library. The general procedure for quantization is to use AIMET's QuantizationSimModel to compute new encodings, then finetune the model.

We now present an overview of the technique. The weights of the pretrained model (in our case, ResNet), are originally 32-bit floating point numbers, which are then converted to 8-bit numbers through a rounding procedure. Then, this quantized model can be finetuned.


This script utilizes AIMET to perform Quantization Aware Training on a ResNet18 pretrained model with the ImageNet data set. This is intended as a working example to show how AIMET APIs can be invoked.

Scenario parameters:
1. AIMET quantization aware training using simulation model
2. Quant Scheme: 'tf'
3. rounding_mode: 'nearest'
4. default_output_bw: 8, default_param_bw: 8
5. Encoding computation using 5 batches of data
6. Input shape: [1, 3, 224, 224]
7. Learning rate: 0.001
8. Decay Steps: 5

#### The example code shows the following:
1. Instantiate Data Pipeline for evaluation 
2. Load the pretrained ResNet18 Pytorch model
3. Calculate Model accuracy
    * 3.1. Calculate floating point accuracy
    * 3.2. Calculate Quant Simulator accuracy
4. Apply AIMET CLE and BC
    * 4.1. Apply AIMET CLE and calculates QuantSim accuracy
    * 4.2. Apply AIMET BC and calculates QuantSim accuracy
5. Fine-tune Model

The first three cells below takes care of all necessary imports:

In [None]:
import warnings
warnings.filterwarnings("ignore", ".*param.*")

# Imports necessary for the notebook
import os
import copy
import argparse
from datetime import datetime
from functools import partial
import torch
from torchvision.models import resnet18

In [None]:
# AIMET Imports for Quantization
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel, QuantParams
from aimet_torch.bias_correction import correct_bias

from aimet_torch.cross_layer_equalization import equalize_model
from aimet_torch.batch_norm_fold import fold_all_batch_norms


In [None]:
# Imports needed for the Data Pipeline
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_trainer import ImageNetTrainer
from Examples.torch.utils.image_net_evaluator import ImageNetDataLoader

## Setting Up Our Config Dictionary

The config dictionary specifies a number of things 

config: 
This mapping expects following parameters:
1. **dataset_dir:** Path to a directory containing ImageNet dataset. This folder should contain subfolders 'train' for training dataset and 'val' for validation dataset.
3. **use_cuda:** A boolean var to indicate to run the quantization on GPU.
4. **logdir:** Path to a directory for logging.

To get a better understanding of when each of the parameters in the config dictionary is used, read the code in those cells.  
**Note:** You will have to replace the dataset_dir path with the path to your own imagenet/tinyimagenet dataset

In [None]:
config = {'dataset_dir': "path/to/dataset",
          'use_cuda': True,
          'logdir': os.path.join("benchmark_output", "QAT_"+datetime.now().strftime("%Y-%m-%d-%H-%M-%S")), 
          'epochs': 1, 
          'learning_rate': 1e-2, 
          'learning_rate_schedule': [5, 10]}

os.makedirs(config['logdir'], exist_ok=True)

## 1. Instantiate Data Pipeline

The ImageNetDataPipeline class takes care of evaluating a model using a dataset directory. For more detail on how it works, see the relevant files under examples/torch/utils.

The data pipeline class is simply a template for the user to follow. The methods for this class can be replaced by the user to fit their needs.

In [None]:
class ImageNetDataPipeline:
    """
    Provides APIs for model quantization using evaluation and finetuning.
    """

    def __init__(self, config):
        """
        :param config:
        """
        self._config = config

    def data_loader(self):
        """
        :return: ImageNetDataloader
        """
        
        data_loader = ImageNetDataLoader(is_training=False, images_dir=self._config["dataset_dir"],
                                         image_size=image_net_config.dataset['image_size']).data_loader

        return data_loader
    
    def evaluate(self, model: torch.nn.Module, iterations: int = None, use_cuda: bool = False) -> float:
        """
        Evaluate the specified model using the specified number of samples from the validation set.
        :param model: The model to be evaluated.
        :param iterations: The number of batches of the dataset.
        :param use_cuda: If True then use a GPU for inference.
        :return: The accuracy for the sample with the maximum accuracy.
        """

        # Your code goes here

        evaluator = ImageNetEvaluator(self._config['dataset_dir'], image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      num_workers=image_net_config.evaluation['num_workers'])

        return evaluator.evaluate(model, iterations, use_cuda)
    
    def finetune(self, model: torch.nn.Module):
        """
        Finetunes the model.  The implemtation provided here is just an example,
        provide your own implementation if needed.

        :param model: The model to finetune.
        """

        # Your code goes here instead of the example from below

        trainer = ImageNetTrainer(self._config['dataset_dir'], image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_workers=image_net_config.train['num_workers'])

        trainer.train(model, max_epochs=self._config['epochs'], learning_rate=self._config['learning_rate'],
                      learning_rate_schedule=self._config['learning_rate_schedule'], use_cuda=self._config['use_cuda'])

        torch.save(model, os.path.join(self._config['logdir'], 'finetuned_model.pth'))

## 2. Load the Model, Initialize DataPipeline

The next section will initialize the model and data pipeline for the quantization

We initialize the pipeline and the model. Before quantizing the model, we calculate the original floating point (FP32) accuracy of the model on the dataset provided.

In [None]:
data_pipeline = ImageNetDataPipeline(config)

model = resnet18(pretrained=True)
if config['use_cuda']:
    if torch.cuda.is_available():
        model.to(torch.device('cuda'))
    else:
        raise Exception("use_cuda is True but cuda is unavailable")
model.eval()

## 3. Quantization Simulator

The next cells are for the actual quantization step. The quantization parameters are specified in the following cell:

1. **quant_scheme**: The scheme used to quantize the model. We can choose from s - post_training_tf or post_training_tf_enhanced.

2. **rounding_mode**: The rounding mode used for quantization. There are two possible choices here - 'nearest' or 'stochastic'

3. **default_output_bw**: The bitwidth of the activation tensors. The value of this should be a power of 2, less than 32.

4. **default_param_bw**: The bidwidth of the parameter tensors. The value of this should be a power of 2, less than 32.

5. **num_batches**: The number of batches used to evaluate the model while calculating the quantization encodings.Number of batches to use for computing encodings. Only 5 batches are used here to speed up the process. In addition, the number of images in these 5 batches should be sufficient for compute encodings

In [None]:
quant_scheme = QuantScheme.post_training_tf_enhanced
rounding_mode = 'nearest'
default_output_bw = 8
default_param_bw = 8

#Uncomment one of the following lines
# num_batches = 5 #Typical
num_batches = 1 #Test

### 3.1. Calculate floating point accuracy

In [None]:
accuracy = data_pipeline.evaluate(model, use_cuda=config['use_cuda'])
print("Original Model Accuracy: ", accuracy)

### 3.2. Calculate Quant Simulator accuracy

We now set up the quantization simulator, and quantize the model. The resulting quantized (INT8) Model is then evaluated on the dataset. We utilize the evaluate function from the data pipeline to compute the new weights.

it is customary to fold batch norms; however, the Cross Layer Equalization API expects a model which does not have folded batch norms. For this reason, we make a copy of our model to evaluate.

In [None]:
dummy_input = torch.rand(1, 3, 224, 224)
if config['use_cuda']:
    dummy_input = dummy_input.to(torch.device('cuda'))


BN_folded_model = copy.deepcopy(model)
_ = fold_all_batch_norms(BN_folded_model, input_shapes=(1, 3, 224, 224))

quantizer = QuantizationSimModel(model=BN_folded_model,
                                 quant_scheme=quant_scheme,
                                 dummy_input=dummy_input,
                                 rounding_mode=rounding_mode,
                                 default_output_bw=default_output_bw,
                                 default_param_bw=default_param_bw)

quantizer.compute_encodings(forward_pass_callback=partial(data_pipeline.evaluate,
                                                          use_cuda=config['use_cuda']),
                            forward_pass_callback_args=num_batches)

# Calculate quantized (INT8) accuracy after CLE
accuracy = data_pipeline.evaluate(quantizer.model)
print("Quantized (INT8) Model Top-1 Accuracy: ", accuracy)

## 4. Apply AIMET CLE and BC


### 4. 1 Cross Layer Equalization

The next cell performs cross-layer equalization on the model. As noted before, the function folds batch norms, applies cross-layer scaling, and then folds high biases.

In [None]:
# This API will equalize the model in-place
equalize_model(model, input_shapes=(1, 3, 224, 224))

Then, the model is quantized, and the accuracy is noted. This is done before the bias correction step in order to measure the individual impacts of each technique.

In [None]:
dummy_input = torch.rand(1, 3, 224, 224)
if config['use_cuda']:
    dummy_input = dummy_input.to(torch.device('cuda'))

cle_quantizer = QuantizationSimModel(model=model,
                                     quant_scheme=quant_scheme,
                                     dummy_input=dummy_input,
                                     rounding_mode=rounding_mode,
                                     default_output_bw=default_output_bw,
                                     default_param_bw=default_param_bw)

cle_quantizer.compute_encodings(forward_pass_callback=partial(data_pipeline.evaluate,
                                                              use_cuda=config['use_cuda']),
                                forward_pass_callback_args=num_batches)

accuracy = data_pipeline.evaluate(cle_quantizer.model)
print("CLE applied Model Top-1 accuracy on Quant Simulator: ", accuracy)

### 4. 2 Bias Correction

Perform Bias correction and calculate the accuracy on the quantsim model. The first cell includes two parameters related to this step:

1. **num_quant_samples**: The number of samples used during quantization
2. **num_bias_correction_samples**: The number of samples used during bias correction

In [None]:
# Uncomment one of the following sets of parameters
num_quant_samples = 16 #Typical
num_bias_correct_samples = 16 #Typical

num_quant_samples = 1 #Test
num_bias_correct_samples = 1 #Test

Here the actual bias correction steps are performed:

In [None]:
data_loader = data_pipeline.data_loader()

bc_params = QuantParams(weight_bw=default_param_bw,
                        act_bw=default_output_bw,
                        round_mode=rounding_mode,
                        quant_scheme=quant_scheme)

correct_bias(model,
             bc_params,
             num_quant_samples=num_quant_samples,
             data_loader=data_loader,
             num_bias_correct_samples=num_bias_correct_samples)

Finally, the model is quantized, the accuracy is logged.

In [None]:
dummy_input = torch.rand(1, 3, 224, 224)
if config['use_cuda']:
    dummy_input = dummy_input.to(torch.device('cuda'))

quantsim = QuantizationSimModel(model=model,
                                quant_scheme=quant_scheme,
                                dummy_input=dummy_input,
                                rounding_mode=rounding_mode,
                                default_output_bw=default_output_bw,
                                default_param_bw=default_param_bw,
                                in_place=False)

quantsim.compute_encodings(forward_pass_callback=partial(data_pipeline.evaluate,
                                                         use_cuda=config['use_cuda']),
                               forward_pass_callback_args=num_batches)

accuracy = data_pipeline.evaluate(quantsim.model)
print("Quantized (INT8) Model Top-1 Accuracy After Bias Correction: ", accuracy)

## 5. Fine-tune Model


After the model is quantized, the model is finetuned, then evaluated and saved.

In [None]:
# Model is fine-tuned for 1 epoch, to get better accuracy, the number of epochs can be increased
print("Starting Model Finetuning")
data_pipeline.finetune(quantsim.model)

# Calculate and log the accuracy of quantized-finetuned model
accuracy = data_pipeline.evaluate(quantsim.model, use_cuda=config['use_cuda'])
print("After Quantization Aware Training, top-1 accuracy = %.2f", accuracy)

print("Quantization Aware Training Complete")

input_shape = (1, 3, 224, 224)
dummy_input = torch.rand(input_shape)

# Save the quantized model
quantsim.export(path=config['logdir'], filename_prefix='QAT_resnet', dummy_input=dummy_input.cpu())
