## Understanding the concepts of Quantizations 

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm 
from pathlib import Path
import os

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_trainset = datasets.MNIST(root= './data', train= True, download= True, transform= transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size= 20, shuffle= True)

mnist_testset = datasets.MNIST(root= './data', train= False, download= True, transform= transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size= 10, shuffle= True)

In [3]:
len(mnist_trainset), len(mnist_testset)

(60000, 10000)

In [4]:
device = 'cpu'

In [5]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1 = 100, hidden_size_2= 100):
        super(VerySimpleNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [6]:
net = VerySimpleNet().to(device)

In [7]:
def train(train_loader, net, epochs= 5, total_iteration_limit= None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr= 0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()
        loss_sum = 0
        num_iteration = 0

        data_iterator = tqdm(train_loader, desc= f'epoch {epoch+1}')
        if total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        for data in data_iterator:
            num_iteration += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iteration
            data_iterator.set_postfix(loss = avg_loss)
            loss.backward()
            optimizer.step()

            if total_iteration_limit is not None and total_iterations >= total_iteration_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('size (KB) :',os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print("loaded model from disk")

else :
    train(train_loader, net, epochs= 5)
    torch.save(net.state_dict(), MODEL_FILENAME)


loaded model from disk


In [8]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0
    model.eval()

    with torch.inference_mode():
        for data in tqdm(test_loader, desc= "testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy : {round(correct/total, 3)}')

In [9]:
print('weights before quantization ')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

weights before quantization 
Parameter containing:
tensor([[ 0.0607,  0.0424,  0.0059,  ...,  0.0174,  0.0488,  0.0213],
        [-0.0083,  0.0153,  0.0304,  ...,  0.0152,  0.0430,  0.0030],
        [ 0.0363,  0.0573,  0.0394,  ...,  0.0461,  0.0401,  0.0088],
        ...,
        [ 0.0441,  0.0173,  0.0452,  ...,  0.0729,  0.0238,  0.0820],
        [ 0.0039,  0.0593,  0.0658,  ...,  0.0267, -0.0009,  0.0061],
        [ 0.0490,  0.0488,  0.0113,  ...,  0.0463,  0.0136,  0.0208]],
       requires_grad=True)
torch.float32


In [10]:
print("size of the model before quantisation")
print_size_of_model(net)

size of the model before quantisation
size (KB) : 360.998


In [11]:
print("accuracy of model before quantization")
test(net)

accuracy of model before quantization


testing: 100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 444.37it/s]

Accuracy : 0.974





In [12]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1= 100, hidden_size_2= 100):
        super(QuantizedVerySimpleNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

In [13]:
net_quantised = QuantizedVerySimpleNet().to(device)

net_quantised.load_state_dict(net.state_dict())
net_quantised.eval()

net_quantised.qconfig = torch.ao.quantization.default_qconfig
net_quantised = torch.ao.quantization.prepare(net_quantised)
net_quantised

QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [14]:
test(net_quantised)

testing: 100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 340.73it/s]

Accuracy : 0.974





In [15]:
print("check statistics of various layers")
net_quantised

check statistics of various layers


QuantizedVerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-65.33385467529297, max_val=34.708717346191406)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-47.84341812133789, max_val=32.8056640625)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-57.502113342285156, max_val=31.2893123626709)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [16]:
net_quantised = torch.ao.quantization.convert(net_quantised)

In [17]:
print("statistics of various layers")
net_quantised

statistics of various layers


QuantizedVerySimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=0.7877367734909058, zero_point=83, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.6350321173667908, zero_point=75, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.6991450786590576, zero_point=82, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [18]:
print("weights after the quantization")
print(torch.int_repr(net_quantised.linear1.weight()))

weights after the quantization
tensor([[ 8,  6,  1,  ...,  2,  7,  3],
        [-1,  2,  4,  ...,  2,  6,  0],
        [ 5,  8,  5,  ...,  6,  6,  1],
        ...,
        [ 6,  2,  6,  ..., 10,  3, 11],
        [ 1,  8,  9,  ...,  4,  0,  1],
        [ 7,  7,  2,  ...,  6,  2,  3]], dtype=torch.int8)


In [19]:
print("orginal weight :")
print(net.linear1.weight)
print()
print("Dequantized weights")
print(torch.dequantize(net_quantised.linear1.weight()))

orginal weight :
Parameter containing:
tensor([[ 0.0607,  0.0424,  0.0059,  ...,  0.0174,  0.0488,  0.0213],
        [-0.0083,  0.0153,  0.0304,  ...,  0.0152,  0.0430,  0.0030],
        [ 0.0363,  0.0573,  0.0394,  ...,  0.0461,  0.0401,  0.0088],
        ...,
        [ 0.0441,  0.0173,  0.0452,  ...,  0.0729,  0.0238,  0.0820],
        [ 0.0039,  0.0593,  0.0658,  ...,  0.0267, -0.0009,  0.0061],
        [ 0.0490,  0.0488,  0.0113,  ...,  0.0463,  0.0136,  0.0208]],
       requires_grad=True)

Dequantized weights
tensor([[ 0.0574,  0.0431,  0.0072,  ...,  0.0144,  0.0502,  0.0215],
        [-0.0072,  0.0144,  0.0287,  ...,  0.0144,  0.0431,  0.0000],
        [ 0.0359,  0.0574,  0.0359,  ...,  0.0431,  0.0431,  0.0072],
        ...,
        [ 0.0431,  0.0144,  0.0431,  ...,  0.0718,  0.0215,  0.0789],
        [ 0.0072,  0.0574,  0.0646,  ...,  0.0287,  0.0000,  0.0072],
        [ 0.0502,  0.0502,  0.0144,  ...,  0.0431,  0.0144,  0.0215]])


In [20]:
print("size of model after quantization")
print_size_of_model(net_quantised)

size of model after quantization
size (KB) : 95.394


In [21]:
print("testing model after quantization")
test(net_quantised)

testing model after quantization


testing: 100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 394.62it/s]

Accuracy : 0.974



