## Libraries

In [7]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from tqdm import tqdm
import os
from pathlib import Path

## Load the datsets

In [2]:
transform=transforms.Compose(
    [transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))]

)

#Load the dataset
mnist=datasets.MNIST(root='./data',train=True,download=True,transform=transform)

train_dataloader=torch.utils.data.DataLoader(mnist,batch_size=10,shuffle=True)
mnist_test=datasets.MNIST(root='./data',train=False,download=True,transform=transform)
test_dataloader=torch.utils.data.DataLoader(mnist_test,batch_size=10,shuffle=True)


device="cpu"

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 56410636.96it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1899095.32it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 13707429.00it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3570189.05it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






# Create the Model

In [3]:
class SimpleNet(nn.Module):
    def __init__(self,hidden_size_1=100,hidden_size_2=200):
        super(SimpleNet,self).__init__()
        self.linear1=nn.Linear(28*28,hidden_size_1)
        self.linear2=nn.Linear(hidden_size_1,hidden_size_2)
        self.linear3=nn.Linear(hidden_size_2,10)
        self.relu=nn.ReLU()

    def forward(self,image):
        image=image.view(-1,28*28)
        image=self.relu(self.linear1(image))
        image=self.relu(self.linear2(image))
        image=self.linear3(image)
        return image

In [4]:
net=SimpleNet().to(device)

## Training the Model

In [5]:
def train(train_dataloader,net,epoch=5,total_iterations_limit=None):
  cross_entropy=nn.CrossEntropyLoss()
  optimizer=torch.optim.Adam(net.parameters(),lr=0.001)

  total_iters=0
  for e in range(epoch):
    net.train()
    loss_sum=0
    num_iters=0

    data_iterator=tqdm(train_dataloader,desc=f'Epoch {epoch+1}')
    if total_iterations_limit is not None:
      data_iterator.total=total_iterations_limit
    for data in data_iterator:
      num_iters+=1
      total_iters+=1
      image,label=data
      image,label=image.to(device),label.to(device)
      optimizer.zero_grad()
      output=net(image.view(-1,28*28))
      loss=cross_entropy(output,label)
      loss_sum+=loss.item()
      avg_loss=loss_sum/num_iters
      data_iterator.set_postfix(loss=avg_loss)
      loss.backward()
      optimizer.step()
      if total_iterations_limit is not None and total_iters>=total_iterations_limit:
        return


In [9]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_dataloader, net, epoch=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Epoch 2: 100%|██████████| 6000/6000 [01:12<00:00, 82.62it/s, loss=0.215]


In [15]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_dataloader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Before Quantization

In [10]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0080,  0.0578,  0.0237,  ...,  0.0003,  0.0079,  0.0556],
        [ 0.0396, -0.0197,  0.0119,  ...,  0.0084,  0.0382, -0.0078],
        [ 0.0125,  0.0403,  0.0097,  ...,  0.0050, -0.0146, -0.0065],
        ...,
        [ 0.0320, -0.0159,  0.0083,  ...,  0.0187, -0.0224,  0.0310],
        [-0.0153, -0.0110, -0.0316,  ..., -0.0404,  0.0094,  0.0189],
        [-0.0277, -0.0152, -0.0175,  ...,  0.0042,  0.0063, -0.0131]],
       requires_grad=True)
torch.float32


In [11]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 405.414


In [16]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 293.54it/s]

Accuracy: 0.967





# Quantized Model

In [21]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=200):
        super(QuantizedVerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [23]:
net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=200, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=200, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [24]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 241.11it/s]

Accuracy: 0.967





In [26]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [27]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.5790975689888, zero_point=74, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=200, scale=0.4156707525253296, zero_point=76, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=200, out_features=10, scale=0.44545111060142517, zero_point=80, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [28]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[ 2, 13,  5,  ...,  0,  2, 12],
        [ 9, -4,  3,  ...,  2,  8, -2],
        [ 3,  9,  2,  ...,  1, -3, -1],
        ...,
        [ 7, -3,  2,  ...,  4, -5,  7],
        [-3, -2, -7,  ..., -9,  2,  4],
        [-6, -3, -4,  ...,  1,  1, -3]], dtype=torch.int8)


#Compare the Dequantized weights

In [30]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))

Original weights: 
Parameter containing:
tensor([[ 0.0080,  0.0578,  0.0237,  ...,  0.0003,  0.0079,  0.0556],
        [ 0.0396, -0.0197,  0.0119,  ...,  0.0084,  0.0382, -0.0078],
        [ 0.0125,  0.0403,  0.0097,  ...,  0.0050, -0.0146, -0.0065],
        ...,
        [ 0.0320, -0.0159,  0.0083,  ...,  0.0187, -0.0224,  0.0310],
        [-0.0153, -0.0110, -0.0316,  ..., -0.0404,  0.0094,  0.0189],
        [-0.0277, -0.0152, -0.0175,  ...,  0.0042,  0.0063, -0.0131]],
       requires_grad=True)

Dequantized weights: 
tensor([[ 0.0092,  0.0595,  0.0229,  ...,  0.0000,  0.0092,  0.0550],
        [ 0.0412, -0.0183,  0.0137,  ...,  0.0092,  0.0366, -0.0092],
        [ 0.0137,  0.0412,  0.0092,  ...,  0.0046, -0.0137, -0.0046],
        ...,
        [ 0.0321, -0.0137,  0.0092,  ...,  0.0183, -0.0229,  0.0321],
        [-0.0137, -0.0092, -0.0321,  ..., -0.0412,  0.0092,  0.0183],
        [-0.0275, -0.0137, -0.0183,  ...,  0.0046,  0.0046, -0.0137]])


In [31]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 106.786


In [32]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 263.43it/s]

Accuracy: 0.966



