# Tutorial 14: Performance Optimization

This tutorial covers comprehensive performance optimization techniques for PyTorch models, from profiling to advanced optimization strategies.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
import time
import numpy as np
from torch.profiler import profile, record_function, ProfilerActivity
import torch.cuda.amp as amp
from torch.nn.parallel import DataParallel, DistributedDataParallel
import matplotlib.pyplot as plt
import psutil
import gc

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. PyTorch Profiler

The PyTorch profiler is essential for identifying performance bottlenecks in your models.

In [None]:
# Define a simple model for profiling
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Profile the model
model = SimpleModel().to(device)
inputs = torch.randn(32, 3, 32, 32).to(device)

# Use PyTorch profiler
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
             record_shapes=True,
             profile_memory=True,
             with_stack=True) as prof:
    with record_function("model_inference"):
        for _ in range(10):
            model(inputs)

# Print profiler results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

## 2. Memory Optimization

Memory optimization is crucial for training large models. Let's explore gradient checkpointing and other techniques.

In [None]:
def get_memory_usage():
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**2  # MB
    else:
        return psutil.Process().memory_info().rss / 1024**2  # MB

# Memory-efficient gradient checkpointing
class CheckpointedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Dropout(0.1)
            ) for _ in range(10)
        ])
        self.final = nn.Linear(1024, 10)
    
    def forward(self, x):
        for layer in self.layers:
            # Use checkpoint to trade compute for memory
            x = torch.utils.checkpoint.checkpoint(layer, x)
        return self.final(x)

In [None]:
# Compare memory usage
x = torch.randn(128, 1024).to(device)

# Without checkpointing
regular_model = nn.Sequential(*[
    nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(0.1))
    for _ in range(10)
] + [nn.Linear(1024, 10)]).to(device)

mem_before = get_memory_usage()
y1 = regular_model(x)
loss1 = y1.sum()
loss1.backward()
mem_regular = get_memory_usage() - mem_before
print(f"Regular model: {mem_regular:.2f} MB")

# Clear memory
del regular_model, y1, loss1
torch.cuda.empty_cache() if torch.cuda.is_available() else None

# With checkpointing
checkpointed_model = CheckpointedModel().to(device)
optimizer = torch.optim.Adam(checkpointed_model.parameters())
optimizer.zero_grad()

mem_before = get_memory_usage()
y2 = checkpointed_model(x)
loss2 = y2.sum()
loss2.backward()
mem_checkpoint = get_memory_usage() - mem_before
print(f"Checkpointed model: {mem_checkpoint:.2f} MB")
print(f"Memory saved: {(1 - mem_checkpoint/mem_regular)*100:.1f}%")

## 3. Mixed Precision Training

Mixed precision training can significantly speed up training while maintaining model accuracy.

In [None]:
# Create a model for mixed precision demo
class MixedPrecisionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
# Training function with mixed precision
def train_with_amp(model, dataloader, use_amp=True, epochs=2):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters())
    scaler = amp.GradScaler() if use_amp else None
    
    model.train()
    total_time = 0
    losses = []
    
    for epoch in range(epochs):
        epoch_start = time.time()
        epoch_loss = 0
        
        for i, (inputs, targets) in enumerate(dataloader):
            if i >= 10:  # Limit iterations for demo
                break
                
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            
            if use_amp:
                # Mixed precision forward pass
                with amp.autocast():
                    outputs = model(inputs)
                    loss = F.cross_entropy(outputs, targets)
                
                # Scaled backward pass
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Regular forward/backward
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, targets)
                loss.backward()
                optimizer.step()
            
            epoch_loss += loss.item()
        
        epoch_time = time.time() - epoch_start
        total_time += epoch_time
        losses.append(epoch_loss / (i + 1))
    
    return total_time / epochs, losses

In [None]:
# Create dummy dataset
class DummyDataset(Dataset):
    def __init__(self, size=1000):
        self.size = size
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        return torch.randn(3, 32, 32), torch.randint(0, 10, (1,)).item()

dataset = DummyDataset()
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)

# Compare training times
model_fp32 = MixedPrecisionModel()
model_amp = MixedPrecisionModel()

print("Training with FP32...")
time_fp32, losses_fp32 = train_with_amp(model_fp32, dataloader, use_amp=False)
print(f"Average epoch time: {time_fp32:.3f}s")

