### Post Training Quantization 
- calibration for the scale and zeropoint parameters
- tracks the statistics of a layer using observers

In [1]:
# add all the imports
import os
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn

from tqdm import tqdm


_ = torch.manual_seed(1337)

In [2]:
# define the dataset

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081))])
mnist_train = datasets.MNIST(root="./data/", train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10, shuffle=True)

mnist_test = datasets.MNIST(root="./data/", train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=10, shuffle=True)

# no gpu on mac
device='cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=128, num_classes=10):
        super(SimpleNN, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)

        return x


In [4]:
model = SimpleNN().to(device)

#### Training the model

In [5]:
def train(train_loader, model, epochs=10):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    
    global_step = 0
    for epoch in range(epochs):
        model.train()
        step = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        loss_sum = 0
        for batch in data_iterator:
            step += 1
            global_step += 1

            img, label = batch
            img, label = img.to(device), label.to(device)

            optimizer.zero_grad()
            output = model(img)
            loss = loss_fn(output, label)
            loss_sum += loss.item()
            avg_loss = loss_sum / step
            data_iterator.set_postfix(loss=avg_loss)

            loss.backward()
            optimizer.step()


def print_model_size(model):
    torch.save(model.state_dict(), 'temp_model.pt')
    print(f'Size (KB): {os.path.getsize("temp_model.pt",) / 1e3}')
    os.remove("temp_model.pt")


def print_trainable_parameters(model):
    trainable_parameters = [torch.numel(p) for p in model.parameters() if p.requires_grad]
    print(f'Total trainable parameters (K): {sum(trainable_parameters) / 1e3}')

MODEL_PATH = './models/'
MODEL_FILENAME = f'.{MODEL_PATH}/simplenn_no_quant.pt'
if os.path.exists(MODEL_FILENAME):
    model.load_state_dict(torch.load(MODEL_FILENAME))
    print("Loaded Model from disk")
else:
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)
    # print trainable parameters
    print_trainable_parameters(model)
    train(train_loader, model, epochs=5)
    torch.save(model.state_dict(), MODEL_FILENAME)

Loaded Model from disk


#### Test loop

In [6]:
def test(model: nn.Module, device):
    correct = 0
    total = 0

    iterations = 0
    model.to(device)
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            img, label = data
            img, label = img.to(device), label.to(device)
            output = model(img)
            for idx, o in enumerate(output):
                if torch.argmax(o) == label[idx]:
                    correct += 1
                total += 1
    
    print(f'Accuracy: {(correct / total)}')


In [7]:
test(model, device=device)

Testing: 100%|██████████| 1000/1000 [00:00<00:00, 1558.75it/s]

Accuracy: 0.9781





In [8]:
# Check the weight size of the model 
print(f'Size of the model before Quantization: ')
print_model_size(model)

Size of the model before Quantization: 
Size (KB): 475.814


### Adding the quantization observers in the model
 - Calibrating the model activations for quantization
 - Use test data for the calibration

In [9]:
class SimpleQuantNN(nn.Module):
    def __init__(self, hidden_size_1=128, hidden_size_2=128, num_classes=10):
        super(SimpleQuantNN, self).__init__()
        # adding quant 
        self.quant = torch.ao.quantization.QuantStub()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, num_classes)
        self.relu = nn.ReLU()
        # adding dequant
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        
        return x

In [12]:
# defining the model
quant_model = SimpleQuantNN().to(device)
# We have to load the weights from the trained model
quant_model.load_state_dict(model.state_dict())
quant_model.eval()  # ?

quant_model.qconfig = torch.ao.quantization.default_qconfig
quant_model = torch.ao.quantization.prepare(quant_model)

quant_model


SimpleQuantNN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=128, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=128, out_features=128, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=128, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [13]:
#######################
# CALIBRATE the MODEL #
#######################

test(quant_model, device='cpu')


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 683.80it/s]

Accuracy: 0.9781





In [14]:
print(f'Check the stats: ')
quant_model

# NOTE: The observer stats of the model are filled with the MinMax stats instead of the -inf and the inf values before.

Check the stats: 


SimpleQuantNN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=128, bias=True
    (activation_post_process): MinMaxObserver(min_val=-40.58298873901367, max_val=30.44935417175293)
  )
  (linear2): Linear(
    in_features=128, out_features=128, bias=True
    (activation_post_process): MinMaxObserver(min_val=-33.20005798339844, max_val=28.273630142211914)
  )
  (linear3): Linear(
    in_features=128, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-60.843292236328125, max_val=40.28129959106445)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

#### Quantize the calibrated model
 

In [15]:
quant_model = torch.ao.quantization.convert(quant_model)

print(f"Check the stats: ")
quant_model

# NOTE: Normal layers are converted into QuantizedLinear layers and they are quantized.
        # - per_tensor_affine quantization is used

Check the stats: 


SimpleQuantNN(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=128, scale=0.5593097805976868, zero_point=73, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=128, out_features=128, scale=0.48404476046562195, zero_point=69, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=128, out_features=10, scale=0.7962566018104553, zero_point=76, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [16]:
print(f'Size of the quantized model:') 
print_model_size(quant_model)

Size of the quantized model:
Size (KB): 124.322


In [17]:
## Look at the model performance # run on cpu
print(f'Accuracy after quantization: ')
test(quant_model, device='cpu')

#NOTE: There is a very minimal drop in the accuracy of the model, 
# - Accuracy: 0.9779

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 705.04it/s]

Accuracy: 0.9779



