# Importing Packages

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg16
from torch.quantization import QuantStub, DeQuantStub
import torch.quantization
import torch.optim as optim
from torchinfo import summary
from tqdm import tqdm

# Downloading Data
- Note that we apply a transform to the original CIFAR-100 data set $(3,32,32) \rightarrow (3,224,224)$ where we have (channels, height, width)

In [2]:
transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                         std=[0.2673, 0.2564, 0.2762])
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                         std=[0.2673, 0.2564, 0.2762])
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                         download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                        download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=4)

# Making the Model
- It is important to note that we have to insert the `QuantStub` and `DeQuantStub` functions to tell the tool where to apply quantization for later
- Also, in here we pull in the original `vgg16` model and we need to replace the final cassification layer to have 100 outputs instead of 1,000 because the former was due to CIFAR-100 having 100 classes only and the latter was because VGG16 was originally trained for ImageNet.
- Then, there is a `fuse_model` function that fuses the convolution and ReLU together to have one ConvRelu operation. But particularly what this does is to have 1 quantization point (scale + zero point) for the combination of conv and relu all together.

In [3]:
class QuantizableVGG16(nn.Module):
    def __init__(self, num_classes=100):
        super(QuantizableVGG16, self).__init__()
        # LOOK! These are important to be defined
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        # Use pretrained VGG16's features
        self.features = vgg16(pretrained=True).features

        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    # Observe how the quant and dequant are called
    # You quantize the inputs then at the end
    # They are dequantized
    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        # Fuse Conv + ReLU in VGG16 features
        for idx in range(len(self.features)):
            if isinstance(self.features[idx], nn.Conv2d):
                next_idx = idx + 1
                if next_idx < len(self.features) and isinstance(self.features[next_idx], nn.ReLU):
                    torch.quantization.fuse_modules(self.features, [str(idx), str(next_idx)], inplace=True)

        # Fuse classifier layers: Linear + ReLU
        for idx in [0, 3]:  # indices of Linear layers followed by ReLU
            torch.quantization.fuse_modules(self.classifier, [str(idx), str(idx + 1)], inplace=True)


# Model Settings

In [4]:
# Create the model
model = QuantizableVGG16()
summary(model, input_size=(1, 3, 224, 224))



Layer (type:depth-idx)                   Output Shape              Param #
QuantizableVGG16                         [1, 100]                  --
├─QuantStub: 1-1                         [1, 3, 224, 224]          --
├─Sequential: 1-2                        [1, 512, 7, 7]            --
│    └─Conv2d: 2-1                       [1, 64, 224, 224]         1,792
│    └─ReLU: 2-2                         [1, 64, 224, 224]         --
│    └─Conv2d: 2-3                       [1, 64, 224, 224]         36,928
│    └─ReLU: 2-4                         [1, 64, 224, 224]         --
│    └─MaxPool2d: 2-5                    [1, 64, 112, 112]         --
│    └─Conv2d: 2-6                       [1, 128, 112, 112]        73,856
│    └─ReLU: 2-7                         [1, 128, 112, 112]        --
│    └─Conv2d: 2-8                       [1, 128, 112, 112]        147,584
│    └─ReLU: 2-9                         [1, 128, 112, 112]        --
│    └─MaxPool2d: 2-10                   [1, 128, 56, 56]          --

In [5]:
# Go into training mode
model.train()

# Fuse Conv, BN, and ReLU layers
model.fuse_model()

# Set quantization configuration
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# Prepare for QAT
torch.quantization.prepare_qat(model, inplace=True)



QuantizableVGG16(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (dequant): DeQuantStub()
  (features): Sequential(
    (0): ConvReLU2d(
      3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False


# Set to use GPUs in parallel

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.to(device)

DataParallel(
  (module): QuantizableVGG16(
    (quant): QuantStub(
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (dequant): DeQuantStub()
    (features): Sequential(
      (0): ConvReLU2d(
        3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
          fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qschem

# Set training parameters and train model

In [7]:
# Training parameters
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix(loss=running_loss / (progress_bar.n + 1))


Epoch 1/50: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [02:50<00:00,  2.29it/s, loss=2.46]
Epoch 2/50: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [02:49<00:00,  2.31it/s, loss=1.36]
Epoch 3/50: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [02:51<00:00,  2.28it/s, loss=1.05]
Epoch 4/50: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

# Evaluate Model Before Quantization

In [8]:
def orig_evaluate(model, dataloader):
    model.eval()
    model.to('cuda')

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    print(f"Accuracy of the quantized model: {acc:.2f}%")
    return acc


orig_evaluate(model, testloader)

Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:17<00:00,  4.57it/s]

Accuracy of the quantized model: 72.92%





72.92

# Convert Model to be Quantized
- This step makes the model with int8 values

In [9]:
model.eval()
model.cpu()
quantized_model = torch.quantization.convert(model, inplace=False)

# Save quantized model

In [10]:
torch.save(quantized_model.state_dict(), 'quantized_vgg16_cifar100.pth')

# Setting Quantized Model to CPU
- This is mandatory as CPU can do the int8 operations but GPUs can only do FP32

In [11]:
# Ensure model is quantized and on CPU
quantized_model = quantized_model.module  # unwrap from DataParallel
quantized_model.eval()
quantized_model.to('cpu')  # <--- this fixes the RuntimeError

QuantizableVGG16(
  (quant): Quantize(scale=tensor([0.0309]), zero_point=tensor([61]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (features): Sequential(
    (0): QuantizedConvReLU2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.10570178180932999, zero_point=0, padding=(1, 1))
    (1): Identity()
    (2): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.16206395626068115, zero_point=0, padding=(1, 1))
    (3): Identity()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.3249315023422241, zero_point=0, padding=(1, 1))
    (6): Identity()
    (7): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.5744616985321045, zero_point=0, padding=(1, 1))
    (8): Identity()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(1, 1), sc

# Evaluate the Quantized Model

In [12]:
def evaluate(model, dataloader):
    model.eval()
    model.to('cpu')  # ensure model is on CPU

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.cpu()
            labels = labels.cpu()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100 * correct / total
    print(f"Accuracy of the quantized model: {acc:.2f}%")
    return acc


evaluate(quantized_model, testloader)

Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [07:23<00:00,  5.61s/it]

Accuracy of the quantized model: 72.97%





72.97

# Saving ONNX Model

In [13]:
dummy_input = torch.randn(1, 3, 224, 224)  # ONNX needs a batch dimension

In [14]:
torch.onnx.export(
    quantized_model,
    dummy_input,
    "quantized_vgg16.onnx",
    export_params=True,
    opset_version=13,  # Use 13+ for better quant support
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)