# Post Training Quantization
In this notebook, we will implementing one of the modes of quantization called Post Training Quantization using PyTorch.

## Import

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import os

from tqdm import tqdm
from pathlib import Path

## Load the FashionMNIST Dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the FashionMNIST dataset
fmnist_trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(fmnist_trainset, batch_size=16, shuffle=True)

# Load the FashionMNIST test set
fmnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(fmnist_trainset, batch_size=16, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7688938.22it/s] 


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 139658.40it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2511678.10it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 18743296.00it/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 89183050.56it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 75192857.74it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 24696436.87it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4263770.99it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






## Define the Model

In [4]:
class FashionNet(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=256):
        super(FashionNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [5]:
model = FashionNet().to(device)
print(model)

FashionNet(
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=10, bias=True)
  (relu): ReLU()
)


## Train the Model

In [6]:
def train(train_loader, model, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit

        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

In [7]:
MODEL_FILENAME = 'fashion_net.pt'

if Path(MODEL_FILENAME).exists():
    model.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, model, epochs=10)
    # Save the model to disk
    torch.save(model.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|██████████| 3750/3750 [00:45<00:00, 82.17it/s, loss=0.458]
Epoch 2: 100%|██████████| 3750/3750 [00:47<00:00, 79.21it/s, loss=0.355]
Epoch 3: 100%|██████████| 3750/3750 [00:45<00:00, 82.26it/s, loss=0.32]
Epoch 4: 100%|██████████| 3750/3750 [00:46<00:00, 80.21it/s, loss=0.301]
Epoch 5: 100%|██████████| 3750/3750 [00:45<00:00, 81.80it/s, loss=0.284]
Epoch 6: 100%|██████████| 3750/3750 [00:46<00:00, 80.45it/s, loss=0.27]
Epoch 7: 100%|██████████| 3750/3750 [00:46<00:00, 80.02it/s, loss=0.258]
Epoch 8: 100%|██████████| 3750/3750 [00:47<00:00, 79.78it/s, loss=0.247]
Epoch 9: 100%|██████████| 3750/3750 [00:45<00:00, 81.54it/s, loss=0.241]
Epoch 10: 100%|██████████| 3750/3750 [00:46<00:00, 79.82it/s, loss=0.232]


## Testing

In [8]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

In [9]:
test(model)

Testing: 100%|██████████| 3750/3750 [00:18<00:00, 200.10it/s]

Accuracy: 0.922





##Print weights and size of the model before quantization

In [10]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(model.linear1.weight)
print(model.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0755,  0.0828, -0.0091,  ...,  0.0536,  0.0798,  0.0720],
        [-0.0101, -0.0183, -0.0618,  ..., -0.1195, -0.0224, -0.0202],
        [ 0.0850,  0.1174,  0.0451,  ...,  0.0233,  0.0333,  0.0921],
        ...,
        [ 0.0236, -0.0083,  0.0074,  ..., -0.0738, -0.0481,  0.0258],
        [ 0.1485,  0.1255,  0.0529,  ...,  0.1001,  0.1012,  0.1476],
        [ 0.0257,  0.0074,  0.0256,  ...,  0.1097, -0.0216,  0.0140]],
       requires_grad=True)
torch.float32


In [11]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

In [12]:
print('Size of the model before quantization')
print_size_of_model(model)

Size of the model before quantization
Size (KB): 546.918


In [13]:
print(f'Accuracy of the model before quantization: ')
test(model)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 3750/3750 [00:19<00:00, 191.87it/s]

Accuracy: 0.922





## Insert min-max observers in the model

In [14]:
class QuantizedFashionNet(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=256):
        super(QuantizedFashionNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [15]:
model_quantized = QuantizedFashionNet().to(device)

In [16]:
# Copy weights from unquantized model
model_quantized.load_state_dict(model.state_dict())
model_quantized.eval()

model_quantized.qconfig = torch.ao.quantization.default_qconfig
model_quantized = torch.ao.quantization.prepare(model_quantized) # Insert observers
model_quantized

QuantizedFashionNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=128, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=128, out_features=256, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=256, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

## Calibrate the model using the test set

In [17]:
test(model_quantized)

Testing: 100%|██████████| 3750/3750 [00:20<00:00, 187.45it/s]

Accuracy: 0.922





In [18]:
print(f'Check statistics of the various layers')
model_quantized

Check statistics of the various layers


QuantizedFashionNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=128, bias=True
    (activation_post_process): MinMaxObserver(min_val=-119.12142944335938, max_val=73.69100189208984)
  )
  (linear2): Linear(
    in_features=128, out_features=256, bias=True
    (activation_post_process): MinMaxObserver(min_val=-106.03083038330078, max_val=78.91856384277344)
  )
  (linear3): Linear(
    in_features=256, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-263.02685546875, max_val=49.16145324707031)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

## Quantize the model using the statistics collected

In [19]:
model_quantized = torch.ao.quantization.convert(model_quantized)

In [20]:
print(f'Check statistics of the various layers')
model_quantized

Check statistics of the various layers


QuantizedFashionNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=128, scale=1.5182081460952759, zero_point=78, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=128, out_features=256, scale=1.4562945365905762, zero_point=73, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=256, out_features=10, scale=2.4581754207611084, zero_point=107, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [21]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(model_quantized.linear1.weight()))

Weights after quantization
tensor([[ 5,  6, -1,  ...,  4,  5,  5],
        [-1, -1, -4,  ..., -8, -2, -1],
        [ 6,  8,  3,  ...,  2,  2,  6],
        ...,
        [ 2, -1,  0,  ..., -5, -3,  2],
        [10,  8,  4,  ...,  7,  7, 10],
        [ 2,  0,  2,  ...,  7, -1,  1]], dtype=torch.int8)


## Compare the dequantized weights and the original weights

In [22]:
print('Original weights: ')
print(model.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(model_quantized.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[ 0.0755,  0.0828, -0.0091,  ...,  0.0536,  0.0798,  0.0720],
        [-0.0101, -0.0183, -0.0618,  ..., -0.1195, -0.0224, -0.0202],
        [ 0.0850,  0.1174,  0.0451,  ...,  0.0233,  0.0333,  0.0921],
        ...,
        [ 0.0236, -0.0083,  0.0074,  ..., -0.0738, -0.0481,  0.0258],
        [ 0.1485,  0.1255,  0.0529,  ...,  0.1001,  0.1012,  0.1476],
        [ 0.0257,  0.0074,  0.0256,  ...,  0.1097, -0.0216,  0.0140]],
       requires_grad=True)

Dequantized weights: 
tensor([[ 0.0744,  0.0893, -0.0149,  ...,  0.0595,  0.0744,  0.0744],
        [-0.0149, -0.0149, -0.0595,  ..., -0.1191, -0.0298, -0.0149],
        [ 0.0893,  0.1191,  0.0447,  ...,  0.0298,  0.0298,  0.0893],
        ...,
        [ 0.0298, -0.0149,  0.0000,  ..., -0.0744, -0.0447,  0.0298],
        [ 0.1489,  0.1191,  0.0595,  ...,  0.1042,  0.1042,  0.1489],
        [ 0.0298,  0.0000,  0.0298,  ...,  0.1042, -0.0149,  0.0149]])



## Print size and accuracy of the quantized model

In [23]:
print('Size of the model after quantization')
print_size_of_model(model_quantized)

Size of the model after quantization
Size (KB): 142.498


In [24]:
print('Testing the model after quantization')
test(model_quantized)

Testing the model after quantization


Testing: 100%|██████████| 3750/3750 [00:19<00:00, 191.95it/s]

Accuracy: 0.901



