# Mentium Take Home Test

## Challenge Specification

1. Create a simple convolutional network (with 3-5 total layers) to classify MNIST dataset.
2. Set up the training and validation flow and train the network on the MNIST dataset.
3. Calculate the network classification performance on the validation and training datasets.
4. Calculate the network size in Kbytes.
5. Perform an 8-bit post-training quantization of the trained network and recalculate the network
size in KB and its classification performance.
6. Perform an 8-bit quantization-aware training on the original network and recalculate the network
size in KB and its classification performance after the network is trained.

## Considerations and Constraints

1. I am opting to use ready-made library to handle the dataset.
2. I am opting to use PyTorch as required to build the model.
3. All quantization will be done using PyTorch quantization tools only.

In [1]:
# Imports

import os
import json
import torch
import torch.nn as nn

from torchvision.datasets import MNIST
from torchvision import transforms
from sklearn.model_selection import KFold
import numpy as np

import torch.utils
import torch.utils.data

from tqdm import tqdm

import torchinfo

# Quantization
import torch.ao.quantization

## Model Architecture

In [2]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.relu1 = nn.ReLU()
        
        self.maxpool_pad = nn.MaxPool2d(2, padding=1)
        
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.relu2 = nn.ReLU()
        
        self.maxpool = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.relu3 = nn.ReLU()
        
        self.fc1 = nn.Linear(128 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        
        x = self.maxpool_pad(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        
        x = self.maxpool_pad(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        
        x = self.maxpool(x)
        
        # print(x.shape)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x
    
    def predict(self, x):
        with torch.no_grad():
            return torch.argmax(self.forward(x), 1)

## Data Preparation

In [3]:
transforms_ = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(root='data', download=True, transform=transforms_)
# dataset_train, dataset_test = MNIST(root='data', download=True, train=True), MNIST(root='data', download=True, train=False)



In [4]:
# Cross Validation Split
kfold = KFold(n_splits=2, shuffle=True)

## Training Loop

In [5]:
# Globals to store the best model
best_params = {
    'model': None,
    'accuracy': 0,
    'epoch': 0,
    'fold': 0,
    'training_acc': None,
    'val_acc': None
}

In [7]:
# Updated train and test functions with kfold

import torch.utils.data.dataloader

if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')

def reset_model_params(model):
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

def train(model, train_loader, loss_fn, device):
    model.train()
    model.to(device)
    for i, data in enumerate(train_loader):
            X, y = data
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            loss_val = loss.item()
            
            if i % 1000 == 0:
                print(f"Loss: {loss_val}")
                
    return loss_val
        
def validate(model, test_loader, loss_fn, device):
    model.eval()
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        y_pred = model.forward(X)
        loss = loss_fn(y_pred, y)
    val_loss = loss.item()
    return val_loss

def test_acc(model, test_loader, device):
    # Validation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            y_pred = model.predict(X)
            correct += torch.sum(y_pred == y).item()
            total += y.size(0)
        accuracy = correct / total
        print(f"Accuracy: {accuracy}")
        
    return accuracy

NUM_EPOCHS = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for fold, (train_idx, test_idx) in enumerate(kfold.split(dataset)):
    print(f"Fold {fold}")
    sample_train, sample_test = torch.utils.data.SubsetRandomSampler(train_idx), torch.utils.data.SubsetRandomSampler(test_idx)
    
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, sampler=sample_train)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=64, sampler=sample_test)
    
    model = MNISTClassifier()
    if torch.cuda.is_available():
        model.cuda()
    model.apply(reset_model_params)
    
    loss_fn = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch}")
        loss_val = train(model, train_loader, loss_fn, device=device)
        val_loss = validate(model, test_loader, loss_fn=loss_fn, device=device)
        val_acc = test_acc(model, test_loader, device=device)
      
        if best_params['val_acc'] is None or val_acc > best_params['val_acc'] :
            best_params['model'] = model
            best_params['epoch'] = epoch
            best_params['fold'] = fold
            best_params['training_acc'] = loss_val
            best_params['val_acc'] = val_acc  

print("Best Model : ")
print(json.dumps({k:best_params[k] for k in best_params.keys() if 'model' not in k}, indent=4))

