In [9]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn 
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

## load MNIST dataset


In [3]:
# make torch deterministic
_ = torch.manual_seed(0)

In [5]:
import torch.utils


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081))])

# Load the MINIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# create dataloader for training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

device = "cpu"

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:09<00:00, 1088571.73it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:45<00:00, 36155.71it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






## Define the model

In [11]:

class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2 = 100):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [12]:
net = VerySimpleNet().to(device)

## Train the model

In [18]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and (total_iterations >= total_iterations_limit):
                return
        

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')
    
MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print("Loaded model from disk")
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Loaded model from disk


## Define the testing loop

In [19]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    
    iterations = 0
    
    model.eval()
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1
            if total_iterations is not None and (iterations >= total_iterations):
                break
    
    print(f'Accuracy: {round(correct/total, 3)}')

## Print weights and size of the model before quantization

In [20]:
# weights matrix of the model before quantizaiton
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[ 0.0181,  0.0062, -0.0277,  ..., -0.0345, -0.0427,  0.0037],
        [ 0.0307, -0.0066,  0.0099,  ...,  0.0511,  0.0323,  0.0331],
        [-0.0134,  0.0397,  0.0118,  ..., -0.0003,  0.0548,  0.0379],
        ...,
        [ 0.0175,  0.0004, -0.0011,  ...,  0.0010,  0.0305,  0.0111],
        [-0.0010, -0.0197, -0.0457,  ..., -0.0275, -0.0398, -0.0106],
        [-0.0100, -0.0256, -0.0232,  ..., -0.0382,  0.0136, -0.0216]],
       requires_grad=True)
torch.float32


In [21]:
print('Size of model before quantization')
print_size_of_model(net)

Size of model before quantization
Size (KB): 360.998


In [23]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1083.77it/s]

Accuracy: 0.949





# Apply Quantization

## Insert min-max observers in the model

In [24]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [28]:

net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized)   # Insert observers
net_quantized

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

## Calibrate the model using the test set

In [30]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 944.05it/s]

Accuracy: 0.949





Accuracy of the model before quantization: is (0.949)

In [32]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-42.980228424072266, max_val=33.47325897216797)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-23.005212783813477, max_val=22.060422897338867)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-29.85936737060547, max_val=21.634349822998047)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

## Quantize model using the statistics collected

collect statistics and apply on model

In [35]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [36]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.6019960045814514, zero_point=71, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.3548475205898285, zero_point=65, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.4054623544216156, zero_point=74, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [38]:
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[  4,   1,  -6,  ...,  -8, -10,   1],
        [  7,  -1,   2,  ...,  11,   7,   7],
        [ -3,   9,   3,  ...,   0,  12,   8],
        ...,
        [  4,   0,   0,  ...,   0,   7,   2],
        [  0,  -4, -10,  ...,  -6,  -9,  -2],
        [ -2,  -6,  -5,  ...,  -9,   3,  -5]], dtype=torch.int8)


## Compare the dequantized weights and original weights

In [51]:
print('Original weights: ')
print(net.linear1.weight)
print(f'Data type: {net.linear1.weight.dtype}')

print('Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print(f' Data type: int8')

Original weights: 
Parameter containing:
tensor([[ 0.0181,  0.0062, -0.0277,  ..., -0.0345, -0.0427,  0.0037],
        [ 0.0307, -0.0066,  0.0099,  ...,  0.0511,  0.0323,  0.0331],
        [-0.0134,  0.0397,  0.0118,  ..., -0.0003,  0.0548,  0.0379],
        ...,
        [ 0.0175,  0.0004, -0.0011,  ...,  0.0010,  0.0305,  0.0111],
        [-0.0010, -0.0197, -0.0457,  ..., -0.0275, -0.0398, -0.0106],
        [-0.0100, -0.0256, -0.0232,  ..., -0.0382,  0.0136, -0.0216]],
       requires_grad=True)
Data type: torch.float32
Dequantized weights: 
tensor([[ 0.0179,  0.0045, -0.0269,  ..., -0.0359, -0.0449,  0.0045],
        [ 0.0314, -0.0045,  0.0090,  ...,  0.0494,  0.0314,  0.0314],
        [-0.0135,  0.0404,  0.0135,  ...,  0.0000,  0.0538,  0.0359],
        ...,
        [ 0.0179,  0.0000,  0.0000,  ...,  0.0000,  0.0314,  0.0090],
        [ 0.0000, -0.0179, -0.0449,  ..., -0.0269, -0.0404, -0.0090],
        [-0.0090, -0.0269, -0.0224,  ..., -0.0404,  0.0135, -0.0224]])
 Data type: int8


## Print size and accuracy of the quantized model

In [52]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 95.394


Original model size was : Size (KB): 360.998

In [53]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:02<00:00, 437.01it/s]

Accuracy: 0.948



