# Why Quantization Aware Training?

When you quantize a model after training (Post-Training Quantization or PTQ), you replace float weights and activations with lower-precision (e.g. int8) versions after training is complete.
But this can lead to significant accuracy drop, especially in sensitive models (like NLP or small models).

To mitigate this, QAT simulates quantization during training, so the model learns to be robust to quantization artifacts.

So we insert some fake modules in the computational graph of the model to simulate the effect of the quantization during
training. This way, the loss function gets used to update weights that constantly suffer from the effect of quantization, and it usually leads to a more robust model.

## 🔧 What Actually Happens in QAT?

1. You start with a pretrained model in float32.

2. You wrap layers with fake quantization modules like QuantStub, DeQuantStub, and layer-specific observers.

3. During training:
  - The model performs fake quantization on weights and activations.
  - This means it simulates int8 rounding and clipping but keeps everything in float32.
  - This ensures gradients can flow through and training is stable.
  - The optimizer updates the weights to compensate for quantization errors.

4. After training:
  - The model is converted to real int8 quantized version.
  - Now you have real quantized weights and can deploy to edge devices efficiently.

## 🔁 The Quantize-Dequantize Flow in QAT
The key thing QAT does is simulate this cycle:

```
Input (float32)
   ↓
QuantStub  ⟶  FakeQuantize (simulates int8)
   ↓
Model layers (convolution, linear, etc.)
   ↓
DeQuantStub ⟶ back to float32 for output
```

Every layer internally also uses fake quantization to simulate quantized weights and activations.

### In  Quantization Aware Training (QAT) all calculations are still in float32 during training:

The fake quantization layers only simulate the effects of int8 by:

1. Taking your float32 tensor

2. Applying the quantization formula (round((x / scale) + zero_point))

3. Clamping it to the int8 range (−128 to 127 or 0 to 255)

4. Then immediately converting it back to float32 for the rest of the computation.

5. Here’s the flow inside the fake quantization modules:
  - Quantize (fake):
       1. Take your float32 activation
       2. Simulate int8 mapping using scale & zero-point
       3. Clamp to int8 range (like −128 to 127)
       4. But store it still as float32
  
  - Dequantize (fake):
       1. Convert that “pretend int8” value back to float32 using scale * (q - zero_point)
       2. Now the next layer still sees float32 input
   
  - Because of that:
     1. You get all the precision loss of int8 (rounding + clipping)
     2. But you can still train with normal float math

6. So at every layer we have Sequence of Quantize and Dequantize operations:
   ```
   float32 → fake quantize → fake dequantize → float32 ops
   ```

This way:
- You get the rounding & clipping errors of int8 quantization.
- But you still keep the full float32 computation graph so backpropagation and optimizer work normally.


7. Key point:
    - The real int8 math only happens after training for inference, when you run:

```python
model_int8 = torch.quantization.convert(model)
```
That’s when PyTorch swaps in actual quantized kernels and your weights & activations become int8 for real inference.
The fake quantize + fake dequantize modules are replaced by real quantize ops
The layers in between actually store and compute in int8.


## ✅ Benefits of QAT
- Accuracy is typically much closer to the original float model.
- Great for models that are sensitive to quantization errors.
- Enables deployment to int8 hardware (e.g. mobile CPUs, edge devices).


## Quantization Aware Training (QAT): gradient
During backpropagation, the model needs to evaluate the gradient of the loss function w.r.t every weight and input. A problem arises:

What is the derivative of the quantization operation we defined before?
A typical solution is to approximate the gradient with the STE (Straight-through Estimator) approximation.
The STE approximation results in 1 if the value being quantized is in the range [𝛼, 𝛽], otherwise it is 0.
(INTEGER QUANTIZATION FOR DEEP LEARNING INFERENCE: PRINCIPLES AND EMPIRICAL EVALUATION, Wu et al.)


# Import the necessary libraries

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

# Load the MNIST dataset

In [28]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [29]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

# Define the model

