<a href="https://colab.research.google.com/github/rrankawat/pytorch-cnn/blob/main/CIFAR_10_Static_Quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Post-Training Static Quantization (PTQ)

Post-Training Static Quantization (PTQ) is a technique that converts a pre-trained floating-point model to a lower-precision format, like an 8-bit integer, without requiring any retraining.

In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization
from torch.quantization import QuantStub, DeQuantStub
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import os

###### Base Model

In [52]:
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)   # -> 16x32x32
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # -> 64x32x32
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # -> 64x32x32
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1) # -> 128x32x32
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128*2*2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)  # 32 -> 16

        # Block 2
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)  # 16 -> 8

        # Block 3
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)  # 8 -> 4

        # Block 4
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2, 2)  # 4 -> 2

        # Flatten
        x = x.view(-1, 128*2*2)

        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

###### Quantization-Ready Model

In [53]:
class QuantCIFARConvNet(CIFARConvNet):
    def __init__(self):
        super().__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)  # Quantize input

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2, 2)

        x = x.reshape(-1, 128 * 2 * 2) # Changed from view to reshape
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        x = self.dequant(x)  # Dequantize output
        return x

###### CIFAR10 Data Loading

In [54]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)

###### Evaluate Model Accuracy

In [55]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predicted = torch.max(outputs.data, 1)[1]
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc

###### Prepare Model for Quantization

In [56]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [57]:
model_fp32 = QuantCIFARConvNet().to(device)
model_fp32.eval()

QuantCIFARConvNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [58]:
# Fuse Conv + BN + ReLU for better quantization performance
# We skip explicit ReLU fusion here since it's inline (F.relu)
torch.quantization.fuse_modules(model_fp32,
    [['conv1', 'bn1'],
     ['conv2', 'bn2'],
     ['conv3', 'bn3'],
     ['conv4', 'bn4']], inplace=True)

QuantCIFARConvNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): Identity()
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): Identity()
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): Identity()
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): Identity()
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (quant): QuantStub()
  (dequant): DeQuantStub()
)

In [59]:
# Set quantization configuration
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

In [60]:
# Prepare model for static quantization
torch.quantization.prepare(model_fp32, inplace=True)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare(model_fp32, inplace=True)


QuantCIFARConvNet(
  (conv1): Conv2d(
    3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (bn1): Identity()
  (conv2): Conv2d(
    16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (bn2): Identity()
  (conv3): Conv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (bn3): Identity()
  (conv4): Conv2d(
    64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (bn4): Identity()
  (fc1): Linear(
    in_features=512, out_features=256, bias=True
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (fc2): Linear(
    in_features=256, out_features=10, bias=True
    (activation_post_process): HistogramObserv

###### Calibration Step

In [61]:
print("Calibrating the model...")
with torch.no_grad():
    for i, (images, _) in enumerate(train_loader):
        if i >= 10:  # 10 batches are enough for calibration
            break
        images = images.to(device)
        model_fp32(images)

Calibrating the model...


###### Convert to Quantized Model

In [62]:
quantized_model = torch.quantization.convert(model_fp32.eval(), inplace=False)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.convert(model_fp32.eval(), inplace=False)


###### Evaluate Both Models (FP32 vs INT8)

In [63]:
# For fair comparison, move models to CPU (quantized runs only on CPU)
model_fp32_cpu = model_fp32.to('cpu')
quantized_model_cpu = quantized_model.to('cpu')

acc_fp32 = evaluate(model_fp32_cpu, test_loader, 'cpu')
acc_quant = evaluate(quantized_model_cpu, test_loader, 'cpu')

print(f"\nFP32 Model Accuracy: {acc_fp32:.2f}%")
print(f"Quantized Model Accuracy: {acc_quant:.2f}%")


FP32 Model Accuracy: 10.00%
Quantized Model Accuracy: 10.00%


###### Compare Model Sizes

In [64]:
def model_size(model, name):
    torch.save(model.state_dict(), f"{name}.pth")
    size = os.path.getsize(f"{name}.pth") / 1e6
    print(f"{name} size: {size:.2f} MB")

model_size(model_fp32_cpu, "original_fp32")
model_size(quantized_model_cpu, "quantized_int8")

print("\nQuantization complete ✅")

original_fp32 size: 1.00 MB
quantized_int8 size: 0.25 MB

Quantization complete ✅
