# Module 4 - Exercise 3: PyTorch Optimization with torch.compile

## Learning Objectives
- Understand the basics of torch.compile for model optimization
- Compare performance between compiled and non-compiled models
- Analyze the impact of batch size on optimization benefits
- Evaluate performance across different devices (CPU vs GPU)
- Learn practical optimization techniques for production models

## Environment Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import time
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Section 1: Understanding torch.compile with Simple Functions

In this section, we'll start with simple functions to understand how torch.compile works and measure its performance benefits.

In [None]:
def simple_computation(x: torch.Tensor) -> torch.Tensor:
    """A simple computation function for demonstration."""
    y = torch.sin(x) * torch.cos(x)
    z = y ** 2 + torch.exp(-y.abs())
    return z.sum()

# TODO: Create a compiled version of simple_computation using torch.compile
# Hint: Use torch.compile() with the function
compiled_simple_computation = None

# Test both versions
test_tensor = torch.randn(1000, 1000, device=device)

# Warmup runs (important for compiled functions)
for _ in range(3):
    _ = simple_computation(test_tensor)
    if compiled_simple_computation:
        _ = compiled_simple_computation(test_tensor)

# Performance comparison
num_iterations = 100

# TODO: Measure time for non-compiled version
# Record start time, run the function num_iterations times, record end time
normal_time = None

# TODO: Measure time for compiled version
# Record start time, run the compiled function num_iterations times, record end time
compiled_time = None

if normal_time and compiled_time:
    print(f"Normal execution time: {normal_time:.4f} seconds")
    print(f"Compiled execution time: {compiled_time:.4f} seconds")
    print(f"Speedup: {normal_time/compiled_time:.2f}x")

## Section 2: Creating a Custom Dataset and Model

Now let's create a more realistic scenario with a dataset and a neural network model to see how torch.compile performs with real ML workloads.

In [None]:
class SyntheticDataset(Dataset):
    """A synthetic dataset for performance testing."""
    
    def __init__(self, num_samples: int = 10000, input_dim: int = 128):
        # TODO: Initialize the dataset with random data
        # Create self.data as random tensor of shape (num_samples, input_dim)
        # Create self.targets as random integers from 0 to 9
        self.data = None
        self.targets = None
        
    def __len__(self):
        # TODO: Return the number of samples
        return None
    
    def __getitem__(self, idx):
        # TODO: Return a tuple of (data, target) for the given index
        return None, None

# Create dataset
dataset = SyntheticDataset(num_samples=10000, input_dim=128)
print(f"Dataset size: {len(dataset) if dataset else 'Not implemented'}")

In [None]:
class OptimizedModel(nn.Module):
    """A moderately complex model for optimization testing."""
    
    def __init__(self, input_dim: int = 128, hidden_dim: int = 256, num_classes: int = 10):
        super().__init__()
        # TODO: Define the model architecture
        # Create a 3-layer MLP with:
        # - fc1: input_dim -> hidden_dim
        # - fc2: hidden_dim -> hidden_dim
        # - fc3: hidden_dim -> num_classes
        # - bn1, bn2: BatchNorm1d layers for hidden_dim
        self.fc1 = None
        self.bn1 = None
        self.fc2 = None
        self.bn2 = None
        self.fc3 = None
        
    def forward(self, x):
        # TODO: Implement forward pass
        # Use ReLU activation and batch normalization
        # fc1 -> bn1 -> relu -> fc2 -> bn2 -> relu -> fc3
        return None

# Create model instances
model_normal = OptimizedModel().to(device)
model_compiled = OptimizedModel().to(device)

# TODO: Compile the model_compiled using torch.compile
# Try with mode='default' first
model_compiled = None

print(f"Model created with {sum(p.numel() for p in model_normal.parameters())} parameters")

## Section 3: Performance Comparison with Different Batch Sizes

Let's analyze how batch size affects the performance benefits of torch.compile.

In [None]:
def benchmark_inference(model, dataloader, num_batches: int = 50) -> float:
    """Benchmark model inference performance."""
    model.eval()
    
    # TODO: Implement benchmarking
    # 1. Warmup with 3 batches
    # 2. Time num_batches iterations
    # 3. Return average time per batch
    
    with torch.no_grad():
        # Warmup
        warmup_batches = 3
        # TODO: Run warmup iterations
        
        # Actual benchmark
        start_time = None
        # TODO: Run num_batches iterations and measure time
        end_time = None
        
    avg_time = None  # TODO: Calculate average time per batch
    return avg_time

# Test different batch sizes
batch_sizes = [1, 4, 16, 32, 64, 128]
normal_times = []
compiled_times = []

for batch_size in batch_sizes:
    # TODO: Create DataLoader with current batch_size
    dataloader = None
    
    if dataloader and model_normal and model_compiled:
        # Benchmark normal model
        normal_time = benchmark_inference(model_normal, dataloader)
        normal_times.append(normal_time)
        
        # Benchmark compiled model
        compiled_time = benchmark_inference(model_compiled, dataloader)
        compiled_times.append(compiled_time)
        
        print(f"Batch size {batch_size:3d}: Normal={normal_time:.4f}s, "
              f"Compiled={compiled_time:.4f}s, Speedup={normal_time/compiled_time:.2f}x")

