In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.quantization
import os
import time
import numpy as np
import matplotlib.pyplot as plt

# ----- Enhanced CNN model -----
class EnhancedCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

# ----- Enhanced CNN with Quant/DeQuant stubs -----
class QuantEnhancedCNN(EnhancedCNN):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        # Store the current mode
        was_training = self.training
        
        # Ensure the model is in eval mode for fusion
        self.eval()
        
        # Fuse conv+bn+relu layers
        torch.quantization.fuse_modules(self.conv, [['0', '1', '2'], ['4', '5', '6']], inplace=True)
        
        # Restore the original training state
        if was_training:
            self.train()

# ----- Data Preparation -----
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# ----- Training Function -----
def train(model, epochs=5, learning_rate=0.001):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    train_losses = []
    for epoch in range(epochs):
        epoch_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(trainloader)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
    
    return train_losses

# ----- Evaluation -----
def evaluate(model):
    model.eval()
    correct = total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1
    
    # Overall accuracy
    overall_accuracy = 100 * correct / total
    
    # Per-class accuracy
    class_accuracies = [100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0 for i in range(10)]
    
    return overall_accuracy, class_accuracies

# ----- Benchmark Inference -----
def benchmark(model, dummy_input, n=100):
    model.eval()
    # Warm-up runs
    with torch.no_grad():
        for _ in range(10):
            model(dummy_input)
    
    # Timing
    start = time.time()
    with torch.no_grad():
        for _ in range(n):
            model(dummy_input)
    end = time.time()
    
    # Calculate metrics
    avg_time = (end - start) / n * 1000  # ms/sample
    memory_usage = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
    
    return avg_time, memory_usage

# ----- Visualization Functions -----
def plot_comparisons(models, metrics):
    plt.figure(figsize=(15, 5))
    
    # Plotting function
    def subplot(position, data, title, ylabel):
        plt.subplot(1, 3, position)
        plt.bar(range(len(models)), data)
        plt.title(title)
        plt.xlabel('Model Type')
        plt.ylabel(ylabel)
        plt.xticks(range(len(models)), [m.__class__.__name__ for m in models])
    
    # Plot metrics
    subplot(1, metrics['sizes'], 'Model Size', 'Size (KB)')
    subplot(2, metrics['inference_times'], 'Inference Time', 'Time (ms/sample)')
    subplot(3, metrics['accuracies'], 'Model Accuracy', 'Accuracy (%)')
    
    plt.tight_layout()
    plt.savefig('quantization_comparison.png')
    plt.close()

