# Quantization Aware Training
In this notebook, we will implementing one of the modes of quantization called Quantization Aware Training using PyTorch.

## Import

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import os

from tqdm import tqdm
from pathlib import Path

## Load the FashionMNIST Dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the FashionMNIST dataset
fmnist_trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(fmnist_trainset, batch_size=16, shuffle=True)

# Load the FashionMNIST test set
fmnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(fmnist_trainset, batch_size=16, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Define the Model

In [4]:
class FashionNet(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=256):
        super(FashionNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

model = FashionNet().to(device)
print(model)

FashionNet(
  (quant): QuantStub()
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=10, bias=True)
  (relu): ReLU()
  (dequant): DeQuantStub()
)


## Insert min-max observers in the model

In [5]:
model.qconfig = torch.ao.quantization.default_qconfig
model.train()
model_quantized = torch.ao.quantization.prepare_qat(model) # Insert observers
model_quantized

FashionNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=128, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=128, out_features=256, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=256, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

## Train the Model

In [6]:
def train(train_loader, model, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit

        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

In [7]:
train(train_loader, model_quantized, epochs=10)

Epoch 1: 100%|██████████| 3750/3750 [00:54<00:00, 68.92it/s, loss=0.455]
Epoch 2: 100%|██████████| 3750/3750 [00:49<00:00, 75.44it/s, loss=0.354]
Epoch 3: 100%|██████████| 3750/3750 [00:50<00:00, 74.56it/s, loss=0.322]
Epoch 4: 100%|██████████| 3750/3750 [00:49<00:00, 75.42it/s, loss=0.296]
Epoch 5: 100%|██████████| 3750/3750 [00:50<00:00, 74.96it/s, loss=0.281]
Epoch 6: 100%|██████████| 3750/3750 [00:50<00:00, 73.80it/s, loss=0.268]
Epoch 7: 100%|██████████| 3750/3750 [00:50<00:00, 73.92it/s, loss=0.259]
Epoch 8: 100%|██████████| 3750/3750 [00:49<00:00, 75.41it/s, loss=0.247]
Epoch 9: 100%|██████████| 3750/3750 [00:48<00:00, 76.65it/s, loss=0.238]
Epoch 10: 100%|██████████| 3750/3750 [00:49<00:00, 75.03it/s, loss=0.231]


## Testing

In [8]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

## Check the collected statistics during training

In [9]:
print(f'Check statistics of the various layers')
model_quantized

Check statistics of the various layers


FashionNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=128, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-2.0463104248046875, max_val=1.1904397010803223)
    (activation_post_process): MinMaxObserver(min_val=-130.43992614746094, max_val=71.39286804199219)
  )
  (linear2): Linear(
    in_features=128, out_features=256, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.7236308455467224, max_val=0.6895429491996765)
    (activation_post_process): MinMaxObserver(min_val=-95.31771850585938, max_val=73.64185333251953)
  )
  (linear3): Linear(
    in_features=256, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-1.617387056350708, max_val=0.4518842399120331)
    (activation_post_process): MinMaxObserver(min_val=-418.06884765625, max_val=68.2452392578125)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

##Quantize the model using the statistics collected

In [10]:
model_quantized.eval()
model_quantized = torch.ao.quantization.convert(model_quantized)

In [11]:
print(f'Check statistics of the various layers')
model_quantized

Check statistics of the various layers


FashionNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=128, scale=1.5892345905303955, zero_point=82, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=128, out_features=256, scale=1.330390214920044, zero_point=72, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=256, out_features=10, scale=3.82924485206604, zero_point=109, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

##Print weights and size of the model after quantization

In [12]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(model_quantized.linear1.weight()))

Weights after quantization
tensor([[  6,   7,   4,  ...,   7,   5,   6],
        [  3,   4,   4,  ...,   9, -11,   2],
        [  5,   7,   4,  ...,   5,   7,   7],
        ...,
        [ -1,  -4,   1,  ...,  -7,  -3,   0],
        [  5,   5,   6,  ...,  -5,  -7,   3],
        [ 10,   8,   7,  ...,  11,   9,   7]], dtype=torch.int8)


In [13]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

In [14]:
print('Size of the model after quantization')
print_size_of_model(model_quantized)

Size of the model after quantization
Size (KB): 142.498


In [15]:
print(f'Accuracy of the model after quantization: ')
test(model_quantized)

Accuracy of the model after quantization: 


Testing: 100%|██████████| 3750/3750 [00:20<00:00, 181.68it/s]

Accuracy: 0.862



