<a href="https://colab.research.google.com/github/twhool02/atubigdataanalyticsproject1/blob/main/umar_jamil_quantization_aware_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import the necessary libraries

In [3]:
# torch is the main PyTorch library. It provides tensor computation and gradients,
# along with functionality for deep neural networks
import torch

# torchvision.datasets is a part of PyTorch. It contains many popular datasets.
# This allows you to load data in a handy way.
import torchvision.datasets as datasets

# torchvision.transforms provides common image transformations.
# This can be chained together using torchvision.transforms.Compose
import torchvision.transforms as transforms

# torch.nn provides necessary functions for creating neural networks.
# It also provides predefined layers that can be used for building complex architectures.
import torch.nn as nn

# matplotlib.pyplot is a collection of functions that helps in creating a variety of charts.
# It is generally used for data visualization.
import matplotlib.pyplot as plt

# tqdm is a Python library that allows you to output a smart progress bar by wrapping around any iterable.
from tqdm import tqdm

# pathlib.Path offers classes representing filesystem paths with semantics appropriate for different operating systems.
from pathlib import Path

# os is a module in python which provides functions for interacting with the operating system.
import os

# Load the MNIST dataset

In [4]:
# Make torch deterministic
# Setting seed to a fixed value ensuring random numbers generated by PyTorch are reproducible
_ = torch.manual_seed(0)

In [5]:
# transforms.Compose creates a series of transformation to prepare the dataset.
# transforms.ToTensor() converts a PIL Image or numpy.ndarray to tensor.
# transforms.Normalize((0.1307,), (0.3081,)) normalizes a tensor image with mean 0.1307 and standard deviation 0.3081.
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset for training.
# root='./data' is the path where the train/test data is stored.
# train=True signifies that we are loading training data.
# download=True will download the data if it's not already present at the specified location.
# transform=transform applies the defined transformations on the data.
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# DataLoader wraps the dataset, and provides mini batches of the data.
# batch_size=10 means DataLoader will take 10 samples at a time to load.
# shuffle=True shuffles the data so the model gets data in a random manner.
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set with the same parameters as the training set, except train=False to signify that we are loading testing data.
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device that we will be using for training. "cpu" means we are using the central processing unit.
device = "cpu"

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 179042044.53it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 18281873.50it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 43512044.62it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 12172861.83it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# Define the model

In [6]:
# Define a very simple neural network class that inherits from nn.Module
class VerySimpleNet(nn.Module):
    # Initialize the network layers
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        # Call the parent class's initialization
        super(VerySimpleNet,self).__init__()
        # QuantStub is used to mark the point in the model where the quantization process starts
        self.quant = torch.quantization.QuantStub()
        # Define the first linear layer with 28*28 input nodes (size of an MNIST image when flattened) and hidden_size_1 output nodes
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        # Define the second linear layer with hidden_size_1 input nodes and hidden_size_2 output nodes
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        # Define the third linear layer with hidden_size_2 input nodes and 10 output nodes (for the 10 classes of digits)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        # Define the activation function to be used between layers
        self.relu = nn.ReLU()
        # DeQuantStub is used to mark the point in the model where the quantization process ends
        self.dequant = torch.quantization.DeQuantStub()

    # Define the forward pass
    def forward(self, img):
        # Flatten the image tensor
        x = img.view(-1, 28*28)
        # Start the quantization process
        x = self.quant(x)
        # Pass the input through the first linear layer and then through the ReLU activation function
        x = self.relu(self.linear1(x))
        # Pass the output from the previous layer through the second linear layer and then through the ReLU activation function
        x = self.relu(self.linear2(x))
        # Pass the output from the previous layer through the third linear layer
        x = self.linear3(x)
        # End the quantization process
        x = self.dequant(x)
        # Return the output
        return x

# Instantiate the network and move it to the device (CPU in this case)
net = VerySimpleNet().to(device)


# Insert min-max observers in the model

In [7]:
# Set the quantization configuration for the network to the default quantization configuration
net.qconfig = torch.ao.quantization.default_qconfig

# Set the network to training mode. This has any effect only on certain modules like Dropout and BatchNorm.
net.train()

# Prepare the network for Quantization-Aware Training (QAT), which inserts observers in the model that observe
# weight and activation statistics during calibration. These statistics are used to determine quantization parameters
# for the subsequent quantize step that happens during deployment.
net_quantized = torch.ao.quantization.prepare_qat(net)

# Print the quantized network
net_quantized


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Train the model

In [8]:
# Define a function to train the network
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    # CrossEntropyLoss combines LogSoftmax and NLLLoss in one single class, useful for multi-class classification problems
    cross_el = nn.CrossEntropyLoss()
    # Adam is an optimization algorithm that can be used instead of the classical stochastic gradient descent to update network weights iteratively
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    # Loop over the dataset multiple times
    for epoch in range(epochs):
        # Set the network to training mode
        net.train()

        loss_sum = 0
        num_iterations = 0

        # tqdm creates a progress bar
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        # Loop over the data
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            # Get the inputs and labels
            x, y = data
            # Move the inputs and labels to the device
            x = x.to(device)
            y = y.to(device)
            # Zero the parameter gradients
            optimizer.zero_grad()
            # Forward pass
            output = net(x.view(-1, 28*28))
            # Compute the loss
            loss = cross_el(output, y)
            # Accumulate the loss
            loss_sum += loss.item()
            # Compute the average loss
            avg_loss = loss_sum / num_iterations
            # Update the progress bar
            data_iterator.set_postfix(loss=avg_loss)
            # Backward pass
            loss.backward()
            # Optimize
            optimizer.step()

            # Stop training if the total number of iterations reaches the limit
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

