In [None]:
import torch
from torch import nn
from torch import optim
from torch.ao import quantization
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.datasets import ImageFolder


class QuantizedVGG16(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedVGG16, self).__init__()
        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()
        self.model_fp32 = model_fp32

    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x


hyperparams = {
    "batch_size": 4,
    "learning_rate": 0.0001,
    "epochs": 5,
    "transform": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48235, 0.45882, 0.40784],
                std=[0.229, 0.224, 0.225],
            ),
        ]
    ),
}

train_dataset = ImageFolder("../datasets/pet/train", transform=hyperparams["transform"])
test_dataset = ImageFolder("../datasets/pet/test", transform=hyperparams["transform"])

train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=True, drop_last=True)

model = models.vgg16(weights="VGG16_Weights.IMAGENET1K_V1")
model.classifier[6] = nn.Linear(4096, len(train_dataset.classes))

quantization_backend = "fbgemm"
device = "cuda" if torch.cuda.is_available() else "cpu"
quantized_model = QuantizedVGG16(model).to(device)
quantized_model.qconfig = quantization.get_default_qat_qconfig(quantization_backend)
quantization.prepare_qat(quantized_model)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(quantized_model.parameters(), lr=hyperparams["learning_rate"])

for epoch in range(hyperparams["epochs"]):
    cost = 0.0

    for images, classes in train_dataloader:
        images = images.to(device)
        classes = classes.to(device)

        output = quantized_model(images)
        loss = criterion(output, classes)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        cost += loss

    cost = cost / len(train_dataloader)
    print(f"Epoch : {epoch+1:4d}, Cost : {cost:.3f}")

with torch.no_grad():
    quantized_model.eval()

    accuracy = 0.0
    for images, classes in test_dataloader:
        images = images.to(device)
        classes = classes.to(device)

        outputs = quantized_model(images)
        probs = F.softmax(outputs, dim=-1)
        outputs_classes = torch.argmax(probs, dim=-1)

        accuracy += int(torch.eq(classes, outputs_classes).sum())

    print(f"acc@1 : {accuracy / (len(test_dataloader) * hyperparams['batch_size']) * 100:.2f}%")
    
quantized_model = quantized_model.to("cpu")
quantization.convert(quantized_model)
torch.jit.save(torch.jit.script(quantized_model), "QAT_VGG16.pt")