# Post-Training Quantization (PTQ)



## Step-by-Step PTQ Flow:

1. Start with a Pretrained Full-Precision Model (float32):
   You start with a trained model (say in PyTorch or TensorFlow), with weights and activations in float32.

2. Attach observers to the model (usually to each layer or submodule)
   Observers track min/max or percentile values of activations during inference

3. Run calibration data (unlabeled):
    Attach observers to the model (usually to each layer or submodule)
    To prepare for quantization, you do:
      - Attach observers to layers. These track the range of values seen during a forward pass (min/max or percentiles).
      - Run calibration data (a few batches of real input data) to collect the stats.
     This step doesn't modify the model yet, it just gathers info about the distribution of values in weights and especially activations. Why? Because quantization requires knowing the range of values to scale them properly.

     What About Outputs Like Y = XW + B?

     When we do inference:
     Suppose X and W are quantized. Then Y = XW + B is done in int32.But Y is now a new tensor we haven’t seen before
     So... how do we dequantize Y?
     Answer: we observe Y during calibration too. You pass a few real inputs through the model and record Y’s range or percentiles → compute its scale and zero-point too.
     Observers track min/max or percentile values of activations during inference


4. Compute scale (s) and zero-point (z) using:
   Min-max or Percentile ranges to avoid outliers

5. Quantize weights and activations to int8 using computed s and z

6. Perform inference using integer-only arithmetic

###  Calibration and Observers (Key Insight)

You can’t quantize well unless you know the range of values (min, max, or distribution).
So, we run a few batches of real (but unlabeled) data through the model and record:
Min/max values of weights and activations Or better: percentiles (to ignore extreme outliers)

These recorded values are then used to compute the scale and zero-point for each layer.

### Input sometimes called “activation” can be quantized “on the fly” using a process called (dynamic quantization) or with observers.




# Post-Training Quantization: Weights & Biases Only (No Activations)

## Goal
Convert the trained model's float32 weights and biases to int8 (or lower-bit) integers to reduce memory usage and improve inference efficiency, especially on edge devices.

---

## Starting Point
You have a fully trained model with:
- Weights `W` (float32)
- Biases `B` (float32)

---

## Step-by-Step Flow (for Weights & Biases Only)

### Step 1: Load the Trained Model
- The model is already trained with float32 weights and biases.

### Step 2: Attach **Observers**
- Observers are modules that **record the min and max** of the weights.
- This range helps us decide how to map float values to integers.

### Step 3: **Calibrate the Weights**
- Calibration = Run the model or feed dummy data through it **just to let observers collect statistics.**
- You **don’t need labels**.
- This is **not training**; it’s just **observation**.
- From the observed min/max range of weights, we compute two parameters:
  - `s` = scale
  - `z` = zero-point

### Step 4: **Quantize the Weights and Biases**
- Use the computed `(s, z)` from calibration to convert:
  - `W_fp32 → W_int8`
  - `B_fp32 → B_int32` *(usually higher precision for bias)*
- Now the model stores quantized values.
- Statically after calibration (static quantization)

---

## Important Notes

- **We don’t dequantize W or B during inference.**
  - Once quantized, they stay in int8 and are directly used in int8 operations.
- **No backpropagation here.**
  - This is for inference only. Weights are "frozen".
- **Bias is usually quantized more precisely** (e.g., int32) to preserve numerical accuracy during matrix multiplications.

---

## Benefits
- Smaller model size (int8 is 4x smaller than float32).
- Faster inference on CPUs or low-power hardware (e.g., phones, microcontrollers).


# Post-Training Quantization: Activations (e.g., ReLU outputs)

## 1. Pretrained Model
- You start with a fully trained model with:
  - Float32 weights `W`
  - Float32 biases `B`
  - Float32 activations `X`, `Y`, etc.

---

## 2. Attach **Activation Observers**
- Insert observer modules **after each activation** or output of interest.
  - For example, after `ReLU`, `GELU`, etc.
- These observers will record the **min and max** values of the outputs during real inference-like usage.

---

## 3. **Calibrate the Activations**
- Run inference on some **representative data** (a few hundred to thousand samples).
  - This lets the observers **collect the actual min/max** ranges of activation outputs.
  - No gradients, no labels needed.
- From this range, compute:
  - `s_act` = activation scale
  - `z_act` = activation zero-point

---

