Advise: the Python cells in this notebook are not (yet) prepared to be run, as are only included for exemplifying purposes. Last revision: 13th March.

## Quantization
by Pau Fabregat

### Table of contents
1. [Prepare the model](#1)
    1. [Models from torchvision](#11)
    2. [Custom model](#12)
2. [Obtaining the quantized checkpoint](#2)
    1. [Post-Training Static](#21)
    2. [Quantization-Aware Training (QAT) with PyTorch Lightning](#22)
    3. [Quantization-Aware Training (QAT) with PyTorch](#23)
3. [Loading the quantized checkpoint](#3)
4. [Usual errors](#4)
5. [Sources and useful links](#5)

<div id='1'/>

## 1. Prepare the model
Some changes have to be done in the model in order to quantize it.

<div id='11'/>

### 1.1. Models from torchvision
The torchvision.models.quantization contains prepared versions for quantization of some of the most-used models, such as ResNet, MobileNet, or ShuffleNet.

In [None]:
from torchvision.models.quantization import shufflenetv2 as qshufflenet
model = qshufflenet.shufflenet_v2_x1_0(pretrained=True, quantize=False)

We can train this model in an usual manner. If quantize=True, we are given the already int8 quantized model. We do not want this as we want to keep on training the model with our own data.

<div id='12'/>

### 1.2. Custom model <a name="1.2"></a>
If the model we are using is not included in the torchvision.models.quantization, we need to perform some manual modifications. The following block has been extracted from the PyTorch documentation:

> It is necessary to currently make some modifications to the model definition
prior to Eager mode quantization. This is because currently quantization works on a module
by module basis. Specifically, for all quantization techniques, the user needs to:  
> 1. Convert any operations that require output requantization (and thus have additional parameters) from functionals to module form (for example, using ``torch.nn.ReLU`` instead of ``torch.nn.functional.relu``).  
> 2. Specify which parts of the model need to be quantized either by assigning ``.qconfig`` attributes on submodules or by specifying ``qconfig_mapping``. For example, setting ``model.conv1.qconfig = None`` means that the ``model.conv`` layer will not be quantized, and setting ``model.linear1.qconfig = custom_qconfig`` means that the quantization settings for ``model.linear1`` will be using ``custom_qconfig`` instead of the global qconfig.    
>
>For static quantization techniques which quantize activations, the user needs to do the following in addition:  
> 1. Specify where activations are quantized and de-quantized. This is done using `torch.ao.quantization.QuantStub` and `torch.ao.quantization.DeQuantStub` modules.  
> 2. Use `torch.ao.nn.quantized.FloatFunctional` to wrap tensor operations that require special handling for quantization into modules. Examples are operations like ``add`` and ``cat`` which require special handling to determine output quantization parameters.  
> 3. Fuse modules: combine operations/modules into a single module to obtain higher accuracy and performance. This is done using the `torch.ao.quantization.fuse_modules` API, which takes in lists of modules to be fused. We currently support the following fusions: [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu]

For example, in the multitask model, it was convenient not to quantize the last linear layer of the road classifier:

In [None]:
self.model.model.road_head[-1].linear.qconfig = None

The SiLU activation function used in YOLOv5 can't be quantized. Although it admits a quantized tensor (likely because dequantization is done internally), using dequant and quant before and after this function is faster:

In [None]:
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
x = self.quant(nn.SiLU(self.dequant(x))) 

The three 'simple' operations of addition, multiplication, and concatenation have to be adapted:

In [None]:
self.ff = torch.ao.nn.quantized.FloatFunctional()

# Adapt addition
# Old:
x + y
# New:
self.ff.add(x, y)

# Adapt multiplication
# Old:
x * y
# New:
self.ff.mul(x, y)

# Adapt concatenation
# Old:
torch.cat(x, y)
# New:
self.ff.cat(x, y)

For instance, the Bottleneck module in the YOLOv5 has to be converted from:

In [None]:
class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

To the following code:

In [None]:
class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2
        self.ff = torch.ao.nn.quantized.FloatFunctional()

    def forward(self, x):
        return self.ff.add(x, self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))

<div id='2'/>

## 2. Obtaining the quantized checkpoint <a name="2"></a>
PyTorch allows three types of quantization:
1. Post-training dynamic quantization (PTQ-dynamic). This method does not allow the quantization of convolutional layers, so it is not suitable for the project.
2. Post-training static quantization (PTQ-static).
3. Quantization-aware training (QAT).

<div id='21'/>

### 2.1. Post-Training Static <a name="2.1"></a>
One of its key parts is the calibration. The general workflow is the following:

In [None]:
model.eval()
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
model = torch.ao.quantization.prepare(model)
calibrate(model_fp32_prepared, dataloader)
model = torch.ao.quantization.convert(model)
torch.save(model.state_dict(), "quantized_checkpoint.ckpt")

When loading a quantized model, we will use a quite similar procedure. In the previous block, we have to define the calibrate function (using ~100 samples is said to be enough), and we could fuse some layers (for instance, convolutions and batchnorms):

In [None]:
# Load the non-quantized checkpoint
# Use CPU, not GPU
ckpt = torch.load("checkpoint.ckpt")
model.load_state_dict(ckpt, strict=False)

# Set model to eval and define the qconfig
model.eval()
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")

# Fuse layers (optional)
for m in model.modules():
    if type(m) == Conv:
        torch.ao.quantization.fuse_modules(m, ['conv', 'bn'], inplace=True)

# Prepare for static quantization
model = torch.ao.quantization.prepare(model)

# Calibrate the model. Use min 100 samples.
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in dataloader:
            model(image)
print("Calibrating model... please wait...")
calibrate(model_fp32_prepared, dataloader)  # Define a dataloader

# Quantize and save the model
model = torch.ao.quantization.convert(model)
torch.save(model.state_dict(), "quantized_checkpoint.ckpt")


<div id='22'/>

### 2.2. Quantization-aware Training (QAT) with PyTorch Lightning <a name="2.2"></a>
The PyTorch Lightning (PL) library has its custom QAT callback. If we pass this callback to the PL trainer, the QAT is done automatically.

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import QuantizationAwareTraining
callbacks = {}

# We define the PL QAT callback
qat_callback = QuantizationAwareTraining(qconfig='fbgemm', observer_type='average')
callbacks["qat_callback"] = qat_callback

# Optionally add more callbacks (LearningRateMonitor, ModelCheckpoint, EarlyStopping, ...)

# We pass this callbacks dict to the trainer
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer = pl.Trainer.from_argparse_args(callbacks=list(my_callbacks.values()))
trainer.fit(module, datamodule=data_module)

<div id='23'/>

### 2.3. Quantization-aware Training (QAT) with PyTorch <a name="2.3"></a>
TO DO...

<div id='3'/>

## 3. Loading the quantized checkpoint <a name="3"></a>
Once the quantized checkpoint has been saved, loading the model is straightforward.

In [None]:
# Prepare the model for quantization
model.eval()
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
model = torch.ao.quantization.prepare(model)
model = torch.ao.quantization.convert(model)

# Load the quantized checkpoint
# Use CPU, not GPU
ckpt = torch.load("quantized_checkpoint.ckpt")
model.load_state_dict(ckpt, strict=False)

<div id='4'/>

## 4. Usual errors <a name="4"></a>
Passing a non-quantized Tensor into a quantized kernel:

RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend...

Solution: we have to quantize the tensor before using the operator:

In [None]:
self.quant = torch.quantization.QuantStub()
x = self.quant(x)

Passing a quantized Tensor into a non-quantized kernel:

RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend.

Solution: we have to dequantize the tensor before using the operator:

In [None]:
self.dequant = torch.quantization.DeQuantStub()
x = self.dequant(x)

<div id='5'/>

## 5. Sources and useful links <a name="5"></a>
https://pytorch.org/blog/introduction-to-quantization-on-pytorch/  
https://pytorch.org/docs/stable/quantization.html  
https://pytorch.org/tutorials/recipes/quantization.html  
https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html  
https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html  