In [None]:
import torch
import torch.nn as nn
import torch.quantization
from torchvision import datasets, transforms, models

# ✅ Device setup
device = torch.device("cpu")
print("Using:", device)

# ✅ Simple transform
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
])

# ✅ Tiny CIFAR-10 subset for quick testing
testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_subset, _ = torch.utils.data.random_split(testset, [1000, len(testset) - 1000])
testloader = torch.utils.data.DataLoader(test_subset, batch_size=32, shuffle=False)

# ✅ Load the quantization-ready MobileNetV2 model
model_fp32 = models.quantization.mobilenet_v2(weights="DEFAULT", quantize=False)
model_fp32.eval()

# ✅ Set quantization backend
torch.backends.quantized.engine = "fbgemm"

# ✅ Fuse model (now available)
model_fp32.fuse_model()

# ✅ Define quantization config
model_fp32.qconfig = torch.quantization.get_default_qconfig("fbgemm")

# ✅ Prepare and calibrate
model_prepared = torch.quantization.prepare(model_fp32)
with torch.no_grad():
    for imgs, _ in list(testloader)[:5]:
        model_prepared(imgs)

# ✅ Convert to quantized version
model_int8 = torch.quantization.convert(model_prepared)
print("✅ Model quantized successfully!")

# ✅ Evaluate quickly
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in testloader:
        outputs = model_int8(imgs)
        preds = outputs.argmax(1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

print(f"🎯 Quantized Model Accuracy: {100 * correct / total:.2f}%")
