#### What is Quantization?
Quantization is a core method used to reduce computational and memory costs of large neural network models by converting high precision floating point representations (e.g. float32) to lower precision integer types (e.g. int8) with minimal impact on model accuracy. This process helps to significantly reduce inference time and compute resources.

Quantization reduces neural network weights and activations from 32-bit floats to 8-bit integers, reducing storage and memory requirements by up to 4x. Int8 matrix multiplication is also much faster on most hardwares (e.g. cpu and embedded accelerators)

#### How Quantization works?
Quantization replaces original Weighs $W$ and biases $b$ stored as 32-bit floats with a lower-precision integer representation usually $W$ represented in int8 while $b$ in int32. The matrix multiplication with the input to the layer is then performed in this lower precision representation before de-quantization via the scale and zero-point parameters and feeding the output to the next layer.

<!-- #### Types of Quantization -->

#### PyTorch Implementation of Post-Training Quantization

In [1]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose(
    [transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# The Model
import torch
import torch.nn as nn
import torch.quantization

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Add a quantization layer
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        # self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(14 * 14 * 64, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        # Add dequantization layer
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x) # quantize input to int8
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.dequant(x) # dequantize output to float
        return x

model_fp32 = CNN()
model_fp32

100%|██████████| 9.91M/9.91M [00:02<00:00, 4.56MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.26MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.60MB/s]


CNN(
  (quant): QuantStub()
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=12544, out_features=128, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (dequant): DeQuantStub()
)

In [3]:
# Train the model
import torch.optim as optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_fp32.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)

print(f"Training on {device}")
for epoch in range(5):
    running_loss = 0.0
    model_fp32.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model_fp32(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

Training on cuda
Epoch 1, Loss: 0.01163397313121483
Epoch 2, Loss: 0.00725513479790467
Epoch 3, Loss: 0.007493622528057428
Epoch 4, Loss: 0.005139252133150162
Epoch 5, Loss: 0.004771352358855667


In [5]:
# Fuse Layers
def fuse_model(model):
    torch.quantization.fuse_modules(model_fp32,
     [['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']], inplace=True)
fuse_model(model_fp32)

In [7]:
# Set quantization configuration. We use the fbgemm configuration
# FBGEMM (Facebook GEneral Matrix Multiplication) is a low-precision,
# high-performance matrix-matrix multiplications and convolution library for
# server-side inference. This is preferred for x86 CPUs. For ARM CPUs use 'qnnpack'.
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

In [9]:
# Prepare for quantization
model_fp32.cpu()
model_fp32_prepared = torch.quantization.prepare(model_fp32)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  model_fp32_prepared = torch.quantization.prepare(model_fp32)


In [12]:
model_fp32_prepared

CNN(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0079]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0090]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.1397202014923096)
    )
  )
  (relu1): Identity()
  (conv2): ConvReLU2d(
    (0): Conv2

In [10]:
# Calibration with a sample data
print("Calibrating...\n")
model_fp32_prepared.eval()
with torch.no_grad():
    for images, labels in train_loader:
        model_fp32_prepared(images)
        break

Calibrating...



In [14]:
# Convert to Quantized Model
quantized_model = torch.quantization.convert(model_fp32_prepared, inplace=False)
quantized_model

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.convert(model_fp32_prepared, inplace=False)


CNN(
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (conv1): QuantizedConvReLU2d(1, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.008974174968898296, zero_point=0, padding=(1, 1))
  (relu1): Identity()
  (conv2): QuantizedConvReLU2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.04490518197417259, zero_point=0, padding=(1, 1))
  (relu2): Identity()
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): QuantizedLinearReLU(in_features=12544, out_features=128, scale=0.29909777641296387, zero_point=0, qscheme=torch.per_channel_affine)
  (relu3): Identity()
  (fc2): QuantizedLinear(in_features=128, out_features=10, scale=0.5427078604698181, zero_point=59, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)

In [15]:
# Run the test data through each model and compare float32 and quantized int8 accuracy
def evaluate_model(model, data_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Accuracy: {100.0 * correct / total}%")
    return correct / total

print(f"Evaluating original float32 model")
evaluate_model(model_fp32, test_loader)

print(f"Evaluating quantized model")
evaluate_model(quantized_model, test_loader)

Evaluating original float32 model
Accuracy: 98.89%
Evaluating quantized model
Accuracy: 98.91%


0.9891

In [16]:
# Let's check the model file sizes
import os

# Save model
torch.save(model_fp32.state_dict(), "fp32_model.pth")
torch.save(quantized_model.state_dict(), "quantized_model.pth")

fp32_size = os.path.getsize("fp32_model.pth")
quantized_size = os.path.getsize("quantized_model.pth")

print(f"Original Model Size: {fp32_size / (1024 * 1024)} MB")
print(f"Quantized Model Size: {quantized_size / (1024 * 1024)} MB")

Original Model Size: 6.205862998962402 MB
Quantized Model Size: 1.5630693435668945 MB