## 4. Quantize Activations at Runtime
- You **do not statically quantize** activations like weights.
- Instead:
  - At inference time, input tensors (e.g., `X`) are quantized **on the fly** using the `(scale, zero_point)` learned during calibration:
    ```
    X_int8 = round(X_fp32 / s_act + z_act)
    ```
  - Then integer matrix multiplication is performed:
    ```
    Y_int32 = W_int8 * X_int8 + B_int32
    ```
  - And the output is **requantized** to int8 using the output scale and zero point:
    ```
    Y_int8 = round(Y_int32 * s_output) + z_output
    ```
  - Dynamically at runtime (dynamic quantization)

---

## 5. Inference (Quantized Model)
- Now all forward passes use:
  - `int8` inputs
  - `int8` weights
  - `int32` accumulations
  - `int8` outputs

- Float32 only appears if you convert inputs or outputs back to float.

---

## Summary Table

| Component       | Quantized? | How Quantized             | Static or Dynamic |
|----------------|------------|---------------------------|-------------------|
| Weights         | ✅ Yes     | After calibration          | Static            |
| Biases          | ✅ Yes     | Derived from weights & act | Static (int32)    |
| Activations     | ✅ Yes     | Calibrated using observers | Dynamic at runtime|



# Import the necessary libraries

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

# Load the MNIST dataset

In [43]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [44]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

# Define the model

In [None]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [46]:
net = VerySimpleNet().to(device)

# Train the model

In [None]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Loaded model from disk


# Define the testing loop

In [48]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Print weights and size of the model before quantization

In [49]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[-0.0068,  0.0126, -0.0359,  ...,  0.0154, -0.0028, -0.0045],
        [-0.0141, -0.0093, -0.0048,  ..., -0.0146, -0.0003, -0.0243],
        [ 0.0251,  0.0601,  0.0120,  ...,  0.0249,  0.0464,  0.0533],
        ...,
        [ 0.0564,  0.0601,  0.0255,  ...,  0.0201,  0.0394,  0.0024],
        [-0.0070,  0.0011,  0.0332,  ...,  0.0135,  0.0135,  0.0130],
        [ 0.0103,  0.0049, -0.0092,  ...,  0.0272, -0.0221, -0.0020]],
       requires_grad=True)
torch.float32


In [50]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 360.559


In [51]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing:  15%|█▌        | 150/1000 [00:00<00:01, 709.86it/s]

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 799.89it/s]

Accuracy: 0.963





# Insert min-max observers in the model

In [None]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet,self).__init__()

        # torch.quantization.QuantStub()
        # In PTQ (Post-Training Quantization):
        # QuantStub and DeQuantStub are placeholders in the model.
        # During model preparation, they don't actually quantize or dequantize the data yet — they just mark
        # the spots in the model graph where quantization and dequantization should happen once we convert the model.
        # After calibration and conversion, these are replaced by real quant/dequant ops that perform integer math.
        # So yes, they're "fake" initially, but become real ops after convert()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

When you place a QuantStub() in the forward pass, everything between that quantize() call and the matching dequantize() call will run in the quantized domain after convert() — as long as the layers inside are quantizable modules (like nn.Linear, nn.Conv2d, etc.) and have been prepared with observers.

```
float32 input 
   │
QuantStub → (convert) → real quantize (float32 → int8)
   │
[ int8 layer1 → int8 layer2 → int8 layer3 ... ]
   │
DeQuantStub → (convert) → real dequantize (int8 → float32)
   │
float32 output
```

You don’t have to sprinkle self.quant() before every layer — the quantization state is sticky until you explicitly dequantize.

Important details:
  - During prepare, PyTorch automatically attaches per-tensor/per-channel observers to all eligible layers between those stubs.
  - At convert, those float modules get swapped for quantized equivalents (nnq.Linear, nnq.Conv2d) that expect int8 inputs and store int8 weights.
  - Any layer outside the quant-dequant block stays float.

`torch.quantization.QuantStub()` is a module in PyTorch's `torch.ao.quantization` library used for quantization. It acts as a placeholder or "stub" within a neural network model, indicating where a floating-point tensor should be quantized to a lower-precision format (e.g., 8-bit integer) during the quantization process.

Here's a breakdown of its role:

- Quantization Point: QuantStub explicitly marks the input to a section of the model that will be quantized. When performing post-training static quantization or quantization-aware training, PyTorch's quantization tools identify these QuantStub instances and insert the necessary quantization operations (like observers and quantizers) during the prepare and convert stages.

