In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.datasets import *
from torchvision.transforms import *
import os
from tqdm import tqdm
import copy

In [2]:
# Configuration class
class Config:
    BATCH_SIZE = 512
    LEARNING_RATE = 0.001
    EPOCHS = 1
    QAT_EPOCHS = 5
    WEIGHT_BITS = 8
    ACTIVATION_BITS = 8
    PATIENCE = 3
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class QuantizeSTE(torch.autograd.Function):
    """
    Straight-Through Estimator (STE) for quantization.
    During the forward pass, the input tensor is quantized.
    During the backward pass, the gradients are passed through as if the quantization operation was the identity function.
    """
    @staticmethod
    def forward(ctx, tensor: torch.Tensor, bit_width: int):
        # Quantize the tensor
        max_val = tensor.abs().max()
        scale = max_val / (float(2 ** (bit_width - 1) - 1)) if max_val != 0 else 1.0
        q_tensor = torch.round(tensor / scale).clamp(
            min=-(2 ** (bit_width - 1)),
            max=(2 ** (bit_width - 1) - 1)
        )
        # Save scale for backward pass
        ctx.save_for_backward(scale)
        return q_tensor.to(torch.int8), scale

    @staticmethod
    def backward(ctx, grad_q_tensor, grad_scale):
        # During backward pass, pass the gradient through as if the quantization was the identity function
        scale, = ctx.saved_tensors
        grad_tensor = grad_q_tensor / scale
        return grad_tensor, None  # No gradient for bit_width


class QuantizedLinear(nn.Module):
    def __init__(self, layer: nn.Module, bit_width: int = 8, 
                 act_bit_width: int = 8, device: str = 'cpu'):
        super(QuantizedLinear, self).__init__()
            
        self.bit_width = bit_width
        self.act_bit_width = act_bit_width
        self.device = device

        self.is_conv = isinstance(layer, nn.Conv2d)
        if self.is_conv:
            self.stride = layer.stride
            self.padding = layer.padding

        self.weight = nn.Parameter(layer.weight.data.detach().clone()).to(device)
        self.bias = nn.Parameter(layer.bias.data.detach().clone()).to(device) if layer.bias is not None else None
    
        self.register_buffer('weight_scale', torch.tensor(1.0).to(device))
        self.register_buffer('act_scale', torch.tensor(1.0).to(device))
        self.register_buffer('qweight', None)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Quantize weights using STE
        if self.qweight is None or self.training:
            self.qweight, self.weight_scale = QuantizeSTE.apply(self.weight, self.bit_width)
            self.qweight = self.qweight.to(self.device)
        
        # Perform the linear or convolutional operation
        if self.is_conv:
            x = F.conv2d(x, self.qweight.float() * self.weight_scale, self.bias, stride=self.stride, padding=self.padding)
        else:
            x = F.linear(x, self.qweight.float() * self.weight_scale, self.bias)
        
        # Quantize activations using STE
        if self.act_bit_width is not None:
            x, self.act_scale = QuantizeSTE.apply(x, self.act_bit_width)
            x = x.float() * self.act_scale
        
        return x

    def __repr__(self):
        if self.is_conv:
            return f"QuantizedConv2d(in_channels={self.weight.shape[1]}, out_channels={self.weight.shape[0]}, bias={self.bias is not None})"
        else:
            return f"QuantizedLinear(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, bias={self.bias is not None})"


In [4]:
# Data loaders
def get_data_loaders():
    image_size = 32
    transforms = {
        "train": Compose([
            RandomCrop(image_size, padding=4),
            RandomHorizontalFlip(),
            ToTensor(),
        ]),
        "test": ToTensor(),
    }

    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms['train'], download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transforms['test'], download=True)
    
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)
    
    return train_loader, test_loader


