In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm
from pathlib import Path
import os

torch.manual_seed(42)
device = "mps"

# Preparation

#### mnist datasets & loader

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# train
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
# test
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

#### model

In [3]:
class Classification(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=128):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

#### train script

In [4]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

#### test script

In [5]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Train/Load Model

In [6]:
MODEL_FILENAME = 'simplenet_ptq.pt'

net = Classification().to(device)
if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    torch.save(net.state_dict(), MODEL_FILENAME)

Loaded model from disk


In [7]:
test(net)

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 178.12it/s]

Accuracy: 0.96





# (PTQ) Post Training Quantisation

In [10]:
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
from torchao.quantization import quantize_
from copy import deepcopy

q_net = deepcopy(net)
quantize_(q_net, Int8DynamicActivationInt8WeightConfig())

# Inspect Models

activation=\<function _int8_symm_per_token_reduced_range_quant\>:
- INT8
- symmetric (symm)
- per-token (dynamic, runtime)
- reduced range (e.g. [-127, 127])

You do not see:
- activation_scale
- activation_zero_point
- calibration observers
Because they are computed on-the-fly.


<img src="attachments/quant_3.png" width="400">

#### model architecture

In [11]:
print(net)
print(q_net)

Classification(
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
)
Classification(
  (linear1): Linear(in_features=784, out_features=128, weight=LinearActivationQuantizedTensor(activation=<function _int8_symm_per_token_reduced_range_quant at 0x167401f30>, weight=AffineQuantizedTensor(shape=torch.Size([128, 784]), block_size=(1, 784), device=mps:0, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=None, quant_max=None)))
  (linear2): Linear(in_features=128, out_features=128, weight=LinearActivationQuantizedTensor(activation=<function _int8_symm_per_token_reduced_range_quant at 0x167401f30>, weight=AffineQuantizedTensor(shape=torch.Size([128, 128]), block_size=(1, 128), device=mps:0, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=None, quant_max=None)))
  (linear3): Linear(in_features=128, out_fea

#### original weight & bias

In [12]:
print(net.linear1.weight)
print(net.linear1.bias)

Parameter containing:
tensor([[ 0.0389,  0.0412,  0.0032,  ..., -0.0026,  0.0209,  0.0251],
        [-0.0285, -0.0451,  0.0090,  ..., -0.0203, -0.0098,  0.0018],
        [ 0.0225,  0.0249,  0.0278,  ...,  0.0106,  0.0044,  0.0291],
        ...,
        [-0.0007,  0.0222,  0.0490,  ...,  0.0158, -0.0129,  0.0316],
        [ 0.0130,  0.0458,  0.0080,  ...,  0.0117,  0.0280,  0.0555],
        [-0.0136,  0.0313,  0.0211,  ...,  0.0062,  0.0145,  0.0012]],
       device='mps:0', requires_grad=True)
Parameter containing:
tensor([-0.0125,  0.0362,  0.0010, -0.0220, -0.0636, -0.0202,  0.0311, -0.0310,
        -0.0279, -0.0124,  0.0104,  0.0056,  0.0190, -0.0433, -0.0403, -0.0144,
        -0.0422, -0.0489,  0.0070,  0.0476,  0.0027, -0.0185,  0.0521, -0.0434,
        -0.0529, -0.0260, -0.0357, -0.0259, -0.0414, -0.0560,  0.0050, -0.0021,
        -0.0013, -0.0516, -0.0169,  0.0261, -0.0083,  0.0039,  0.0015, -0.0128,
         0.0150,  0.0152, -0.0263,  0.0036, -0.0391,  0.0013,  0.0054, -0.0354,

#### quantized weight

In [13]:
print(q_net.linear1.weight)
print(q_net.linear1.bias)
# bias is not quantized since it is usually int32 precision to match gpu accumulator
# see last part in https://youtu.be/0VdNflU08yA?si=zTiAwwZ3W1YT-ruG

LinearActivationQuantizedTensor(AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[ 13,  14,   1,  ...,  -1,   7,   9],
        [-12, -19,   4,  ...,  -9,  -4,   1],
        [ 10,  11,  13,  ...,   5,   2,  13],
        ...,
        [  0,   5,  11,  ...,   4,  -3,   7],
        [  5,  16,   3,  ...,   4,  10,  20],
        [ -6,  14,   9,  ...,   3,   6,   1]], device='mps:0',
       dtype=torch.int8)... , scale=tensor([0.0029, 0.0024, 0.0022, 0.0013, 0.0024, 0.0016, 0.0021, 0.0014, 0.0020,
        0.0016, 0.0019, 0.0020, 0.0027, 0.0004, 0.0015, 0.0020, 0.0012, 0.0027,
        0.0022, 0.0019, 0.0021, 0.0020, 0.0023, 0.0006, 0.0015, 0.0022, 0.0004,
        0.0031, 0.0003, 0.0020, 0.0023, 0.0029, 0.0021, 0.0005, 0.0017, 0.0022,
        0.0026, 0.0040, 0.0014, 0.0034, 0.0019, 0.0017, 0.0020, 0.0026, 0.0005,
        0.0018, 0.0004, 0.0038, 0.0017, 0.0018, 0.0021, 0.0028, 0.0021, 0.0019,
        0.0026, 0.0017, 0.0019, 0.0033, 0.0013, 0.0016, 0.0022, 0.0021, 0.0007,
        

#### dequantised outputs

In [14]:
def activation_hook(name):
    def hook(module, inp, out):
        print(out.dtype)
        print(out)
    return hook

q_net.eval()
hook_handle = q_net.linear1.register_forward_hook(
    activation_hook("linear3")
)

iterations = 1
for idx, (x, y) in enumerate(tqdm(train_loader, desc='Inspection')):
    if idx >= iterations:
        break
    x = x.to(device)
    y = y.to(device)
    q_net(x)

hook_handle.remove()

Inspection:   0%|          | 1/6000 [00:00<10:59,  9.09it/s]

torch.float32
tensor([[ -3.1619,  -5.2125,  -2.5498,  ...,  -4.8291, -12.8415,  -3.8565],
        [ -9.3881,  -6.7253, -16.2016,  ..., -13.1316, -18.7740,  -3.7072],
        [ 10.8166,   2.4249,  10.6563,  ...,  -1.0033,  -3.9378,  -4.4136],
        ...,
        [-15.1940,  -0.6169,  -7.4738,  ...,  -7.8985,  -1.1337,  -6.4681],
        [ 13.4286,  -2.4671,  10.0608,  ...,  -1.0275,  -6.8543,  -7.2977],
        [ -4.1664,   8.4670,  -9.4396,  ..., -10.8934, -13.6064, -12.9925]],
       device='mps:0', grad_fn=<AsStridedBackward0>)





# Comparisons

In [15]:
print("Test Evaluation on Original Model")
test(net)
print("Test Evaluation on Quantised Model")
test(q_net)

Test Evaluation on Original Model


Testing: 100%|██████████| 1000/1000 [00:05<00:00, 174.06it/s]


Accuracy: 0.96
Test Evaluation on Quantised Model


Testing: 100%|██████████| 1000/1000 [00:39<00:00, 25.21it/s]

Accuracy: 0.96





In [16]:
print("Size of Original Model")
print_size_of_model(net)
print("Size of Quantized Model")
print_size_of_model(q_net)

Size of Original Model
Size (KB): 476.153
Size of Quantized Model
Size (KB): 125.366
