In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn 
import matplotlib.pyplot as plt 
from tqdm import tqdm 
from pathlib import Path
import os 

# Load MNIST dataset

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load MNIST dataset and create dataloader for training
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # all training data loaded
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True) #Batches of 10 images will be loaded for each iteration during training.

# Load MNIST testset and create dataloader for testing
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

device = 'cpu'

# Define the model

In [6]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1 = 100, hidden_size_2 = 100):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1,hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x
    
net = VerySimpleNet().to(device)
net

VerySimpleNet(
  (quant): QuantStub()
  (linear1): Linear(in_features=784, out_features=100, bias=True)
  (linear2): Linear(in_features=100, out_features=100, bias=True)
  (linear3): Linear(in_features=100, out_features=10, bias=True)
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Insert min-max observer in the model

In [9]:
# set dynamic attribute (qconfig) for "net" model
net.qconfig = torch.ao.quantization.default_qconfig # define quantization configuration for net model
net.train() 
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observer
net_quantized

VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Train

In [22]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    # Define loss function and optimizer
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # updates weights and bias using gradient

    total_iterations = 0

    for epoch in range(epochs):
        net.train() # set model to training mode

        loss_sum = 0
        num_iterations = 0

        # create progress bar for visualizing training process
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}') # input: iterable & description/// output: decorated iterable
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit # set total num of iterations for the progress bar
        
        # Main kernel
        for data in data_iterator:
            #print(len(data_iterator)) # 6000
            num_iterations += 1
            total_iterations += 1
            #print(data)

            x, y = data # data = ((batchsize, 28,28),(batchsize))
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad() # initialize gradient as zero
            output = net(x.view(-1, 28*28)) # predict
            loss = cross_el(output, y) # calculate loss
            loss_sum += loss.item() 
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss) # update info displayed in the postfix of progress bar
            loss.backward() # calculate gradient 
            optimizer.step() # update weights and bias

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size of model: ', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

train(train_loader, net_quantized, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [01:33<00:00, 64.43it/s, loss=0.0894]


# Test

In [23]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    iterations = 0

    model.eval() # Evaluation mode

    with torch.no_grad():
        for data in tqdm(test_loader, desc=f"Tesing"):
            x,y = data 
            x = x.to(device)
            y = y.to(device)

            output = model(x.view(-1, 28*28)) # (batchsize)
            print(output)
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1

            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f"Accuracy: {round(correct/total, 3)}")

# check the collected statistics during training

In [24]:
print(f"Check statistics of the various layers")
net_quantized

Check statistics of the various layers


VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-1.1537182331085205, max_val=0.5618109107017517)
    (activation_post_process): MinMaxObserver(min_val=-69.19570922851562, max_val=49.220550537109375)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.6201204061508179, max_val=0.4973856210708618)
    (activation_post_process): MinMaxObserver(min_val=-51.87538528442383, max_val=38.172454833984375)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-1.745202660560608, max_val=0.27717524766921997)
    (activation_post_process): MinMaxObserver(min_val=-140.6678009033203, max_val=33.692996978759766)
  )
  (relu): ReLU()
  (dequant): DeQuantSt

# Quantize the model using the statistics collected

In [25]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)

In [26]:
print("check statistics of various layers")
net_quantized

check statistics of various layers


VerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.9324114918708801, zero_point=74, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.7090380787849426, zero_point=73, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=1.3729196786880493, zero_point=102, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# print weights and size of the model after quantization

In [27]:
print("Weights before quantization")
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[-1,  0,  1,  ...,  3,  0,  5],
        [14, 10,  9,  ..., 12, 11, 16],
        [ 2,  3,  1,  ..., -1,  1,  1],
        ...,
        [ 4,  1,  4,  ...,  1,  4,  0],
        [-1,  2,  6,  ..., -1,  2,  5],
        [ 0, -4, -5,  ..., -1, -4,  0]], dtype=torch.int8)


In [28]:
print("Testing the model after quantization")
test(net_quantized)

Testing the model after quantization


Tesing: 100%|██████████| 1000/1000 [00:05<00:00, 181.97it/s]

Accuracy: 0.965



