# Import the necessary libraries

In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x114c5e390>

# Loading the MNIST dataset

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# loading the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# creating a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# loading the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# creating a dataloader for the testing
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4173076.61it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 24621075.98it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2430906.96it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4387500.87it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






# Defining the model

In [4]:
class NeuralNetwork(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.ln1 = nn.Linear(28*28, hidden_size_1) 
        self.ln2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.ln3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.ln1(x))
        x = self.relu(self.ln2(x))
        x = self.ln3(x)
        x = self.dequant(x)
        return x

In [5]:
model = NeuralNetwork().to(device)

# Inserting min-max observers in the model

In [6]:
model.qconfig = torch.ao.quantization.default_qconfig
model.train()
model_quantized = torch.ao.quantization.prepare_qat(model) # Inserting observers
model_quantized

NeuralNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (ln1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (ln2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (ln3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Training the model

In [7]:
def train(train_loader, model, epochs=5, total_iterations_limit=None):
    cross_entropy = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_entropy(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (KB):', os.path.getsize("temp.p")/1e3)
    os.remove('temp.p')

In [8]:
train(train_loader, model_quantized, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:48<00:00, 122.78it/s, loss=0.222]


# Defining the testing loop

In [9]:
def test(model: nn.Module, total_iterations: int = None):
    
    correct = 0
    total = 0
    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

# Checking the collected statistics during training

In [10]:
print(f'Statistics of the various layers')
model_quantized

Statistics of the various layers


NeuralNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (ln1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5424702167510986, max_val=0.3144207000732422)
    (activation_post_process): MinMaxObserver(min_val=-53.354026794433594, max_val=35.39004898071289)
  )
  (ln2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.41298532485961914, max_val=0.3560512363910675)
    (activation_post_process): MinMaxObserver(min_val=-28.404935836791992, max_val=24.387720108032227)
  )
  (ln3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.4521658420562744, max_val=0.26482340693473816)
    (activation_post_process): MinMaxObserver(min_val=-32.44058609008789, max_val=21.72846794128418)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

# Quantizing the model using the statistics collected

In [11]:
model_quantized.eval()
net_quantized = torch.ao.quantization.convert(model_quantized)

In [12]:
print(f'Statistics of the various layers after quantization')
model_quantized

Statistics of the various layers after quantization


NeuralNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (ln1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5424702167510986, max_val=0.3144207000732422)
    (activation_post_process): MinMaxObserver(min_val=-53.354026794433594, max_val=35.39004898071289)
  )
  (ln2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.41298532485961914, max_val=0.3560512363910675)
    (activation_post_process): MinMaxObserver(min_val=-28.404935836791992, max_val=24.387720108032227)
  )
  (ln3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.4521658420562744, max_val=0.26482340693473816)
    (activation_post_process): MinMaxObserver(min_val=-32.44058609008789, max_val=21.72846794128418)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [13]:
print('Testing the model after quantization')
test(model_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:03<00:00, 254.11it/s]

Accuracy: 0.956



