In [9]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn 
import matplotlib.pyplot as plt 
from tqdm import tqdm
from pathlib import Path 
import os

# Load MNIST dataset

In [10]:
# make torch deterministic
_ = torch.manual_seed(0)

In [12]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load theMNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

device = "cpu"

# Define the model

In [13]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2 = 100):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x


In [14]:
net = VerySimpleNet().to(device)

# Train the model

In [15]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x,y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print("size (KB): ", os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')
            
MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print("Loaded model from disk")
else:
    train(train_loader, net, epochs=1)
    torch.save(net.state_dict(), MODEL_FILENAME)



Epoch 1: 100%|██████████| 6000/6000 [01:45<00:00, 56.96it/s, loss=0.223]


# Define the testing loop

In [17]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Print weights and size of the model before quantization

In [18]:
# Print the weight matrix of the model before quantization
print("Weights before quantization")
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[-0.0127,  0.0067, -0.0419,  ...,  0.0095, -0.0087, -0.0104],
        [-0.0197, -0.0149, -0.0104,  ..., -0.0202, -0.0059, -0.0299],
        [ 0.0199,  0.0550,  0.0068,  ...,  0.0197,  0.0413,  0.0481],
        ...,
        [ 0.0278,  0.0316, -0.0031,  ..., -0.0084,  0.0108, -0.0261],
        [ 0.0056,  0.0137,  0.0458,  ...,  0.0261,  0.0261,  0.0256],
        [ 0.0102,  0.0049, -0.0093,  ...,  0.0271, -0.0221, -0.0020]],
       requires_grad=True)
torch.float32


In [19]:
print("size of the model before quantization")
print_size_of_model(net)

size of the model before quantization
size (KB):  360.998


In [20]:
print(f"Accuracy of the model before quantization: ")
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:05<00:00, 173.66it/s]

Accuracy: 0.961





# Insert min-max observers in the model

In [21]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2 = 100):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x


In [22]:
net_quantized = QuantizedVerySimpleNet().to(device)

# copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # insert observer
net_quantized

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Calibrate the model using the test set

In [23]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:07<00:00, 140.95it/s]

Accuracy: 0.961





In [24]:
print(f"check statistics of the various layers")
net_quantized

check statistics of the various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-53.427085876464844, max_val=38.368324279785156)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-29.539627075195312, max_val=25.24178695678711)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-25.761398315429688, max_val=22.611188888549805)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Dequantize the model using the statistics collected

In [25]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [26]:
print(f"check statistics of various layers")
net_quantized

check statistics of various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.7227985262870789, zero_point=74, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.4313497245311737, zero_point=68, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.3808865249156952, zero_point=68, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

# Print the model weights after quantization

In [27]:
print("Weight after quantization")
print(torch.int_repr(net_quantized.linear1.weight()))

Weight after quantization
tensor([[-3,  1, -9,  ...,  2, -2, -2],
        [-4, -3, -2,  ..., -4, -1, -6],
        [ 4, 11,  1,  ...,  4,  8, 10],
        ...,
        [ 6,  6, -1,  ..., -2,  2, -5],
        [ 1,  3,  9,  ...,  5,  5,  5],
        [ 2,  1, -2,  ...,  6, -5,  0]], dtype=torch.int8)


# compare dequantized weights and original weight

In [29]:
print("Original weight")
print(net.linear1.weight)

print("Dequantized weights")
print(torch.dequantize(net_quantized.linear1.weight()))

Original weight
Parameter containing:
tensor([[-0.0127,  0.0067, -0.0419,  ...,  0.0095, -0.0087, -0.0104],
        [-0.0197, -0.0149, -0.0104,  ..., -0.0202, -0.0059, -0.0299],
        [ 0.0199,  0.0550,  0.0068,  ...,  0.0197,  0.0413,  0.0481],
        ...,
        [ 0.0278,  0.0316, -0.0031,  ..., -0.0084,  0.0108, -0.0261],
        [ 0.0056,  0.0137,  0.0458,  ...,  0.0261,  0.0261,  0.0256],
        [ 0.0102,  0.0049, -0.0093,  ...,  0.0271, -0.0221, -0.0020]],
       requires_grad=True)
Dequantized weights
tensor([[-0.0146,  0.0049, -0.0437,  ...,  0.0097, -0.0097, -0.0097],
        [-0.0194, -0.0146, -0.0097,  ..., -0.0194, -0.0049, -0.0291],
        [ 0.0194,  0.0534,  0.0049,  ...,  0.0194,  0.0388,  0.0486],
        ...,
        [ 0.0291,  0.0291, -0.0049,  ..., -0.0097,  0.0097, -0.0243],
        [ 0.0049,  0.0146,  0.0437,  ...,  0.0243,  0.0243,  0.0243],
        [ 0.0097,  0.0049, -0.0097,  ...,  0.0291, -0.0243,  0.0000]])


# Print size and accuracy of the quantized model

In [30]:
print("size of the model after quantization")
print_size_of_model(net_quantized)
print("Tesing the model after quantization")
test(net_quantized)

size of the model after quantization
size (KB):  95.394
Tesing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:04<00:00, 200.45it/s]

Accuracy: 0.961



