# Import the necessary libraries

In [16]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

*  torch and torch.nn: PyTorch core and neural network layers.
*  torchvision.datasets/transforms: Convenient datasets and image preprocessing
*  matplotlib: Not used later, but typically for plotting (safe to remove in this notebook).
*  tqdm: Pretty progress bars for training/testing loops. pathlib / os: File handling utilities.





# Load the MNIST dataset

In [17]:
# Make torch deterministic and sets a random seed
_ = torch.manual_seed(0)

In [18]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

ToTensor() converts PIL images to tensors (values in [0,1]).
Normalize(mean, std) standardizes images using MNIST’s mean/std (~0.13/0.31). This helps training converge.

Loads training set (downloads if missing) and wraps it in a DataLoader:
batch_size=10: the number of images per step.
shuffle=True: mix the order each epoch for better training.

QAT → quantized inference is primarily supported on CPU in PyTorch. So we keep it on CPU.





# Define the model

In [19]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net = VerySimpleNet().to(device)

What’s happening here?

This is a simple MLP for MNIST:

Input: 28×28 = 784 pixels → flatten to a vector
Two hidden layers (Linear + ReLU)
Output: 10 logits (digits 0…9)



Why QuantStub and DeQuantStub?
Simulating the quantization for weight and biases.

QuantStub() and DeQuantStub() mark where quantization starts and ends.
During training with QAT:

QuantStub inserts observers and fake quantization at the input to simulate the
full precision weights.

Layers in between will be trained with simulated INT8 behavior.
DeQuantStub converts the output back to FP32 (so losses work normally).

# Insert min-max observers in the model

In [20]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers
net_quantized

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

qconfig: Defines how to quantify activations and weights:

What kind of observer to use (e.g., MinMax)
Fake‑quant modules to simulate INT8 during training


prepare_qat(model):

Inserts observers and fake‑quant modules into the model (at places like stubs and supported layers).

# Train the model

In [21]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

train(train_loader, net_quantized, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:49<00:00, 120.72it/s, loss=0.224]


Uses CrossEntropyLoss for classification and Adam optimizer.
Inside the loop:

Flatten input (view(-1, 28*28))
Forward pass through the QAT‑prepared model → fake quant applies
Compute loss, backprop, and update weights
Uses tqdm to display average loss nicely



The crucial part: this training happens with fake quantization in place, so the model learns quantization‑friendly weights.

# Define the testing loop

In [23]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Check the collected statistics during training

In [24]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.4810347855091095, max_val=0.3415062427520752)
    (activation_post_process): MinMaxObserver(min_val=-40.239410400390625, max_val=39.7955436706543)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.43307897448539734, max_val=0.3635982275009155)
    (activation_post_process): MinMaxObserver(min_val=-41.65215301513672, max_val=22.231578826904297)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5131738185882568, max_val=0.20312127470970154)
    (activation_post_process): MinMaxObserver(min_val=-32.679283142089844, max_val=22.15723419189453)
  )
  (relu): ReLU()
  (dequant): DeQuantS

# Quantize the model using the statistics collected

In [10]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_quantized = torch.ao.quantization.convert(net_quantized)


In [11]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6301965117454529, zero_point=64, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.5030215382575989, zero_point=83, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.4317835867404938, zero_point=76, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights and size of the model after quantization

In [12]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[  3,   8,  -5,  ...,   9,   4,   4],
        [ -5,  -4,  -3,  ...,  -5,  -2,  -8],
        [  0,   9,  -4,  ...,   0,   5,   7],
        ...,
        [  4,   5,  -4,  ...,  -5,   0, -10],
        [ -5,  -3,   6,  ...,   1,   1,   1],
        [  3,   2,  -2,  ...,   8,  -5,   0]], dtype=torch.int8)


In [13]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 293.27it/s]

Accuracy: 0.957





In [26]:

# Before conversion
print("Accuracy before convert (QAT prepared):")
print(net_quantized)
test(net_quantized)

# Compare sizes
print("FP32 model size:")
print_size_of_model(net)

# Convert to INT8
net_quantized.eval()
net_int8 = torch.ao.quantization.convert(net_quantized)

print("INT8 model size:")
print_size_of_model(net_int8)

# After conversion
print("Accuracy after convert (INT8):")
test(net_int8)


Accuracy before convert (QAT prepared):
VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.4810347855091095, max_val=0.3415062427520752)
    (activation_post_process): MinMaxObserver(min_val=-41.82633590698242, max_val=39.7955436706543)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.43307897448539734, max_val=0.3635982275009155)
    (activation_post_process): MinMaxObserver(min_val=-41.65215301513672, max_val=22.231578826904297)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5131738185882568, max_val=0.20312127470970154)
    (activation_post_process): MinMaxObserver(min_val=-32.679283142089844, max_val=22.15723419189453)
  )

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 303.69it/s]
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_int8 = torch.ao.quantization.convert(net_quantized)


Accuracy: 0.958
FP32 model size:
Size (KB): 361.465
INT8 model size:
Size (KB): 95.797
Accuracy after convert (INT8):


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 276.97it/s]

Accuracy: 0.957