In [None]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

        # Step 2: Add QuantStub and DeQuantStub
        # These are "fake" quant and dequant modules inserted into the model
        # QuantStub: simulates quantizing activations at model input
        # DeQuantStub: simulates dequantizing before output
        # In QAT (Quantization-Aware Training):
        # QuantStub and DeQuantStub actively simulate quantization during training using fake quantization ops.
        # This means the forward pass mimics the rounding and clamping effects of int8 math using floating-point
        # tensors, so the weights learn to be robust to the quantization noise.
        # At the end of training, we still call convert(), which replaces the fake quant ops with real quantized
        # int8 operators.
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)

        # Quantize input
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)

        # Dequantize output
        x = self.dequant(x)
        return x

net = VerySimpleNet().to(device)

# Insert min-max observers in the model

In [None]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()


# Prepare model for QAT
# This inserts fake quant/dequant modules into the graph. These simulate the effect of real quantization.
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers
net_quantized

VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Train the model

In [None]:
# Train (or fine-tune) the model with QAT
# This is where the model learns to be robust to quantization effects.
# Usually, you train for a few epochs only, on pre-trained weights.

def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

train(train_loader, net_quantized, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:22<00:00, 271.31it/s, loss=0.224]


# Define the testing loop

In [33]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Check the collected statistics during training

In [None]:
print(f'Check statistics of the various layers during training')
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5494080781936646, max_val=0.3067437410354614)
    (activation_post_process): MinMaxObserver(min_val=-42.466209411621094, max_val=40.64482879638672)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.4645603895187378, max_val=0.33165353536605835)
    (activation_post_process): MinMaxObserver(min_val=-40.01139450073242, max_val=22.106660842895508)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.46996980905532837, max_val=0.21300984919071198)
    (activation_post_process): MinMaxObserver(min_val=-30.044422149658203, max_val=23.014163970947266)
  )
  (relu): ReLU()
  (dequant): DeQua

# Quantize the model using the statistics collected during training

### One of the big differences between QAT and PTQ.

- In PTQ:
   - The model has never seen quantization effects during training
   - You must run inference (with some representative dataset) to collect activation statistics (min/max or histogram)
   - Those stats are used to calculate scale and zero-point for each quantization point
   - Without that calibration step, you can’t convert the model because it wouldn’t know how to map float32 → int8

- In QAT:
  - The model is trained with fake quantize/dequantize modules already in place
  - Those modules collect stats during training itself
  - By the end of training, they already know the right scale and zero-point
  - So you can directly run torch.quantization.convert(model) without needing a separate calibration step and it’s ready   for int8 inference — no extra calibration pass needed

- So in short:
  - PTQ → needs post-training calibration
  - QAT → calibration is built into training

In [None]:
net_quantized.eval()

# Convert model to a fully quantized version
# This replaces fake quant/dequant with real int8 arithmetic operations
net_quantized = torch.ao.quantization.convert(net_quantized)

In [36]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6544176340103149, zero_point=65, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.4891185462474823, zero_point=82, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.4177841544151306, zero_point=72, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights and size of the model after quantization

In [40]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[ 4,  9, -3,  ...,  9,  5,  5],
        [-7, -6, -5,  ..., -7, -4, -9],
        [ 0,  8, -3,  ...,  0,  5,  6],
        ...,
        [ 8,  9,  1,  ...,  0,  4, -4],
        [-2,  0,  8,  ...,  3,  3,  3],
        [ 6,  4,  1,  ...,  9, -2,  3]], dtype=torch.int8)


In [None]:
# Run inference
# Now this model uses real quantized weights and activations for inference
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 716.87it/s]

Accuracy: 0.952





## Recap
- QuantStub/DeQuantStub simulate what real quantized ops would look like, so model can train while seeing quant noise.
- Fake Quantization: forward pass quantizes and dequantizes activations/weights, backward is full precision.
- QAT gives better accuracy than post-training quantization because it teaches the model to handle quantization artifacts.
- Final Convert Step actually drops fake modules and replaces them with quantized implementations (e.g., int8 matmul).