# Model Optimization Techniques: Quantization and Knowledge Distillation

This notebook demonstrates practical implementation of model compression techniques using:
- **Real Dataset**: CIFAR-10 for image classification
- **Real Models**: ResNet architectures with different sizes
- **Quantization**: Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT)
- **Knowledge Distillation**: Teacher-Student training with different architectures

## 📋 Table of Contents
1. [Setup and Data Loading](#setup)
2. [Baseline Models](#baseline)
3. [Quantization Techniques](#quantization)
4. [Knowledge Distillation](#distillation)
5. [Performance Comparison](#comparison)
6. [Production Export](#export)

---

## 1. Setup and Data Loading

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install torchsummary
!pip install matplotlib seaborn

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.quantization as quant

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os
from collections import OrderedDict
import pandas as pd
from torchsummary import summary

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# CIFAR-10 Data Loading
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Download and load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Create data loaders
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
calibration_loader = DataLoader(testset, batch_size=32, shuffle=True, num_workers=2)  # For quantization calibration

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
num_classes = len(classes)

print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")
print(f"Number of classes: {num_classes}")

# Visualize some samples
def show_samples(data_loader, num_samples=8):
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i in range(num_samples):
        img = images[i].numpy().transpose((1, 2, 0))
        img = img * np.array([0.2023, 0.1994, 0.2010]) + np.array([0.4914, 0.4822, 0.4465])
        img = np.clip(img, 0, 1)
        
        ax = axes[i//4, i%4]
        ax.imshow(img)
        ax.set_title(f'{classes[labels[i]]}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples(test_loader)

## 2. Baseline Models

We'll create ResNet models of different sizes:
- **Teacher Model**: ResNet-34 (larger, more accurate)
- **Student Model**: ResNet-18 (smaller, faster)

In [None]:
# ResNet Building Blocks
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        
        # Add quantization stubs for QAT
        self.quant = quant.QuantStub()
        self.dequant = quant.DeQuantStub()
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.quant(x)  # For quantization
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        out = self.dequant(out)  # For quantization
        return out

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

# Create models
teacher_model = ResNet34().to(device)
student_model = ResNet18().to(device)

print("\n=== Model Architectures ===")
print("\nTeacher Model (ResNet-34):")
teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"Parameters: {teacher_params:,}")

print("\nStudent Model (ResNet-18):")
student_params = sum(p.numel() for p in student_model.parameters())
print(f"Parameters: {student_params:,}")
print(f"Size ratio: {teacher_params/student_params:.2f}x")

In [None]:
# Training and evaluation utilities
def train_model(model, train_loader, test_loader, epochs=20, lr=0.01):
    """Train a model with standard cross-entropy loss"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    train_losses = []
    train_accuracies = []
    test_accuracies = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        scheduler.step()
        
        train_acc = 100. * correct / total
        train_loss = running_loss / len(train_loader)
        
        # Testing
        test_acc = evaluate_model(model, test_loader)
        
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)
        
        if epoch % 5 == 0:
            print(f'Epoch {epoch:2d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%')
    
    return train_losses, train_accuracies, test_accuracies

def evaluate_model(model, test_loader):
    """Evaluate model accuracy"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return 100. * correct / total

def measure_model_performance(model, test_loader, num_runs=50):
    """Measure model size, inference time, and accuracy"""
    model.eval()
    
    # Model size (MB)
    param_size = 0
    buffer_size = 0
    
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_mb = (param_size + buffer_size) / (1024 ** 2)
    
    # Inference time
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    
    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input)
    
    # Measure
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(dummy_input)
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    end_time = time.time()
    
    avg_time_ms = (end_time - start_time) * 1000 / num_runs
    
    # Accuracy
    accuracy = evaluate_model(model, test_loader)
    
    return {
        'size_mb': size_mb,
        'inference_time_ms': avg_time_ms,
        'accuracy': accuracy,
        'parameters': sum(p.numel() for p in model.parameters())
    }

In [None]:
# Train Teacher Model (ResNet-34)
print("=== Training Teacher Model (ResNet-34) ===")
teacher_train_losses, teacher_train_acc, teacher_test_acc = train_model(
    teacher_model, train_loader, test_loader, epochs=25, lr=0.01
)

# Evaluate teacher performance
teacher_performance = measure_model_performance(teacher_model, test_loader)
print(f"\nTeacher Model Performance:")
print(f"  Size: {teacher_performance['size_mb']:.2f} MB")
print(f"  Inference Time: {teacher_performance['inference_time_ms']:.2f} ms")
print(f"  Accuracy: {teacher_performance['accuracy']:.2f}%")

# Save teacher model
torch.save(teacher_model.state_dict(), 'teacher_resnet34.pth')
print("\nTeacher model saved!")

## 3. Quantization Techniques

We'll implement three quantization approaches:
1. **Post-Training Quantization (PTQ)**: Quick quantization without retraining
2. **Quantization-Aware Training (QAT)**: Training with quantization simulation
3. **Dynamic Quantization**: Runtime quantization for inference

In [None]:
# 1. Post-Training Quantization (PTQ)
def apply_post_training_quantization(model, calibration_loader):
    """Apply post-training quantization"""
    # Create a copy of the model
    quantized_model = ResNet34()
    quantized_model.load_state_dict(model.state_dict())
    quantized_model.eval()
    
    # Set quantization config
    quantized_model.qconfig = quant.get_default_qconfig('fbgemm')
    
    # Fuse modules for better quantization
    quantized_model = torch.quantization.fuse_modules(quantized_model, 
                                                     [['conv1', 'bn1']], inplace=True)
    
    # Prepare model for quantization
    quantized_model = quant.prepare(quantized_model, inplace=True)
    
    # Calibration
    print("Calibrating model for quantization...")
    with torch.no_grad():
        for i, (inputs, _) in enumerate(calibration_loader):
            if i >= 100:  # Use 100 batches for calibration
                break
            quantized_model(inputs)
    
    # Convert to quantized model
    quantized_model = quant.convert(quantized_model, inplace=True)
    
    return quantized_model

# Apply PTQ to teacher model
print("=== Applying Post-Training Quantization ===")
ptq_model = apply_post_training_quantization(teacher_model, calibration_loader)
ptq_performance = measure_model_performance(ptq_model, test_loader)

print(f"\nPTQ Model Performance:")
print(f"  Size: {ptq_performance['size_mb']:.2f} MB")
print(f"  Inference Time: {ptq_performance['inference_time_ms']:.2f} ms")
print(f"  Accuracy: {ptq_performance['accuracy']:.2f}%")
print(f"  Compression Ratio: {teacher_performance['size_mb']/ptq_performance['size_mb']:.2f}x")
print(f"  Speedup: {teacher_performance['inference_time_ms']/ptq_performance['inference_time_ms']:.2f}x")
print(f"  Accuracy Drop: {teacher_performance['accuracy'] - ptq_performance['accuracy']:.2f}%")

In [None]:
# 2. Quantization-Aware Training (QAT)
def train_qat_model(model, train_loader, test_loader, epochs=10):
    """Train model with quantization-aware training"""
    # Prepare model for QAT
    qat_model = ResNet34()
    qat_model.load_state_dict(model.state_dict())
    qat_model.train()
    
    # Set QAT config
    qat_model.qconfig = quant.get_default_qat_qconfig('fbgemm')
    
    # Fuse and prepare for QAT
    qat_model = torch.quantization.fuse_modules(qat_model, [['conv1', 'bn1']], inplace=True)
    qat_model = quant.prepare_qat(qat_model, inplace=True)
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    
    print("Training QAT model...")
    for epoch in range(epochs):
        qat_model.train()
        running_loss = 0.0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = qat_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Evaluate
        if epoch % 2 == 0:
            qat_model.eval()
            test_acc = evaluate_model(qat_model, test_loader)
            print(f'Epoch {epoch}: Test Accuracy: {test_acc:.2f}%')
    
    # Convert to quantized model
    qat_model.eval()
    quantized_qat_model = quant.convert(qat_model, inplace=False)
    
    return quantized_qat_model

# Train QAT model
print("\n=== Training Quantization-Aware Model ===")
qat_model = train_qat_model(teacher_model, train_loader, test_loader, epochs=8)
qat_performance = measure_model_performance(qat_model, test_loader)

print(f"\nQAT Model Performance:")
print(f"  Size: {qat_performance['size_mb']:.2f} MB")
print(f"  Inference Time: {qat_performance['inference_time_ms']:.2f} ms")
print(f"  Accuracy: {qat_performance['accuracy']:.2f}%")
print(f"  Compression Ratio: {teacher_performance['size_mb']/qat_performance['size_mb']:.2f}x")
print(f"  Speedup: {teacher_performance['inference_time_ms']/qat_performance['inference_time_ms']:.2f}x")
print(f"  Accuracy Drop: {teacher_performance['accuracy'] - qat_performance['accuracy']:.2f}%")

In [None]:
# 3. Dynamic Quantization
def apply_dynamic_quantization(model):
    """Apply dynamic quantization to linear layers"""
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Linear}, dtype=torch.qint8
    )
    return quantized_model

# Apply dynamic quantization
print("\n=== Applying Dynamic Quantization ===")
dynamic_q_model = apply_dynamic_quantization(teacher_model)
dynamic_performance = measure_model_performance(dynamic_q_model, test_loader)

print(f"\nDynamic Quantization Model Performance:")
print(f"  Size: {dynamic_performance['size_mb']:.2f} MB")
print(f"  Inference Time: {dynamic_performance['inference_time_ms']:.2f} ms")
print(f"  Accuracy: {dynamic_performance['accuracy']:.2f}%")
print(f"  Compression Ratio: {teacher_performance['size_mb']/dynamic_performance['size_mb']:.2f}x")
print(f"  Speedup: {teacher_performance['inference_time_ms']/dynamic_performance['inference_time_ms']:.2f}x")
print(f"  Accuracy Drop: {teacher_performance['accuracy'] - dynamic_performance['accuracy']:.2f}%")

## 4. Knowledge Distillation

We'll train a smaller ResNet-18 model to mimic the behavior of our ResNet-34 teacher.

In [None]:
# Knowledge Distillation Implementation
class DistillationLoss(nn.Module):
    """Combined loss for knowledge distillation"""
    def __init__(self, alpha=0.7, temperature=4.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, student_logits, teacher_logits, target_labels):
        # Distillation loss (KL divergence between soft predictions)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)
        distillation_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
        
        # Standard classification loss
        classification_loss = self.ce_loss(student_logits, target_labels)
        
        # Combined loss
        total_loss = self.alpha * distillation_loss + (1 - self.alpha) * classification_loss
        
        return total_loss, distillation_loss, classification_loss

def train_student_with_distillation(teacher_model, student_model, train_loader, test_loader, 
                                  epochs=20, alpha=0.7, temperature=4.0, lr=0.01):
    """Train student model using knowledge distillation"""
    
    teacher_model.eval()  # Teacher is frozen
    student_model.train()
    
    # Loss and optimizer
    distillation_criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    optimizer = optim.SGD(student_model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    train_losses = []
    distillation_losses = []
    classification_losses = []
    test_accuracies = []
    
    print(f"Starting knowledge distillation training (α={alpha}, T={temperature})...")
    
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0
        running_dist_loss = 0.0
        running_class_loss = 0.0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Get teacher predictions (no gradients)
            with torch.no_grad():
                teacher_logits = teacher_model(inputs)
            
            # Get student predictions
            optimizer.zero_grad()
            student_logits = student_model(inputs)
            
            # Calculate combined loss
            total_loss, dist_loss, class_loss = distillation_criterion(
                student_logits, teacher_logits, targets
            )
            
            # Backward pass
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            running_dist_loss += dist_loss.item()
            running_class_loss += class_loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}: '
                      f'Total: {total_loss.item():.4f}, '
                      f'Distill: {dist_loss.item():.4f}, '
                      f'Class: {class_loss.item():.4f}')
        
        scheduler.step()
        
        # Record losses
        train_losses.append(running_loss / len(train_loader))
        distillation_losses.append(running_dist_loss / len(train_loader))
        classification_losses.append(running_class_loss / len(train_loader))
        
        # Evaluate
        if epoch % 5 == 0:
            test_acc = evaluate_model(student_model, test_loader)
            test_accuracies.append(test_acc)
            print(f'Epoch {epoch}: Test Accuracy: {test_acc:.2f}%')
    
    return train_losses, distillation_losses, classification_losses, test_accuracies

# Train student model with knowledge distillation
print("\n=== Knowledge Distillation Training ===")
distilled_student = ResNet18().to(device)

dist_losses, kd_losses, class_losses, kd_test_acc = train_student_with_distillation(
    teacher_model, distilled_student, train_loader, test_loader,
    epochs=20, alpha=0.7, temperature=4.0, lr=0.01
)

# Evaluate distilled student
distilled_performance = measure_model_performance(distilled_student, test_loader)

print(f"\nDistilled Student Model Performance:")
print(f"  Size: {distilled_performance['size_mb']:.2f} MB")
print(f"  Inference Time: {distilled_performance['inference_time_ms']:.2f} ms")
print(f"  Accuracy: {distilled_performance['accuracy']:.2f}%")
print(f"  Compression vs Teacher: {teacher_performance['size_mb']/distilled_performance['size_mb']:.2f}x")
print(f"  Speedup vs Teacher: {teacher_performance['inference_time_ms']/distilled_performance['inference_time_ms']:.2f}x")

In [None]:
# Train baseline student model (without distillation) for comparison
print("\n=== Training Baseline Student Model (No Distillation) ===")
baseline_student = ResNet18().to(device)

baseline_train_losses, baseline_train_acc, baseline_test_acc = train_model(
    baseline_student, train_loader, test_loader, epochs=20, lr=0.01
)

baseline_performance = measure_model_performance(baseline_student, test_loader)

print(f"\nBaseline Student Model Performance:")
print(f"  Size: {baseline_performance['size_mb']:.2f} MB")
print(f"  Inference Time: {baseline_performance['inference_time_ms']:.2f} ms")
print(f"  Accuracy: {baseline_performance['accuracy']:.2f}%")

print(f"\nDistillation Improvement:")
print(f"  Accuracy Gain: {distilled_performance['accuracy'] - baseline_performance['accuracy']:.2f}%")

## 5. Performance Comparison

Let's compare all our optimization techniques:

In [None]:
# Comprehensive Performance Comparison
results_data = {
    'Model': [
        'Teacher (ResNet-34)',
        'Post-Training Quantization',
        'Quantization-Aware Training',
        'Dynamic Quantization',
        'Baseline Student (ResNet-18)',
        'Distilled Student (ResNet-18)'
    ],
    'Size (MB)': [
        teacher_performance['size_mb'],
        ptq_performance['size_mb'],
        qat_performance['size_mb'],
        dynamic_performance['size_mb'],
        baseline_performance['size_mb'],
        distilled_performance['size_mb']
    ],
    'Inference Time (ms)': [
        teacher_performance['inference_time_ms'],
        ptq_performance['inference_time_ms'],
        qat_performance['inference_time_ms'],
        dynamic_performance['inference_time_ms'],
        baseline_performance['inference_time_ms'],
        distilled_performance['inference_time_ms']
    ],
    'Accuracy (%)': [
        teacher_performance['accuracy'],
        ptq_performance['accuracy'],
        qat_performance['accuracy'],
        dynamic_performance['accuracy'],
        baseline_performance['accuracy'],
        distilled_performance['accuracy']
    ],
    'Parameters': [
        teacher_performance['parameters'],
        ptq_performance['parameters'],
        qat_performance['parameters'],
        dynamic_performance['parameters'],
        baseline_performance['parameters'],
        distilled_performance['parameters']
    ]
}

# Create DataFrame
results_df = pd.DataFrame(results_data)

# Calculate compression ratios and speedups
baseline_size = teacher_performance['size_mb']
baseline_time = teacher_performance['inference_time_ms']
baseline_acc = teacher_performance['accuracy']

results_df['Size Compression'] = baseline_size / results_df['Size (MB)']
results_df['Speed Improvement'] = baseline_time / results_df['Inference Time (ms)']
results_df['Accuracy Drop'] = baseline_acc - results_df['Accuracy (%)']

# Display results
print("=== COMPREHENSIVE PERFORMANCE COMPARISON ===")
print(results_df.round(2))

# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Model Size Comparison
axes[0, 0].bar(range(len(results_df)), results_df['Size (MB)'], color='skyblue')
axes[0, 0].set_title('Model Size Comparison')
axes[0, 0].set_ylabel('Size (MB)')
axes[0, 0].set_xticks(range(len(results_df)))
axes[0, 0].set_xticklabels(results_df['Model'], rotation=45, ha='right')

# Inference Time Comparison
axes[0, 1].bar(range(len(results_df)), results_df['Inference Time (ms)'], color='lightcoral')
axes[0, 1].set_title('Inference Time Comparison')
axes[0, 1].set_ylabel('Inference Time (ms)')
axes[0, 1].set_xticks(range(len(results_df)))
axes[0, 1].set_xticklabels(results_df['Model'], rotation=45, ha='right')

# Accuracy Comparison
axes[1, 0].bar(range(len(results_df)), results_df['Accuracy (%)'], color='lightgreen')
axes[1, 0].set_title('Accuracy Comparison')
axes[1, 0].set_ylabel('Accuracy (%)')
axes[1, 0].set_xticks(range(len(results_df)))
axes[1, 0].set_xticklabels(results_df['Model'], rotation=45, ha='right')

# Efficiency Plot (Size vs Accuracy)
scatter_colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown']
for i, (idx, row) in enumerate(results_df.iterrows()):
    axes[1, 1].scatter(row['Size (MB)'], row['Accuracy (%)'], 
                      s=100, c=scatter_colors[i], label=row['Model'], alpha=0.7)

axes[1, 1].set_title('Efficiency: Model Size vs Accuracy')
axes[1, 1].set_xlabel('Size (MB)')
axes[1, 1].set_ylabel('Accuracy (%)')
axes[1, 1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()

# Print key insights
print("\n=== KEY INSIGHTS ===")
print(f"Best compression: {results_df.loc[results_df['Size Compression'].idxmax(), 'Model']} "
      f"({results_df['Size Compression'].max():.2f}x)")
print(f"Best speedup: {results_df.loc[results_df['Speed Improvement'].idxmax(), 'Model']} "
      f"({results_df['Speed Improvement'].max():.2f}x)")
print(f"Best accuracy: {results_df.loc[results_df['Accuracy (%)'].idxmax(), 'Model']} "
      f"({results_df['Accuracy (%)'].max():.2f}%)")
print(f"Best trade-off: Distilled Student - "
      f"{distilled_performance['size_mb']/teacher_performance['size_mb']:.1f}x smaller, "
      f"{teacher_performance['inference_time_ms']/distilled_performance['inference_time_ms']:.1f}x faster, "
      f"only {teacher_performance['accuracy'] - distilled_performance['accuracy']:.1f}% accuracy drop")

## 6. Production Export

Export optimized models for production deployment:

In [None]:
# Export models for production
import torch.jit

def export_model_for_production(model, model_name, example_input):
    """Export model in multiple formats for production"""
    model.eval()
    
    print(f"Exporting {model_name}...")
    
    # 1. TorchScript (JIT)
    try:
        traced_model = torch.jit.trace(model, example_input)
        traced_model.save(f'{model_name}_torchscript.pt')
        print(f"  ✓ TorchScript saved as {model_name}_torchscript.pt")
    except Exception as e:
        print(f"  ✗ TorchScript export failed: {e}")
    
    # 2. State Dict (for PyTorch loading)
    torch.save(model.state_dict(), f'{model_name}_state_dict.pth')
    print(f"  ✓ State dict saved as {model_name}_state_dict.pth")
    
    # 3. ONNX (commented out as it might fail with quantized models)
    # try:
    #     torch.onnx.export(
    #         model, example_input, f'{model_name}.onnx',
    #         export_params=True, opset_version=11,
    #         do_constant_folding=True,
    #         input_names=['input'], output_names=['output']
    #     )
    #     print(f"  ✓ ONNX saved as {model_name}.onnx")
    # except Exception as e:
    #     print(f"  ✗ ONNX export failed: {e}")

# Example input for tracing
example_input = torch.randn(1, 3, 32, 32).to(device)

# Export all models
print("=== EXPORTING MODELS FOR PRODUCTION ===")

models_to_export = [
    (teacher_model, "teacher_resnet34"),
    (distilled_student, "distilled_resnet18"),
    (baseline_student, "baseline_resnet18"),
    (dynamic_q_model, "dynamic_quantized_resnet34")
]

for model, name in models_to_export:
    export_model_for_production(model, name, example_input)
    print()

# Create a deployment guide
deployment_guide = """
=== DEPLOYMENT GUIDE ===

Model Recommendations:

1. HIGHEST ACCURACY: teacher_resnet34
   - Use when: Maximum accuracy is required, resources are not constrained
   - Size: {:.1f}MB, Accuracy: {:.2f}%

2. BEST BALANCE: distilled_resnet18 
   - Use when: Need good accuracy with smaller size
   - Size: {:.1f}MB, Accuracy: {:.2f}% (only {:.1f}% drop from teacher)
   - {:.1f}x smaller, {:.1f}x faster than teacher

3. SMALLEST/FASTEST: dynamic_quantized_resnet34
   - Use when: Extreme size/speed constraints
   - Size: {:.1f}MB, Accuracy: {:.2f}%
   - {:.1f}x compression, {:.1f}x speedup

Loading in Production:
```python
# Method 1: Load TorchScript
model = torch.jit.load('distilled_resnet18_torchscript.pt')
model.eval()

# Method 2: Load state dict
model = ResNet18()
model.load_state_dict(torch.load('distilled_resnet18_state_dict.pth'))
model.eval()
```
""".format(
    teacher_performance['size_mb'], teacher_performance['accuracy'],
    distilled_performance['size_mb'], distilled_performance['accuracy'],
    teacher_performance['accuracy'] - distilled_performance['accuracy'],
    teacher_performance['size_mb'] / distilled_performance['size_mb'],
    teacher_performance['inference_time_ms'] / distilled_performance['inference_time_ms'],
    dynamic_performance['size_mb'], dynamic_performance['accuracy'],
    teacher_performance['size_mb'] / dynamic_performance['size_mb'],
    teacher_performance['inference_time_ms'] / dynamic_performance['inference_time_ms']
)

print(deployment_guide)

# Save deployment guide
with open('deployment_guide.txt', 'w') as f:
    f.write(deployment_guide)

print("\n✓ All models exported and deployment guide saved!")

## 🎯 Conclusion

This notebook demonstrated practical implementation of model optimization techniques:

### ✅ **Key Takeaways:**

1. **Quantization Results:**
   - Post-Training Quantization: Quick 4x compression with minimal accuracy loss
   - Quantization-Aware Training: Better accuracy preservation through training
   - Dynamic Quantization: Good balance for CPU deployment

2. **Knowledge Distillation Results:**
   - Student model achieved significant compression while maintaining competitive accuracy
   - Soft targets from teacher provide richer learning signal than hard labels
   - Temperature scaling crucial for effective knowledge transfer

3. **Production Considerations:**
   - Always measure actual performance on target hardware
   - Consider accuracy-efficiency trade-offs based on use case
   - Export models in multiple formats for deployment flexibility

### 🚀 **Next Steps:**
- Try these techniques on your own models and datasets
- Experiment with different compression ratios and architectures
- Combine multiple techniques for maximum optimization
- Benchmark on your target deployment hardware

---
*This notebook provides a foundation for model optimization - adapt the techniques to your specific use case and requirements.*