# Define a function to print the size of the model
def print_size_of_model(model):
    # Save the model to a file
    torch.save(model.state_dict(), "temp_delme.p")
    # Print the size of the file
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    # Remove the file
    os.remove('temp_delme.p')

# Train the network for one epoch
train(train_loader, net_quantized, epochs=1)


Epoch 1: 100%|██████████| 6000/6000 [01:07<00:00, 88.41it/s, loss=0.224]


# Define the testing loop

In [9]:
# Define a function to test the model
def test(model: nn.Module, total_iterations: int = None):
    # Initialize the number of correct predictions and the total number of predictions
    correct = 0
    total = 0

    # Initialize the number of iterations
    iterations = 0

    # Set the model to evaluation mode. This has any effect only on certain modules like Dropout and BatchNorm.
    model.eval()

    # torch.no_grad() impacts the autograd engine and deactivates it. It reduces memory usage and speeds up computations
    with torch.no_grad():
        # Loop over the test data
        for data in tqdm(test_loader, desc='Testing'):
            # Get the inputs and labels
            x, y = data
            # Move the inputs and labels to the device
            x = x.to(device)
            y = y.to(device)
            # Compute the model output
            output = model(x.view(-1, 784))
            # Loop over the output
            for idx, i in enumerate(output):
                # If the model's prediction matches the true label, increment the number of correct predictions
                if torch.argmax(i) == y[idx]:
                    correct +=1
                # Increment the total number of predictions
                total +=1
            # Increment the number of iterations
            iterations += 1
            # If the total number of iterations reaches the limit, break the loop
            if total_iterations is not None and iterations >= total_iterations:
                break
    # Print the accuracy of the model
    print(f'Accuracy: {round(correct/total, 3)}')


# Check the collected statistics during training

In [10]:
# Print a statement to indicate that we are checking the statistics of the various layers
print(f'Check statistics of the various layers')

# net_quantized is our model that has been prepared for Quantization-Aware Training (QAT).
# When we print net_quantized, it will display the architecture of the model along with the quantization parameters.
# These parameters include scale and zero_point for each layer, which are used to quantize the floating point values to int8/uint8.
# The printout will help us understand the structure of our model and how quantization is applied to each layer.
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.4810347855091095, max_val=0.3415062427520752)
    (activation_post_process): MinMaxObserver(min_val=-40.239410400390625, max_val=39.7955436706543)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.43307897448539734, max_val=0.3635982275009155)
    (activation_post_process): MinMaxObserver(min_val=-41.65215301513672, max_val=22.231578826904297)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5131738185882568, max_val=0.20312127470970154)
    (activation_post_process): MinMaxObserver(min_val=-32.679283142089844, max_val=22.15723419189453)
  )
  (relu): ReLU()
  (dequant): DeQuantS

# Quantize the model using the statistics collected

In [11]:
# Set the model to evaluation mode. This has any effect only on certain modules like Dropout and BatchNorm.
net_quantized.eval()

# Convert the model to a quantized version. This will replace the QuantStub and DeQuantStub modules with
# actual quantization functions. It also replaces specified modules (like nn.Linear) with their quantized counterparts.
# The model is now ready for deployment.
net_quantized = torch.ao.quantization.convert(net_quantized)

In [13]:
# Print a statement to indicate that we are checking the statistics of the various layers
print(f'Check statistics of the various layers')

# net_quantized is our model that has been converted for quantization.
# When we print net_quantized, it will display the architecture of the model along with the quantization parameters.
# These parameters include scale and zero_point for each layer, which are used to quantize the floating point values to int8/uint8.
# The printout will help us understand the structure of our model and how quantization is applied to each layer.
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6301965117454529, zero_point=64, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.5030215382575989, zero_point=83, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.4317835867404938, zero_point=76, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print weights and size of the model after quantization

In [15]:
# torch.int_repr returns a tensor with the same quantization scheme as self,
# but with values representing the integer equivalent of the floating point values in the original tensor.
# This function is typically used to inspect the integer values of a quantized tensor.

# net_quantized.linear1.weight() returns the weights of the first linear layer of the model.

# So, torch.int_repr(net_quantized.linear1.weight()) will return the integer representation of the weights
# of the first linear layer of the quantized model.

# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[  3,   8,  -5,  ...,   9,   4,   4],
        [ -5,  -4,  -3,  ...,  -5,  -2,  -8],
        [  0,   9,  -4,  ...,   0,   5,   7],
        ...,
        [  4,   5,  -4,  ...,  -5,   0, -10],
        [ -5,  -3,   6,  ...,   1,   1,   1],
        [  3,   2,  -2,  ...,   8,  -5,   0]], dtype=torch.int8)


In [16]:
# Print a statement to indicate that we are testing the model after quantization
print('Testing the model after quantization')

# Call the test function with the quantized model as the argument.
# The test function will evaluate the model's performance on the test data.
# It will loop over the test data, compute the model's output for each sample,
# compare the output with the true label, and count the number of correct predictions.
# At the end, it will print the accuracy of the model, which is the ratio of the number of correct predictions to the total number of predictions.
test(net_quantized)


Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:04<00:00, 227.58it/s]

Accuracy: 0.957



