In [86]:
import torch
import torch.quantization
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary
from convolution import CustomConv2D
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization import default_observer, default_per_channel_weight_observer
from torch.ao.quantization import QConfigMapping

In [87]:
class MNISTModel(nn.Module):
    def __init__(self, is_training=True):
        super(MNISTModel, self).__init__()

        if train_model:
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        else:
            self.conv1 = CustomConv2D(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = CustomConv2D(32, 64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [88]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [89]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [90]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [91]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [92]:
model = MNISTModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print(model)
summary(model, (1, 28, 28))

MNISTModel(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3136, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
         MaxPool2d-2           [-1, 32, 14, 14]               0
            Conv2d-3           [-1, 64, 14, 14]          18,496
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Linear-5                  [-1, 128]         401,536
            Linear-6                   [-1, 10]           1,290
Total params: 421,642
Trainable params: 421,642
Non-trainable params: 0
---------------------------------------

In [93]:
def train_model(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    total_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    print(f"Train Epoch [{epoch}], Loss: {total_loss / len(train_loader)}")

In [94]:
def test_model(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    avg_loss = test_loss / len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f"Test set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)")

    return accuracy

In [95]:
overall_test_accuracy = 0.0
for epoch in range(1, 6):
    train_model(model, device, train_loader, optimizer, criterion, epoch)
    test_accuracy = test_model(model, device, test_loader, criterion)
    overall_test_accuracy += test_accuracy

overall_test_accuracy /= 5
print(f"Overall Test Accuracy: {overall_test_accuracy:.2f}%")

Train Epoch [1], Loss: 0.26161004961171047
Test set: Average loss: 0.0632, Accuracy: 9795/10000 (97.95%)
Train Epoch [2], Loss: 0.06376786234372474
Test set: Average loss: 0.0407, Accuracy: 9863/10000 (98.63%)
Train Epoch [3], Loss: 0.04206291857234975
Test set: Average loss: 0.0387, Accuracy: 9867/10000 (98.67%)
Train Epoch [4], Loss: 0.030932307998312914
Test set: Average loss: 0.0441, Accuracy: 9858/10000 (98.58%)
Train Epoch [5], Loss: 0.025139690039956822
Test set: Average loss: 0.0307, Accuracy: 9894/10000 (98.94%)
Overall Test Accuracy: 98.55%


In [96]:
custom_qconfig_mapping = QConfigMapping().set_global(
    torch.ao.quantization.QConfig(
        activation=default_observer.with_args(quant_min=0, quant_max=255),
        weight=default_per_channel_weight_observer.with_args(quant_min=-128, quant_max=127)
    )
)

In [97]:
model_fp32 = MNISTModel().eval()
qconfig_mapping = get_default_qconfig_mapping()

In [98]:
x = torch.randn((1, 1, 28, 28), dtype=torch.float)
prepared_model = prepare_fx(model_fp32, custom_qconfig_mapping, example_inputs=x)

In [99]:
model.eval()
with torch.no_grad():
    for data, target in train_loader:
        data = data.to(device)
        prepared_model(data)

In [100]:
quantized_model = convert_fx(prepared_model)

In [101]:
test_accuracy = test_model(quantized_model, device, test_loader, criterion)
print(f"Quantized model test accuracy: {test_accuracy:.2f}%")

Test set: Average loss: 2.2993, Accuracy: 1199/10000 (11.99%)
Quantized model test accuracy: 11.99%