print("\nTraining with AMP...")
time_amp, losses_amp = train_with_amp(model_amp, dataloader, use_amp=True)
print(f"Average epoch time: {time_amp:.3f}s")
print(f"Speedup: {time_fp32/time_amp:.2f}x")

In [None]:
# Visualize training losses
plt.figure(figsize=(10, 5))
plt.plot(losses_fp32, 'o-', label='FP32', linewidth=2)
plt.plot(losses_amp, 's-', label='Mixed Precision', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.grid(True)
plt.show()

## 4. Data Loading Optimization

Efficient data loading is crucial for GPU utilization. Let's explore various optimization techniques.

In [None]:
# Optimized dataset with caching and prefetching
class OptimizedDataset(Dataset):
    def __init__(self, size=1000, cache_size=100):
        self.size = size
        self.cache_size = cache_size
        self.cache = {}
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        # Simple caching mechanism
        if idx in self.cache:
            return self.cache[idx]
        
        # Simulate data loading
        image = torch.randn(3, 32, 32)
        label = torch.randint(0, 10, (1,)).item()
        
        # Cache recent items
        if len(self.cache) < self.cache_size:
            self.cache[idx] = (image, label)
        
        return image, label

In [None]:
# Benchmark data loading performance
def benchmark_dataloader(dataset, num_workers, pin_memory=False, prefetch_factor=2):
    dataloader = DataLoader(
        dataset,
        batch_size=128,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=(num_workers > 0),
        prefetch_factor=prefetch_factor if num_workers > 0 else 2
    )
    
    # Warmup
    for i, _ in enumerate(dataloader):
        if i >= 5:
            break
    
    start_time = time.time()
    for i, (data, target) in enumerate(dataloader):
        if i >= 50:  # Limit iterations
            break
        # Simulate processing
        data = data.to(device, non_blocking=True)
    
    total_time = time.time() - start_time
    return total_time

dataset = OptimizedDataset(5000)

# Test different configurations
results = []
configs = [
    (0, False),
    (2, False),
    (2, True),
    (4, False),
    (4, True),
]

print("Data loading benchmark:")
for num_workers, pin_memory in configs:
    time_taken = benchmark_dataloader(dataset, num_workers, pin_memory)
    results.append((num_workers, pin_memory, time_taken))
    print(f"Workers: {num_workers}, Pin memory: {pin_memory} - Time: {time_taken:.3f}s")

In [None]:
# Visualize data loading performance
fig, ax = plt.subplots(figsize=(10, 6))

workers = [r[0] for r in results]
times = [r[2] for r in results]
pin_memory = [r[1] for r in results]

colors = ['blue' if not pm else 'red' for pm in pin_memory]
labels = [f"Workers: {w}\nPin: {pm}" for w, pm in zip(workers, pin_memory)]

bars = ax.bar(range(len(results)), times, color=colors)
ax.set_xlabel('Configuration')
ax.set_ylabel('Time (seconds)')
ax.set_title('Data Loading Performance')
ax.set_xticks(range(len(results)))
ax.set_xticklabels(labels, rotation=45)

# Add legend
blue_patch = plt.Rectangle((0, 0), 1, 1, fc="blue")
red_patch = plt.Rectangle((0, 0), 1, 1, fc="red")
ax.legend([blue_patch, red_patch], ['No Pin Memory', 'Pin Memory'])

plt.tight_layout()
plt.show()

## 5. TorchScript Optimization

TorchScript can significantly improve inference performance by optimizing the computation graph.

In [None]:
# Create a model for scripting
class ScriptableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc = nn.Linear(64 * 6 * 6, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [None]:
# Compare scripted vs regular model
model = ScriptableModel().to(device)
model.eval()

# Script the model
scripted_model = torch.jit.script(model)

# Also try tracing
example_input = torch.randn(1, 3, 32, 32).to(device)
traced_model = torch.jit.trace(model, example_input)

# Benchmark
x = torch.randn(100, 3, 32, 32).to(device)
num_runs = 100

# Regular model
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
with torch.no_grad():
    for _ in range(num_runs):
        _ = model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
regular_time = time.time() - start

# Scripted model
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
with torch.no_grad():
    for _ in range(num_runs):
        _ = scripted_model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
scripted_time = time.time() - start

# Traced model
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
with torch.no_grad():
    for _ in range(num_runs):
        _ = traced_model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
traced_time = time.time() - start

print(f"Regular model: {regular_time:.3f}s")
print(f"Scripted model: {scripted_time:.3f}s (Speedup: {regular_time/scripted_time:.2f}x)")
print(f"Traced model: {traced_time:.3f}s (Speedup: {regular_time/traced_time:.2f}x)")

## 6. Tensor Operations Optimization

Efficient tensor operations are crucial for performance. Let's compare different approaches.

In [None]:
# Compare different implementations of the same operation
def batch_norm_manual(x, gamma, beta, eps=1e-5):
    """Manual batch normalization (inefficient)"""
    mean = x.mean(dim=(0, 2, 3), keepdim=True)
    var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    x_norm = (x - mean) / torch.sqrt(var + eps)
    return gamma.view(1, -1, 1, 1) * x_norm + beta.view(1, -1, 1, 1)

def batch_norm_optimized(x, gamma, beta, eps=1e-5):
    """Optimized batch normalization"""
    # Use running statistics and fused operations
    return F.batch_norm(x, None, None, gamma, beta, True, 0.1, eps)

# Test
batch_size, channels, height, width = 32, 64, 32, 32
x = torch.randn(batch_size, channels, height, width).to(device)
gamma = torch.ones(channels).to(device)
beta = torch.zeros(channels).to(device)

# Benchmark
num_runs = 100

# Manual implementation
start = time.time()
for _ in range(num_runs):
    _ = batch_norm_manual(x, gamma, beta)
manual_time = time.time() - start

# Optimized implementation
start = time.time()
for _ in range(num_runs):
    _ = batch_norm_optimized(x, gamma, beta)
optimized_time = time.time() - start

print(f"Manual batch norm: {manual_time:.3f}s")
print(f"Optimized batch norm: {optimized_time:.3f}s")
print(f"Speedup: {manual_time/optimized_time:.1f}x")

## 7. Memory-Efficient Attention

For transformer models, attention can be a memory bottleneck. Let's implement memory-efficient attention.

In [None]:
class EfficientAttention(nn.Module):
    def __init__(self, dim, num_heads=8, chunk_size=256):
        super().__init__()
        self.num_heads = num_heads
        self.chunk_size = chunk_size
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Chunked attention computation
        attn_chunks = []
        for i in range(0, N, self.chunk_size):
            end_idx = min(i + self.chunk_size, N)
            q_chunk = q[:, :, i:end_idx]
            
            # Compute attention for this chunk
            attn = (q_chunk @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn_chunk = attn @ v
            attn_chunks.append(attn_chunk)
        
        # Concatenate chunks
        x = torch.cat(attn_chunks, dim=2)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

# Compare with standard attention
class StandardAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

In [None]:
# Test memory usage
seq_len = 2048
dim = 512
batch_size = 4

# Standard attention
standard_attn = StandardAttention(dim).to(device)
efficient_attn = EfficientAttention(dim, chunk_size=256).to(device)

x = torch.randn(batch_size, seq_len, dim).to(device)

# Measure memory for standard attention
torch.cuda.empty_cache() if torch.cuda.is_available() else None
mem_before = get_memory_usage()
output1 = standard_attn(x)
mem_standard = get_memory_usage() - mem_before

# Clear memory
del output1
torch.cuda.empty_cache() if torch.cuda.is_available() else None

# Measure memory for efficient attention
mem_before = get_memory_usage()
output2 = efficient_attn(x)
mem_efficient = get_memory_usage() - mem_before

print(f"Standard attention memory: {mem_standard:.2f} MB")
print(f"Efficient attention memory: {mem_efficient:.2f} MB")
print(f"Memory saved: {(1 - mem_efficient/mem_standard)*100:.1f}%")

## 8. Custom Memory Pool

For applications with frequent tensor allocations, a custom memory pool can reduce overhead.

In [None]:
class TensorPool:
    """Simple tensor pool for reusing allocations"""
    def __init__(self):
        self.pool = {}
        self.stats = {'hits': 0, 'misses': 0}
    
    def get(self, shape, dtype=torch.float32, device='cpu'):
        key = (tuple(shape), dtype, str(device))
        if key in self.pool and len(self.pool[key]) > 0:
            self.stats['hits'] += 1
            return self.pool[key].pop()
        self.stats['misses'] += 1
        return torch.empty(shape, dtype=dtype, device=device)
    
    def release(self, tensor):
        key = (tuple(tensor.shape), tensor.dtype, str(tensor.device))
        if key not in self.pool:
            self.pool[key] = []
        self.pool[key].append(tensor)
    
    def clear(self):
        self.pool.clear()
    
    def get_stats(self):
        total = self.stats['hits'] + self.stats['misses']
        hit_rate = self.stats['hits'] / total if total > 0 else 0
        return {
            'hit_rate': hit_rate,
            'pool_size': sum(len(v) for v in self.pool.values()),
            **self.stats
        }

In [None]:
# Demonstrate tensor pool usage
pool = TensorPool()

# Simulate workload with repeated allocations
shapes = [(100, 100), (50, 200), (100, 100), (50, 200)]
num_iterations = 100

# Without pool
start = time.time()
for _ in range(num_iterations):
    tensors = []
    for shape in shapes:
        t = torch.empty(shape, device=device)
        tensors.append(t)
    # Simulate some computation
    for t in tensors:
        t.fill_(1.0)
time_without_pool = time.time() - start

# With pool
start = time.time()
for _ in range(num_iterations):
    tensors = []
    for shape in shapes:
        t = pool.get(shape, device=device)
        tensors.append(t)
    # Simulate some computation
    for t in tensors:
        t.fill_(1.0)
    # Release tensors back to pool
    for t in tensors:
        pool.release(t)
time_with_pool = time.time() - start

print(f"Time without pool: {time_without_pool:.3f}s")
print(f"Time with pool: {time_with_pool:.3f}s")
print(f"Speedup: {time_without_pool/time_with_pool:.2f}x")
print(f"\nPool statistics: {pool.get_stats()}")

## Performance Optimization Checklist

Here's a comprehensive checklist for optimizing PyTorch models:

In [None]:
# Performance optimization checklist
checklist = [
    ("Profile with torch.profiler", "Identify bottlenecks before optimizing"),
    ("Enable mixed precision training", "Use torch.cuda.amp for faster training"),
    ("Optimize data loading pipeline", "Use multiple workers, pin_memory, and persistent_workers"),
    ("Use gradient checkpointing", "Trade compute for memory in large models"),
    ("Apply model quantization", "Reduce model size and improve inference speed"),
    ("Enable CUDNN benchmarking", "torch.backends.cudnn.benchmark = True"),
    ("Use TorchScript for inference", "Compile models for production deployment"),
    ("Implement custom CUDA kernels", "For performance-critical operations"),
    ("Use distributed training", "Scale across multiple GPUs/nodes"),
    ("Monitor GPU utilization", "Ensure high GPU utilization throughout training"),
    ("Optimize tensor operations", "Use vectorized operations and avoid loops"),
    ("Reduce memory fragmentation", "Clear cache and reuse tensors when possible"),
    ("Use operator fusion", "Combine multiple operations into one"),
    ("Optimize model architecture", "Use efficient layers and reduce redundancy"),
    ("Profile memory usage", "Identify and fix memory leaks"),
]

print("Performance Optimization Checklist")
print("=" * 50)
for i, (task, description) in enumerate(checklist, 1):
    print(f"{i:2d}. [ ] {task}")
    print(f"       {description}")
    print()

## Summary

In this tutorial, we covered comprehensive performance optimization techniques:

1. **Profiling**: Using PyTorch profiler to identify bottlenecks
2. **Memory Optimization**: Gradient checkpointing and efficient memory usage
3. **Mixed Precision**: Accelerating training with automatic mixed precision
4. **Data Loading**: Optimizing data pipelines for maximum throughput
5. **TorchScript**: Compiling models for faster inference
6. **Tensor Operations**: Using efficient implementations
7. **Custom Attention**: Memory-efficient attention mechanisms
8. **Memory Pooling**: Reducing allocation overhead

Remember: **Always profile before optimizing!** Premature optimization can lead to complex code without meaningful performance gains.