<a href="https://colab.research.google.com/github/rrankawat/pytorch-cnn/blob/main/CIFAR_10_Static_Quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Post-Training Static Quantization (PTQ)

Post-Training Static Quantization (PTQ) is a technique that converts a pre-trained floating-point model to a lower-precision format, like an 8-bit integer, without requiring any retraining.

In [172]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization
from torch.quantization import QuantStub, DeQuantStub
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import os, time

In [173]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


###### Base Model

In [174]:
class CIFARConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)   # -> 16x32x32
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)  # -> 64x32x32
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # -> 64x32x32
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1) # -> 128x32x32
        self.bn4 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128*2*2, 256)
        self.fc2 = nn.Linear(256, 10)

        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)  # 32 -> 16

        # Block 2
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)  # 16 -> 8

        # Block 3
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2, 2)  # 8 -> 4

        # Block 4
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2, 2)  # 4 -> 2

        # Flatten
        x = x.reshape(-1, 128*2*2)

        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

###### Quantization Wrapper

In [175]:
class QuantizedCIFAR(nn.Module):
    def __init__(self, model_fp32):
        super().__init__()
        self.model_fp32 = model_fp32
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

###### Load Trained Model Weights

In [176]:
model_fp32 = CIFARConvNet()
model_fp32.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/model_cifar10.pth"))
model_fp32.eval()
print("✅ Trained model loaded successfully!")

✅ Trained model loaded successfully!


In [177]:
# Wrap in quantization wrapper
model_to_quant = QuantizedCIFAR(model_fp32)

###### CIFAR10 Data Loading

In [178]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

###### Static Quantization Preparation

In [179]:
model_to_quant.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_to_quant, inplace=True)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare(model_to_quant, inplace=True)


QuantizedCIFAR(
  (model_fp32): CIFARConvNet(
    (conv1): Conv2d(
      3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (bn1): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (conv2): Conv2d(
      16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (bn2): BatchNorm2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (conv3): Conv2d(
      32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (bn3): BatchNorm2d(
      64, eps=1e-05, momentum=0.1, affine=True, track_running_

In [180]:
# Calibrate using few batches
with torch.no_grad():
    for i, (images, _) in enumerate(test_loader):
        model_to_quant(images)
        if i > 10:
            break

quantized_model = torch.quantization.convert(model_to_quant.eval(), inplace=False)
print("✅ Static quantization done!")

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.convert(model_to_quant.eval(), inplace=False)


✅ Static quantization done!


###### Accuracy Evaluation Function

In [181]:
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

acc_fp32 = evaluate(model_fp32, test_loader)
acc_int8 = evaluate(quantized_model, test_loader)

print(f"🧠 FP32 Accuracy: {acc_fp32:.2f}%")
print(f"⚡ INT8 Quantized Accuracy: {acc_int8:.2f}%")

🧠 FP32 Accuracy: 81.05%
⚡ INT8 Quantized Accuracy: 80.85%


###### Inference Time Comparison

In [182]:
def measure_inference_time(model, loader, num_batches=20):
    model.eval()
    start = time.time()
    with torch.no_grad():
        for i, (images, _) in enumerate(loader):
            _ = model(images)
            if i >= num_batches:
                break
    return (time.time() - start) / num_batches

t_fp32 = measure_inference_time(model_fp32, test_loader)
t_int8 = measure_inference_time(quantized_model, test_loader)

print(f"⏱️ FP32 Inference Time: {t_fp32:.4f} sec/batch")
print(f"⚡ INT8 Inference Time: {t_int8:.4f} sec/batch")

⏱️ FP32 Inference Time: 0.1789 sec/batch
⚡ INT8 Inference Time: 0.0864 sec/batch


###### Model Size Comparison

In [183]:
torch.save(model_fp32.state_dict(), "model_fp32.pth")
torch.save(quantized_model.state_dict(), "model_int8.pth")

fp32_size = os.path.getsize("model_fp32.pth") / 1e6
int8_size = os.path.getsize("model_int8.pth") / 1e6

print(f"📦 FP32 Model Size: {fp32_size:.2f} MB")
print(f"📦 INT8 Quantized Model Size: {int8_size:.2f} MB")

📦 FP32 Model Size: 1.03 MB
📦 INT8 Quantized Model Size: 0.27 MB


###### Save Quantized Model to Drive

In [184]:
torch.save(quantized_model.state_dict(), "/content/drive/My Drive/Colab Notebooks/model_cifar10_quantized.pth")
print("✅ Quantized model saved!")

✅ Quantized model saved!
