<a href="https://colab.research.google.com/github/testgithubprecious/Ml_projects/blob/main/QAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# Install if not already: pip install torch torchvision

import torch
import torch.nn as nn
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ------------------------------
# Prepare MNIST dataset
# ------------------------------
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

# ------------------------------
# Define quantizable MLP
# ------------------------------
class QuantizedMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc1 = nn.Linear(28*28, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.quant(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return self.dequant(x)

# ------------------------------
# Setup device
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedMLP().to(device)

# ------------------------------
# Fuse layers for QAT (fc+relu)
# ------------------------------
# Note: manual fusion required for nn.Sequential layers or individual modules
torch.quantization.fuse_modules(model, [['fc1','relu1'], ['fc2','relu2']], inplace=True)

# ------------------------------
# Prepare QAT configuration
# ------------------------------
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)

# ------------------------------
# Training function
# ------------------------------
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

def train(model, loader):
    model.train()
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# ------------------------------
# Train with QAT
# ------------------------------
for epoch in range(3):  # demo epochs
    train(model, train_loader)
    print(f"🧠 Epoch {epoch+1} complete (QAT)")

# ------------------------------
# Convert to quantized int8 model
# ------------------------------
model.to('cpu')
quantized_model = torch.quantization.convert(model.eval(), inplace=False)

# ------------------------------
# Evaluate quantized model
# ------------------------------
def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
    return correct / len(loader.dataset)

acc = evaluate(quantized_model, test_loader)
print(f"📦 Final Quantized Model Accuracy: {acc:.2%}")