In [None]:
# Visualize the results
if normal_times and compiled_times:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # TODO: Create bar plot comparing normal vs compiled times
    # Use ax1 for the comparison
    
    # TODO: Create line plot showing speedup ratio across batch sizes
    # Use ax2 for the speedup plot
    
    plt.tight_layout()
    plt.show()

## Section 4: Device Comparison (CPU vs GPU)

Let's compare the optimization benefits across different devices to understand where torch.compile provides the most value.

In [None]:
def compare_devices(batch_size: int = 32, num_batches: int = 30):
    """Compare performance across CPU and GPU (if available)."""
    results = {}
    
    # TODO: Create a fixed dataset and dataloader
    dataset = SyntheticDataset(num_samples=1000, input_dim=128)
    dataloader = None  # TODO: Create DataLoader with batch_size
    
    for device_name in ['cpu', 'cuda']:
        if device_name == 'cuda' and not torch.cuda.is_available():
            print("CUDA not available, skipping GPU tests")
            continue
            
        current_device = torch.device(device_name)
        print(f"\nTesting on {device_name.upper()}...")
        
        # TODO: Create and move models to current_device
        model_normal = None
        model_compiled = None
        
        if model_normal and model_compiled and dataloader:
            # Move data to device and benchmark
            device_dataloader = [(data.to(current_device), target.to(current_device)) 
                                for data, target in dataloader]
            
            # TODO: Benchmark both models
            normal_time = None
            compiled_time = None
            
            results[device_name] = {
                'normal': normal_time,
                'compiled': compiled_time,
                'speedup': normal_time / compiled_time if compiled_time else 0
            }
    
    return results

# Run device comparison
device_results = compare_devices(batch_size=64, num_batches=30)

# Display results
for device_name, metrics in device_results.items():
    print(f"\n{device_name.upper()} Results:")
    print(f"  Normal: {metrics['normal']:.4f}s")
    print(f"  Compiled: {metrics['compiled']:.4f}s")
    print(f"  Speedup: {metrics['speedup']:.2f}x")

## Section 5: Advanced Compilation Modes

torch.compile offers different modes that trade off compilation time for runtime performance. Let's explore these modes.

In [None]:
# Compilation modes to test
compile_modes = ['default', 'reduce-overhead', 'max-autotune']
mode_results = {}

# Create test data
test_dataset = SyntheticDataset(num_samples=2000, input_dim=128)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for mode in compile_modes:
    print(f"\nTesting mode: {mode}")
    
    # TODO: Create a fresh model and compile it with the current mode
    model = None
    compiled_model = None  # TODO: Use torch.compile with mode=mode
    
    if compiled_model:
        # Measure compilation + first run time
        compile_start = time.time()
        with torch.no_grad():
            for i, (data, target) in enumerate(test_loader):
                if i >= 1:  # Just one batch for compilation
                    break
                _ = compiled_model(data.to(device))
        compile_time = time.time() - compile_start
        
        # TODO: Benchmark runtime performance after compilation
        runtime = None  # TODO: Use benchmark_inference function
        
        mode_results[mode] = {
            'compile_time': compile_time,
            'runtime': runtime
        }
        
        print(f"  Compilation time: {compile_time:.2f}s")
        print(f"  Runtime per batch: {runtime:.4f}s")

## Section 6: Training Loop Optimization

Let's see how torch.compile affects a complete training loop, including both forward and backward passes.

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    """Train model for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    start_time = time.time()
    
    for data, target in dataloader:
        # TODO: Implement training step
        # 1. Move data to device
        # 2. Zero gradients
        # 3. Forward pass
        # 4. Calculate loss
        # 5. Backward pass
        # 6. Optimizer step
        
        data, target = None, None  # TODO: Move to device
        
        # TODO: Complete training step
        loss = None
        
        if loss:
            total_loss += loss.item()
            num_batches += 1
    
    epoch_time = time.time() - start_time
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    
    return avg_loss, epoch_time

# Prepare training
train_dataset = SyntheticDataset(num_samples=5000, input_dim=128)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Create models for training comparison
model_train_normal = OptimizedModel().to(device)
model_train_compiled = OptimizedModel().to(device)

# TODO: Compile the training model
model_train_compiled = None

# Create optimizers and criterion
optimizer_normal = torch.optim.Adam(model_train_normal.parameters(), lr=0.001)
optimizer_compiled = torch.optim.Adam(model_train_compiled.parameters(), lr=0.001) if model_train_compiled else None
criterion = nn.CrossEntropyLoss()

# Train for a few epochs
num_epochs = 3
print("Training comparison:")
print("-" * 50)

for epoch in range(num_epochs):
    # Train normal model
    loss_normal, time_normal = train_epoch(
        model_train_normal, train_loader, optimizer_normal, criterion, device
    )
    
    # Train compiled model
    if model_train_compiled and optimizer_compiled:
        loss_compiled, time_compiled = train_epoch(
            model_train_compiled, train_loader, optimizer_compiled, criterion, device
        )
    else:
        loss_compiled, time_compiled = 0, 0
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Normal - Loss: {loss_normal:.4f}, Time: {time_normal:.2f}s")
    print(f"  Compiled - Loss: {loss_compiled:.4f}, Time: {time_compiled:.2f}s")
    if time_compiled > 0:
        print(f"  Speedup: {time_normal/time_compiled:.2f}x")