## Quantization Aware Training
 -  adding dummy modules that quantize/ dequantize the weights & activations
 -  doing this while training, makes the model to learn the quantization patterns
 -  This leads to improved inference accuracy.

 **This is used in modern LLM quantization / finetuning (e.g. QLoRA.)**

In [1]:
# imports
import os
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch.nn as nn

from tqdm import tqdm

_ = torch.manual_seed(0)


In [2]:
# define dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1037,), (0.3081))])
mnist_train = MNIST(root="./data/", train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10, shuffle=True)

mnist_test = MNIST(root='./data/', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=10, shuffle=False)

# get gpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# define the quantization aware trainer
import torch.ao.quantization


class SimpleNetwork(nn.Module):
    def __init__(self, hidden_1=128, hidden_2=128, num_classes=10):
        super().__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.linear_1 = nn.Linear(28 * 28, hidden_1)
        self.linear_2 = nn.Linear(hidden_1, hidden_2)
        self.linear_3 = nn.Linear(hidden_2, num_classes)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.quant(x)
        x = self.relu(self.linear_1(x))
        x = self.relu(self.linear_2(x))
        x = self.linear_3(x)
        x = self.dequant(x)

        return x
    
model = SimpleNetwork()

In [4]:
model.qconfig = torch.ao.quantization.default_qconfig
model.train()

model_quantized = torch.ao.quantization.prepare_qat(model)
model_quantized

SimpleNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear_1): Linear(
    in_features=784, out_features=128, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear_2): Linear(
    in_features=128, out_features=128, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear_3): Linear(
    in_features=128, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

### Train the model

In [5]:
def train(train_loader, model, epochs=5, steps_per_epoch=None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

    global_step = 0
    
    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        step = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        if steps_per_epoch is None:
            steps_per_epoch = data_iterator.total

        for data in data_iterator:
            step += 1
            global_step += 1
            x, y = data
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            # forward pass
            logits = model(x)
            # loss computation
            loss = loss_fn(logits, y)
            loss_sum += loss.item()

            avg_loss = loss_sum / step
            data_iterator.set_postfix({'loss': avg_loss})

            loss.backward()
            optimizer.step()

            if step == steps_per_epoch:
                # stopping the iterator
                break
    
    return model

def print_model_size(model):
    torch.save(model.state_dict(), "tmp.pth")
    print(f'Size (KB): {os.path.getsize("tmp.pth") / 1e3}')


def print_trainable_parameters(model):
    trainable_parameters = sum([param.numel() for param in model.parameters() if param.requires_grad])
    print(f'Trainable Parameters (K): {trainable_parameters / 1e3}')


MODEL_DIR = './models'
MODEL_FILENAME = f'{MODEL_DIR}/simplenn_qat.pth'

if os.path.exists(MODEL_FILENAME):
    model.load_state_dict(torch.load(MODEL_FILENAME))
    print(f"Found model at {MODEL_FILENAME}: Model loaded.")
else:
    if not os.path.exists(MODEL_DIR):
        os.makedirs(MODEL_DIR)
    
    print_trainable_parameters(model)
    train(train_loader, model_quantized, epochs=5)
    torch.save(model_quantized.state_dict(), MODEL_FILENAME)        

Trainable Parameters (K): 118.282


Epoch 1: 100%|█████████▉| 5999/6000 [00:41<00:00, 143.91it/s, loss=0.252]
Epoch 2: 100%|█████████▉| 5999/6000 [00:31<00:00, 190.64it/s, loss=0.107]
Epoch 3: 100%|█████████▉| 5999/6000 [00:31<00:00, 191.48it/s, loss=0.0737]
Epoch 4: 100%|█████████▉| 5999/6000 [00:31<00:00, 193.20it/s, loss=0.0557]
Epoch 5: 100%|█████████▉| 5999/6000 [00:32<00:00, 187.36it/s, loss=0.0444]


### Test loop

In [6]:
def test(model: nn.Module, device):
    hits = 0
    total = 0
    # accuracy = hits / total

    # set the model device
    model.to(device)
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            img, label = data
            img, label = img.to(device), label.to(device)
            output = model(img)
            for idx, o in enumerate(output):
                if torch.argmax(o) == label[idx]:
                    hits += 1
                
                total += 1
    
    print(f'Accuracy: {(hits / total)}')
                

In [7]:
test(model_quantized, device)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 504.07it/s]

Accuracy: 0.9759





In [8]:
print(f'Model quantization statistics: ')
print(model_quantized)

Model quantization statistics: 
SimpleNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.33657902479171753, max_val=2.909120559692383)
  )
  (linear_1): Linear(
    in_features=784, out_features=128, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.42874079942703247, max_val=0.28564441204071045)
    (activation_post_process): MinMaxObserver(min_val=-33.25809860229492, max_val=25.94108009338379)
  )
  (linear_2): Linear(
    in_features=128, out_features=128, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.37882840633392334, max_val=0.3766877055168152)
    (activation_post_process): MinMaxObserver(min_val=-26.20763397216797, max_val=21.723499298095703)
  )
  (linear_3): Linear(
    in_features=128, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.37869828939437866, max_val=0.21585747599601746)
    (activation_post_process): MinMaxObserver(min_val=-45.703155517578125, max_val=28.475431442260742)
  

In [9]:
model_quantized.eval()
# model_quantized.load_state_dict(model_trained.state_dict())

model_quantized = torch.ao.quantization.convert(model_quantized)

RuntimeError: Didn't find engine for operation quantized::linear_prepack NoQEngine