# Save the best model
torch.save(best_params['model'].state_dict(), f"checkpoints/best_model.pth")

Fold 0
Epoch 0


Loss: -0.0995897501707077
Accuracy: 0.4885333333333333
Fold 1
Epoch 0
Loss: -0.10104354470968246
Accuracy: 0.5833
Best Model : 
{
    "accuracy": 0,
    "epoch": 0,
    "fold": 1,
    "training_acc": -0.602523684501648,
    "val_acc": 0.5833
}


## Network Size Calculation

1. The network size will be calculated using the `torchsummary` library. While this is not extremely difficult to implement, I am opting to use the library to save time.
2. The network size will be calculated in KBytes.


In [8]:
import torchinfo

summary = torchinfo.summary(model, input_size = (1, 1, 28, 28))

# network_fp32_size = summary.

In [9]:
summary

Layer (type:depth-idx)                   Output Shape              Param #
MNISTClassifier                          [1, 10]                   --
├─Conv2d: 1-1                            [1, 32, 28, 28]           320
├─ReLU: 1-2                              [1, 32, 28, 28]           --
├─MaxPool2d: 1-3                         [1, 32, 15, 15]           --
├─Conv2d: 1-4                            [1, 64, 15, 15]           18,496
├─ReLU: 1-5                              [1, 64, 15, 15]           --
├─MaxPool2d: 1-6                         [1, 64, 8, 8]             --
├─Conv2d: 1-7                            [1, 128, 8, 8]            73,856
├─ReLU: 1-8                              [1, 128, 8, 8]            --
├─MaxPool2d: 1-9                         [1, 128, 4, 4]            --
├─Linear: 1-10                           [1, 128]                  262,272
├─Linear: 1-11                           [1, 10]                   1,290
├─Softmax: 1-12                          [1, 10]                   -

In [10]:
print("Model file size : ", os.path.getsize('checkpoints/model_0_0.pth') / 1024, "Kbytes")

Model file size :  1395.39453125 Kbytes


## Quantization

### Post-Training Quantization

In [11]:
# The model needs to be redeffined and loaded from the checkpoint

In [12]:
import torch.ao.quantization

# NOTE: This is not working as expected due to pytorch expecting the dequant to precede the softmax layer 
# class MNISTClassificationPostTrainingStatic(MNISTClassifier):
#     def __init__(self, model_file):
#         super(MNISTClassificationPostTrainingStatic, self).__init__()
#         self.quant = torch.ao.quantization.QuantStub()
#         self.dequant = torch.ao.quantization.DeQuantStub()
#         self.model = MNISTClassifier()
#         self.model.load_state_dict(torch.load(model_file))
        
#     def forward(self, x):
#         x = self.quant(x)
#         x = self.model(x)
#         x = self.dequant(x)
#         return x
    
#     def predict(self, x):
#         with torch.no_grad():
#             return torch.argmax(self.forward(x), 1)