# ----- Detailed Comparison -----
def detailed_comparison():
    # Initialize models
    model_fp32 = EnhancedCNN()
    model_ptq = QuantEnhancedCNN()
    model_qat = QuantEnhancedCNN()

    # 1. Train FP32 Model
    print("Training FP32 Model...")
    train_losses_fp32 = train(model_fp32)
    acc_fp32, class_acc_fp32 = evaluate(model_fp32)
    torch.save(model_fp32.state_dict(), "./model_fp32.pth")

    # 2. Post-Training Quantization (PTQ)
    print("\nPreparing Post-Training Quantization...")
    # Load pre-trained weights
    model_ptq.load_state_dict(model_fp32.state_dict())
    
    # Set quantization configuration
    model_ptq.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    # Fuse model 
    model_ptq.fuse_model()
    
    # Prepare for quantization
    torch.quantization.prepare(model_ptq, inplace=True)
    
    # Calibration
    for inputs, _ in trainloader:
        model_ptq(inputs)
        break
    
    # Convert to quantized model
    torch.quantization.convert(model_ptq, inplace=True)
    acc_ptq, class_acc_ptq = evaluate(model_ptq)
    torch.save(model_ptq.state_dict(), "./model_ptq.pth")

    # 3. Quantization-Aware Training (QAT)
    print("\nQuantization-Aware Training...")
    # Load pre-trained weights
    model_qat.load_state_dict(model_fp32.state_dict())
    
    # Set QAT configuration
    model_qat.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    
    # Fuse model
    model_qat.fuse_model()
    
    # Prepare for Quantization-Aware Training
    torch.quantization.prepare_qat(model_qat, inplace=True)
    
    # Train the quantization-aware model
    train_losses_qat = train(model_qat, epochs=5)
    
    # Convert to quantized model
    torch.quantization.convert(model_qat, inplace=True)
    
    # Evaluate
    acc_qat, class_acc_qat = evaluate(model_qat)
    torch.save(model_qat.state_dict(), "./model_qat.pth")

    # Benchmark
    dummy_input = torch.randn(1, 1, 28, 28)
    t_fp32, mem_fp32 = benchmark(model_fp32, dummy_input)
    t_ptq, mem_ptq = benchmark(model_ptq, dummy_input)
    t_qat, mem_qat = benchmark(model_qat, dummy_input)

    # Compute model sizes
    size_fp32 = os.path.getsize("./model_fp32.pth") / 1024
    size_ptq = os.path.getsize("./model_ptq.pth") / 1024
    size_qat = os.path.getsize("./model_qat.pth") / 1024

    # Detailed Print Out
    print("\n=== Comprehensive Model Comparison ===")
    
    print("\n1. Model Sizes:")
    print(f"FP32 Model:   {size_fp32:.2f} KB")
    print(f"PTQ Model:    {size_ptq:.2f} KB")
    print(f"QAT Model:    {size_qat:.2f} KB")

    print("\n2. Inference Times (ms/sample):")
    print(f"FP32 Model:   {t_fp32:.2f} ms")
    print(f"PTQ Model:    {t_ptq:.2f} ms")
    print(f"QAT Model:    {t_qat:.2f} ms")

    print("\n3. Memory Usage:")
    print(f"FP32 Model:   {mem_fp32} bytes")
    print(f"PTQ Model:    {mem_ptq} bytes")
    print(f"QAT Model:    {mem_qat} bytes")

    print("\n4. Overall Accuracy:")
    print(f"FP32 Model:   {acc_fp32:.2f}%")
    print(f"PTQ Model:    {acc_ptq:.2f}%")
    print(f"QAT Model:    {acc_qat:.2f}%")

    print("\n5. Per-Class Accuracy:")
    print("Digit\tFP32\t\tPTQ\t\tQAT")
    for i in range(10):
        print(f"{i}:\t{class_acc_fp32[i]:.2f}%\t\t{class_acc_ptq[i]:.2f}%\t\t{class_acc_qat[i]:.2f}%")

    # Optional: Visualization
    metrics = {
        'sizes': [size_fp32, size_ptq, size_qat],
        'inference_times': [t_fp32, t_ptq, t_qat],
        'accuracies': [acc_fp32, acc_ptq, acc_qat]
    }
    plot_comparisons([model_fp32, model_ptq, model_qat], metrics)

# Run the comparison
detailed_comparison()

Training FP32 Model...
Epoch 1, Loss: 0.2360
Epoch 2, Loss: 0.0892
Epoch 3, Loss: 0.0728
Epoch 4, Loss: 0.0623
Epoch 5, Loss: 0.0567

Preparing Post-Training Quantization...

Quantization-Aware Training...
Epoch 1, Loss: 0.0493
Epoch 2, Loss: 0.0455
Epoch 3, Loss: 0.0411
Epoch 4, Loss: 0.0378
Epoch 5, Loss: 0.0359

=== Comprehensive Model Comparison ===

1. Model Sizes:
FP32 Model:   216.47 KB
PTQ Model:    63.63 KB
QAT Model:    63.63 KB

2. Inference Times (ms/sample):
FP32 Model:   53.00 ms
PTQ Model:    25.43 ms
QAT Model:    25.58 ms

3. Memory Usage:
FP32 Model:   0 bytes
PTQ Model:    0 bytes
QAT Model:    0 bytes

4. Overall Accuracy:
FP32 Model:   98.77%
PTQ Model:    98.73%
QAT Model:    98.90%

5. Per-Class Accuracy:
Digit	FP32		PTQ		QAT
0:	99.39%		99.49%		99.59%
1:	99.65%		99.56%		99.91%
2:	99.61%		99.61%		99.32%
3:	98.91%		99.01%		98.71%
4:	97.56%		97.96%		99.39%
5:	99.10%		99.10%		98.99%
6:	99.16%		99.16%		97.91%
7:	97.76%		97.57%		98.93%
8:	98.77%		98.56%		98.87%
9:	97.7