### Post Training Quantization 
- calibration for the scale and zeropoint parameters
- tracks the statistics of a layer using observers

In [1]:
# add all the imports
import os
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn

from tqdm import tqdm


_ = torch.manual_seed(1337)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# define the dataset

transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081))])
mnist_train = datasets.MNIST(root="./data/", train=True, download=True, transform=transforms)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10, shuffle=True)

mnist_test = datasets.MNIST(root="./data/", train=False, download=True, transform=transforms)

test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=10, shuffle=True)

# no gpu on mac
device='cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=128, num_classes=10):
        super(SimpleNN, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)

        return x


In [4]:
model = SimpleNN().to(device)

#### Training the model

In [None]:
def train(train_loader, model, epochs=10):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    
    global_step = 0
    for epoch in range(epochs):
        model.train()
        step = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        loss_sum = 0
        for batch in data_iterator:
            step += 1
            global_step += 1

            img, label = batch
            img, label = img.to(device), label.to(device)

            optimizer.zero_grad()
            output = model(img)
            loss = loss_fn(output, label)
            loss_sum += loss.item()
            avg_loss = loss_sum / step
            data_iterator.set_postfix(loss=avg_loss)

            loss.backward()
            optimizer.step()


def print_model_size(model):
    torch.save(model.state_dict(), 'temp_model.pt')
    print(f'Size (KB): {os.path.getsize("temp_model.pt",) / 1e3}')
    os.remove("temp_model.pt")


def print_trainable_parameters(model):
    trainable_parameters = [torch.numel(p) for p in model.parameters() if p.requires_grad]
    print(f'Total trainable parameters (K): {sum(trainable_parameters) / 1e3}')


MODEL_FILENAME = './models/simplenn_no_quant.pt'
if os.path.exists(MODEL_FILENAME):
    model.load_state_dict(torch.load(MODEL_FILENAME))
    print("Loaded Model from disk")
else:
    if not os.path.exists(MODEL_FILENAME):
        os.makedirs(MODEL_FILENAME)
    # print trainable parameters
    print_trainable_parameters(model)
    train(train_loader, model, epochs=5)
    torch.save(model.state_dict(), MODEL_FILENAME)

In [10]:
os.makedirs('models/')

#### Test loop

In [14]:
def test(model):
    correct = 0
    total = 0

    iterations = 0
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            img, label = data
            img, label = img.to(device), label.to(device)
            output = model(x)
            for idx, o in enumerate(output):
                if torch.argmax(o) == label[idx]:
                    correct += 1
                total += 1
    
    print(f'Accuracy: {(correct / total)}')