- Calibration and Conversion: During calibration, QuantStub (along with a qconfig set on the model) helps collect statistics about the activation distributions to determine appropriate quantization parameters (scale and zero-point). In the convert step, QuantStub instances are replaced by actual torch.nn.quantized.Quantize modules, which perform the conversion from floating-point to quantized tensors.





#### You dequantize the model outputs before returning them.

Here’s why:
  -Quantized tensors are usually in torch.qint8 or torch.quint8 format, which most downstream code (loss functions, evaluation metrics, plotting, saving to JSON, etc.) doesn’t understand.
  Dequantizing converts them back to standard torch.float32 values so they can be consumed normally.

```
float input  
   │
QuantStub  (float → int8)
   │
[int8 layers ... int8 weights, int8 activations]
   │
DeQuantStub  (int8 → float)
   │
float output
```

The final output of your model (before you hand it back to the user or another float-based component) should usually be dequantized.

In [None]:
net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
# Attach a quantization configuration (qconfig) to model
# This tells PyTorch which observer to use (histogram, min-max, etc.)
# "fbgemm" is backend optimized for x86 CPUs
# net_quantized.qconfig = torch.quantization.get_default_qconfig("fbgemm")


#  This inserts observers in the model layers (for both weights and activations)
# Observers record min/max ranges during calibration to later compute scale (s) and zero point (z)
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Calibrate the model using the test set

In [None]:

# Calibration - pass real or synthetic data through the model
# This will "observe" the weights (already known) and activations (depends on data)
# It does NOT update weights; just gathers activation stats via observers
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 750.89it/s]

Accuracy: 0.963





In [55]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-53.58397674560547, max_val=34.898128509521484)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-24.331275939941406, max_val=26.62542152404785)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-28.273700714111328, max_val=20.937761306762695)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Quantize the model using the statistics collected

In [None]:

#  This step:
#   - Quantizes weights using s and z computed from calibration
#   - Quantizes activations using the observed ranges
#   - Replaces float ops with quantized integer equivalents
net_quantized = torch.ao.quantization.convert(net_quantized)

In [None]:
print(f'Check statistics of the various layers')
net_quantized

# At this point, the model is fully quantized: weights (W), biases (B), and activations
# All matmuls, relus, etc. now use INT8 arithmetic internally

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6967094540596008, zero_point=77, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.40123382210731506, zero_point=61, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.3874918520450592, zero_point=73, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights of the model after quantization

In [58]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[-2,  3, -8,  ...,  4, -1, -1],
        [-3, -2, -1,  ..., -3,  0, -6],
        [ 6, 14,  3,  ...,  6, 11, 12],
        ...,
        [13, 14,  6,  ...,  5,  9,  1],
        [-2,  0,  8,  ...,  3,  3,  3],
        [ 2,  1, -2,  ...,  6, -5,  0]], dtype=torch.int8)


# Compare the dequantized weights and the original weights

In [59]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[-0.0068,  0.0126, -0.0359,  ...,  0.0154, -0.0028, -0.0045],
        [-0.0141, -0.0093, -0.0048,  ..., -0.0146, -0.0003, -0.0243],
        [ 0.0251,  0.0601,  0.0120,  ...,  0.0249,  0.0464,  0.0533],
        ...,
        [ 0.0564,  0.0601,  0.0255,  ...,  0.0201,  0.0394,  0.0024],
        [-0.0070,  0.0011,  0.0332,  ...,  0.0135,  0.0135,  0.0130],
        [ 0.0103,  0.0049, -0.0092,  ...,  0.0272, -0.0221, -0.0020]],
       requires_grad=True)

Dequantized weights: 
tensor([[-0.0087,  0.0131, -0.0348,  ...,  0.0174, -0.0044, -0.0044],
        [-0.0131, -0.0087, -0.0044,  ..., -0.0131,  0.0000, -0.0261],
        [ 0.0261,  0.0609,  0.0131,  ...,  0.0261,  0.0479,  0.0522],
        ...,
        [ 0.0566,  0.0609,  0.0261,  ...,  0.0218,  0.0392,  0.0044],
        [-0.0087,  0.0000,  0.0348,  ...,  0.0131,  0.0131,  0.0131],
        [ 0.0087,  0.0044, -0.0087,  ...,  0.0261, -0.0218,  0.0000]])



# Print size and accuracy of the quantized model

In [60]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 94.955


In [61]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 774.94it/s]

Accuracy: 0.963