class MNISTClassificationQuantize(nn.Module):
    def __init__(self, weight_path) -> None:
        super(MNISTClassificationQuantize, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.relu1 = nn.ReLU() # Cannot use the same relu layer as the quantization expects the relu to be fused with the conv layer
        
        self.maxpool_pad = nn.MaxPool2d(2, padding=1)
        
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.relu2 = nn.ReLU()
        
        self.maxpool = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.relu3 = nn.ReLU()
        
        self.fc1 = nn.Linear(128 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)
        
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
        
        self.load_state_dict(torch.load(weight_path))
        
    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.relu1(x)
        
        x = self.maxpool_pad(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        
        x = self.maxpool_pad(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        
        x = self.maxpool(x)
        
        # print(x.shape)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.dequant(x)
        x = self.softmax(x)
        return x
    
    def predict(self, x):
        with torch.no_grad():
            return torch.argmax(self.forward(x), 1)
        
model_fp32 = MNISTClassificationQuantize('checkpoints/model_0_0.pth')

model_fp32.eval()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_fp32_fused = torch.ao.quantization.fuse_modules(
    model_fp32, [
        ['conv1', 'relu1'], 
        ['conv2', 'relu2'], 
        ['conv3', 'relu3']
        ])
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

inp_fp32 = torch.randn(1, 1, 28, 28)

model_fp32_prepared(inp_fp32)

model_fp32_quantized = torch.ao.quantization.convert(model_fp32_prepared)

res = model_fp32_quantized(inp_fp32)





In [13]:
torch.onnx.export(
    model_fp32,
    # torch.randn(1, 1, 28, 28).cuda() if torch.cuda.is_available() else torch.randn(1, 1, 28, 28),
    torch.randn(1, 1, 28, 28),
    'model_fp32_additional.onnx',
    export_params=True,
)

In [14]:
torch.onnx.export(
    model_fp32_quantized,
    torch.randn(1, 1, 28, 28),
    'model_fp32_quantized.onnx',
    export_params=True,
)

## QAT - Quantization Aware Training

As with the post-training quantization, I will use the PyTorch quantization tools to perform the quantization aware training.

In [17]:
# Quantization Aware Training

model_fp32 = MNISTClassificationQuantize('checkpoints/best_model.pth')
model_fp32.train()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_fp32_fused = torch.ao.quantization.fuse_modules(
    model_fp32, [
        ['conv1', 'relu1'], 
        ['conv2', 'relu2'], 
        ['conv3', 'relu3']
        ])
model_fp32_qat_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused)

QAT_FINE_TUNE_EPOCHS = 1

qat_best_params = {
    'model': None,
    'accuracy': 0,
    'epoch': 0,
    'fold': 0,
    'training_acc': None,
    'val_acc': None
}

qat_device = torch.device('cpu') # Choosing CPU for now

# K Fold not needed as we are using the best model
for epoch in range(QAT_FINE_TUNE_EPOCHS):
    print(f"QAT Finetune Epoch : {epoch}")
    loss_val = train(model_fp32_qat_prepared, train_loader, loss_fn, device=qat_device)
    val_loss = validate(model_fp32_qat_prepared, test_loader, loss_fn=loss_fn, device=qat_device)
    val_acc = test_acc(model_fp32_qat_prepared, test_loader, device=qat_device)
    
    if qat_best_params['val_acc'] is None or val_acc > qat_best_params['val_acc'] :
        qat_best_params['model'] = model_fp32_qat_prepared
        qat_best_params['epoch'] = epoch
        qat_best_params['fold'] = fold
        qat_best_params['training_acc'] = loss_val
        qat_best_params['val_acc'] = val_acc
        
print("Best Model : ")
print(json.dumps({k:qat_best_params[k] for k in qat_best_params.keys() if 'model' not in k}, indent=4))

# Save the best model
torch.save(qat_best_params['model'].state_dict(), f"checkpoints/qat_best_model.pth")

model_fp32_qat_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_qat_prepared)

inp_fp32 = torch.randn(1, 1, 28, 28)
res = model_int8(inp_fp32)

# Export ONNX
torch.onnx.export(
    model_int8,
    torch.randn(1, 1, 28, 28),
    'model_int8.onnx',
    export_params=True,
)   



QAT Finetune Epoch : 0
Loss: -0.568603515625
Accuracy: 0.5833666666666667
Best Model : 
{
    "accuracy": 0,
    "epoch": 0,
    "fold": 1,
    "training_acc": -0.6180827021598816,
    "val_acc": 0.5833666666666667
}


In [21]:
# import torchinfo
# torchinfo.summary(model_fp32_quantized, input_size=(1, 1, 28, 28))

# TODO: Calculate the size of the model

def calculate_model_size(model):
    model_size = 0
    for param in model.parameters():
        print(param.numel(), param.element_size())
        model_size += param.numel() * param.element_size()
    return model_size


# model_fp32_size = calculate_model_size(model_fp32)
# model_fp32_quantized_size = calculate_model_size(model_fp32_quantized)
model_int8_size = calculate_model_size(model_int8)

In [22]:
dir(model_int8)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_se

In [20]:
# print("Model FP32 Size : ", model_fp32_size / 1024, "Kbytes")
# print("Model FP32 Quantized Size : ", model_fp32_quantized_size / 1024, "Kbytes")
print("Model INT8 Size : ", model_int8_size / 1024, "Kbytes")

Model FP32 Size :  1391.5390625 Kbytes
Model INT8 Size :  0.0 Kbytes
