In [None]:
import torch
from torch import nn
from torch.ao import quantization
from torchvision import models
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder


class QuantizedVGG16(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedVGG16, self).__init__()
        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()
        self.model_fp32 = model_fp32
    
    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

hyperparams = {
    "batch_size": 4,
    "learning_rate": 0.0001,
    "epochs": 5,
    "transform": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48235, 0.45882, 0.40784],
                std=[0.229, 0.224, 0.225],
            ),
        ]
    ),
}

model = models.vgg16(num_classes=2)
model.load_state_dict(torch.load("../models/VGG16.pt"))

device = "cuda" if torch.cuda.is_available() else "cpu"
quantized_model = QuantizedVGG16(model).to(device)

quantization_backend = "fbgemm"
quantized_model.qconfig = quantization.get_default_qconfig(quantization_backend)

model_static_quantized = quantization.prepare(quantized_model)

calibration_dataset = ImageFolder(
    "../datasets/pet/test",
    transform=hyperparams["transform"]
)
calibration_dataloader = DataLoader(
    calibration_dataset,
    batch_size=hyperparams["batch_size"]
)

for i, (images, target) in enumerate(calibration_dataloader):
    if i >= 10:
        break
    model_static_quantized(images.to(device))

model_static_quantized.to("cpu")
model_static_quantized = quantization.convert(model_static_quantized)

torch.jit.save(torch.jit.script(model_static_quantized), "../models/PTSQ_VGG16.pt")