In [5]:
# Model evaluation
def evaluate_model(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return 100 * correct / total

# Model training
def train_model(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, 
                criterion: nn.Module, optimizer: optim.Optimizer, qat: bool = False):
    epochs = Config.EPOCHS if not qat else Config.QAT_EPOCHS
    for epoch in tqdm(range(epochs)):
        model.train()
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Evaluate the model at the end of each epoch
        accuracy = evaluate_model(model, test_loader)

In [6]:
# Replace layers with quantized versions
def replace_layers(model: nn.Module, device: str):
    for name, module in model.named_children():
        if isinstance(module, (nn.Linear,nn.Conv2d)):
            quant_layer = QuantizedLinear(module, 
                                        bit_width=Config.WEIGHT_BITS,
                                        act_bit_width=Config.ACTIVATION_BITS,
                                        device=device)
            setattr(model, name, quant_layer)
        else:
            replace_layers(module, device)
            

In [7]:
from collections import defaultdict, OrderedDict

class VGG(nn.Module):
      ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
    
      def __init__(self) -> None:
        super().__init__()
    
        layers = []
        counts = defaultdict(int)
    
        def add(name: str, layer: nn.Module) -> None:
          layers.append((f"{name}{counts[name]}", layer))
          counts[name] += 1
    
        in_channels = 3
        for x in self.ARCH:
          if x != 'M':
            # conv-bn-relu
            add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
            add("bn", nn.BatchNorm2d(x))
            add("relu", nn.ReLU(True))
            in_channels = x
          else:
            # maxpool
            add("pool", nn.MaxPool2d(2))
    
        self.backbone = nn.Sequential(OrderedDict(layers))
        self.classifier = nn.Linear(512, 10)
    
      def forward(self, x: torch.Tensor) -> torch.Tensor:
        # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
        x = self.backbone(x)
    
        # avgpool: [N, 512, 2, 2] => [N, 512]
        x = x.mean([2, 3])
    
        # classifier: [N, 512] => [N, 10]
        x = self.classifier(x)
        return x

In [8]:
#Load Dataloader
train_loader, test_loader = get_data_loaders()

#Load Pre Trained Model and Weights
model_path = 'vgg.cifar.pretrained.pth'
model = VGG().cuda()
checkpoint = torch.load(model_path, weights_only=True)
model.load_state_dict(checkpoint)

print(f"\n Original Model: {model}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
accuracy = evaluate_model(model, test_loader)
print(f"Original Model Accuracy: {accuracy:.2f}%")


Files already downloaded and verified
Files already downloaded and verified

 Original Model: VGG(
  (backbone): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=

In [10]:
# Apply QAT by replacing layers
qat_model = copy.deepcopy(model)
replace_layers(qat_model, Config.DEVICE)
print(f"Quantized Model: {qat_model}")


Quantized Model: VGG(
  (backbone): Sequential(
    (conv0): QuantizedConv2d(in_channels=3, out_channels=64, bias=False)
    (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): QuantizedConv2d(in_channels=64, out_channels=128, bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): QuantizedConv2d(in_channels=128, out_channels=256, bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): QuantizedConv2d(in_channels=256, out_channels=256, bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [11]:
#Run Quantization Aware Training
qat_model.to(Config.DEVICE)  # Move the model to the specified device
optimizer = optim.Adam(qat_model.parameters(), lr=Config.LEARNING_RATE)
train_model(qat_model, train_loader, test_loader, criterion, optimizer, qat=True)

accuracy = evaluate_model(qat_model, test_loader)
print(f"Quantized Model Accuracy: {accuracy:.2f}%")

100%|██████████| 5/5 [01:04<00:00, 12.96s/it]


Quantized Model Accuracy: 92.94%


In [12]:
#Inspect the weight dtypes of the QAT Model
for name, module in qat_model.named_modules():
    if isinstance(module, QuantizedLinear):
        print(f"QLayer: {name},  Qweight dtype: {module.qweight.dtype}")

QLayer: backbone.conv0,  Qweight dtype: torch.int8
QLayer: backbone.conv1,  Qweight dtype: torch.int8
QLayer: backbone.conv2,  Qweight dtype: torch.int8
QLayer: backbone.conv3,  Qweight dtype: torch.int8
QLayer: backbone.conv4,  Qweight dtype: torch.int8
QLayer: backbone.conv5,  Qweight dtype: torch.int8
QLayer: backbone.conv6,  Qweight dtype: torch.int8
QLayer: backbone.conv7,  Qweight dtype: torch.int8
QLayer: classifier,  Qweight dtype: torch.int8
