# Distributed Training

This notebook demonstrates various distributed training techniques in PyTorch, including Data Parallel, Distributed Data Parallel, Model Parallel, and FSDP.

In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

## 1. Introduction to Distributed Training

Distributed training enables us to:
- Scale training across multiple GPUs and machines
- Train larger models that don't fit on a single GPU
- Reduce training time significantly

In [None]:
# Types of parallelism
print("Types of Parallelism:")
print("1. Data Parallel: Split data across devices, replicate model")
print("2. Model Parallel: Split model across devices")
print("3. Pipeline Parallel: Split model into stages")
print("4. Fully Sharded Data Parallel: Shard everything across devices")

# Check available backends
if torch.cuda.is_available():
    print("\nAvailable backends:")
    print("- NCCL (recommended for GPUs)")
    print("- Gloo (CPU and GPU support)")

## 2. Create Sample Dataset and Model

In [None]:
class SyntheticDataset(Dataset):
    """A synthetic dataset for demonstration."""
    def __init__(self, size=10000, input_dim=784, num_classes=10):
        self.size = size
        self.input_dim = input_dim
        self.num_classes = num_classes
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        # Generate random data
        data = torch.randn(self.input_dim)
        label = torch.randint(0, self.num_classes, (1,)).item()
        return data, label

class SimpleNet(nn.Module):
    """A simple neural network for demonstration."""
    def __init__(self, input_dim=784, hidden_dim=256, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Test the model
model = SimpleNet()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
dummy_input = torch.randn(4, 784)
output = model(dummy_input)
print(f"Output shape: {output.shape}")

## 3. Data Parallel (DP)

DataParallel is the simplest way to use multiple GPUs on a single machine.

In [None]:
# Data Parallel example
if torch.cuda.device_count() >= 2:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
    
    # Create model and wrap with DataParallel
    model_dp = SimpleNet()
    model_dp = DataParallel(model_dp)
    model_dp = model_dp.cuda()
    
    # Create a small batch
    batch_size = 32
    data = torch.randn(batch_size, 784).cuda()
    
    # Forward pass
    output = model_dp(data)
    print(f"Output shape: {output.shape}")
    
    # DataParallel automatically splits the batch across GPUs
    print(f"Batch split across GPUs: {batch_size} / {torch.cuda.device_count()} = {batch_size // torch.cuda.device_count()} per GPU")
else:
    print("DataParallel requires at least 2 GPUs")
    print("Demonstrating concept with single device...")
    
    model_dp = SimpleNet()
    if torch.cuda.is_available():
        model_dp = model_dp.cuda()
        data = torch.randn(32, 784).cuda()
    else:
        data = torch.randn(32, 784)
    
    output = model_dp(data)
    print(f"Output shape: {output.shape}")

### Training with DataParallel

In [None]:
def train_with_dp(num_epochs=2):
    """Train a model using DataParallel."""
    # Create model
    model = SimpleNet()
    
    # Wrap with DataParallel if multiple GPUs available
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = DataParallel(model)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Create dataset and dataloader
    dataset = SyntheticDataset(size=1000)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 5 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch} - Average Loss: {avg_loss:.4f}\n")

# Train with DataParallel
train_with_dp(num_epochs=2)

## 4. Distributed Data Parallel (DDP)

DDP is more efficient than DP and supports multi-node training.

In [None]:
# DDP setup functions
def setup_ddp(rank, world_size):
    """Initialize the distributed environment."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize process group
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    dist.init_process_group(backend, rank=rank, world_size=world_size)

def cleanup_ddp():
    """Clean up the distributed environment."""
    dist.destroy_process_group()

# Note: DDP requires spawning multiple processes
print("DDP Training Function:")
print("""def train_ddp(rank, world_size):
    setup_ddp(rank, world_size)
    
    # Create model and move to device
    device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
    model = SimpleNet().to(device)
    
    # Wrap with DDP
    ddp_model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)
    
    # Create dataset with DistributedSampler
    dataset = SyntheticDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
    
    # Training loop...
    cleanup_ddp()
""")

print("\nTo launch DDP training:")
print("torchrun --nproc_per_node=2 distributed_training.py --distributed")

### DDP Best Practices

In [None]:
# DDP best practices
print("DDP Best Practices:")
print("\n1. Use DistributedSampler:")
print("   - Ensures each process gets different data")
print("   - Call sampler.set_epoch(epoch) for proper shuffling")

print("\n2. Synchronize when needed:")
print("   - Use dist.barrier() for synchronization")
print("   - Use dist.all_reduce() for metric aggregation")

print("\n3. Save checkpoints from rank 0:")
print("""   if rank == 0:
       torch.save(model.state_dict(), 'checkpoint.pth')""")

print("\n4. Handle random seeds:")
print("""   torch.manual_seed(42 + rank)  # Different seed per process""")

## 5. Model Parallel

Model Parallel splits the model across multiple devices.

In [None]:
class ModelParallelNet(nn.Module):
    """A model split across multiple devices."""
    def __init__(self, input_dim=784, hidden_dim=256, num_classes=10):
        super().__init__()
        
        # Determine devices
        self.device1 = torch.device('cuda:0' if torch.cuda.device_count() > 0 else 'cpu')
        self.device2 = torch.device('cuda:1' if torch.cuda.device_count() > 1 else 'cpu')
        
        # Split model across devices
        self.fc1 = nn.Linear(input_dim, hidden_dim).to(self.device1)
        self.relu1 = nn.ReLU().to(self.device1)
        
        self.fc2 = nn.Linear(hidden_dim, hidden_dim).to(self.device2)
        self.relu2 = nn.ReLU().to(self.device2)
        self.fc3 = nn.Linear(hidden_dim, num_classes).to(self.device2)
        
    def forward(self, x):
        # Move input to first device
        x = x.to(self.device1)
        x = self.relu1(self.fc1(x))
        
        # Move to second device
        x = x.to(self.device2)
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        
        return x

# Create and test model parallel network
if torch.cuda.device_count() >= 2:
    print("Creating Model Parallel network across 2 GPUs")
    model_mp = ModelParallelNet()
    
    # Test forward pass
    data = torch.randn(4, 784)
    output = model_mp(data)
    print(f"Output shape: {output.shape}")
    print(f"Output device: {output.device}")
else:
    print("Model Parallel requires at least 2 GPUs")
    print("Demonstrating concept with CPU...")
    model_mp = ModelParallelNet()
    data = torch.randn(4, 784)
    output = model_mp(data)
    print(f"Output shape: {output.shape}")

## 6. Pipeline Parallel

Pipeline parallelism processes micro-batches through model stages.

In [None]:
# Pipeline Parallel visualization
print("Pipeline Parallel Concept:")
print("\nModel is split into stages, each on a different device.")
print("Micro-batches are processed in a pipeline fashion.\n")

# Visualize pipeline schedule
stages = ['Stage 1', 'Stage 2', 'Stage 3', 'Stage 4']
micro_batches = 4

print("Pipeline Schedule (F=Forward, B=Backward):")
print("Time →")
for i, stage in enumerate(stages):
    schedule = ' ' * (i * 3)
    for mb in range(1, micro_batches + 1):
        schedule += f'[F{mb}]'
    for mb in range(micro_batches, 0, -1):
        schedule += f'[B{mb}]'
    print(f"{stage}: {schedule}")

print("\nBenefits:")
print("- Reduces GPU idle time")
print("- Enables training of very deep models")
print("- Can be combined with data parallelism")

## 7. Fully Sharded Data Parallel (FSDP)

FSDP enables training of extremely large models by sharding everything.

In [None]:
# FSDP concept demonstration
print("Fully Sharded Data Parallel (FSDP)\n")

# Memory usage comparison
model_params = 7_000_000_000  # 7B parameters
bytes_per_param = 4  # FP32
num_gpus = 8

# Calculate memory usage
model_size_gb = (model_params * bytes_per_param) / (1024**3)
optimizer_size_gb = model_size_gb * 2  # Adam has 2 states per parameter
gradient_size_gb = model_size_gb

print(f"Model: {model_params/1e9:.0f}B parameters")
print(f"\nMemory Requirements (FP32):")
print(f"- Model parameters: {model_size_gb:.1f} GB")
print(f"- Optimizer states: {optimizer_size_gb:.1f} GB")
print(f"- Gradients: {gradient_size_gb:.1f} GB")
print(f"- Total per GPU (DDP): {model_size_gb + optimizer_size_gb + gradient_size_gb:.1f} GB")

print(f"\nWith FSDP ({num_gpus} GPUs):")
print(f"- Per GPU: {(model_size_gb + optimizer_size_gb + gradient_size_gb) / num_gpus:.1f} GB")
print(f"- Memory reduction: {(1 - 1/num_gpus) * 100:.0f}%")

print("\nFSDP Features:")
print("- Shards model parameters across GPUs")
print("- Shards optimizer states")
print("- Shards gradients")
print("- Optional CPU offloading")
print("- Mixed precision support")

### FSDP Configuration Example

In [None]:
# FSDP configuration example
print("FSDP Configuration Example:\n")

fsdp_config = """
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)

# Configure mixed precision
mixed_precision = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float16,
    buffer_dtype=torch.float16,
)

# Configure CPU offloading
cpu_offload = CPUOffload(offload_params=True)

# Wrap model with FSDP
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mixed_precision,
    cpu_offload=cpu_offload,
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
)
"""

print(fsdp_config)

## 8. Performance Comparison

In [None]:
# Performance comparison visualization
methods = ['Single GPU', 'DP (4 GPUs)', 'DDP (4 GPUs)', 'FSDP (4 GPUs)']
throughput = [100, 320, 380, 350]  # Samples/second
memory_usage = [16, 64, 64, 20]  # GB
scaling_efficiency = [100, 80, 95, 87.5]  # Percentage

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Throughput comparison
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
ax1.bar(methods, throughput, color=colors)
ax1.set_ylabel('Throughput (samples/sec)')
ax1.set_title('Training Throughput')
ax1.set_ylim(0, 400)
for i, v in enumerate(throughput):
    ax1.text(i, v + 10, str(v), ha='center')

# Memory usage comparison
ax2.bar(methods, memory_usage, color=colors)
ax2.set_ylabel('Memory Usage (GB)')
ax2.set_title('GPU Memory Usage')
ax2.set_ylim(0, 70)
for i, v in enumerate(memory_usage):
    ax2.text(i, v + 2, str(v), ha='center')

# Scaling efficiency
ax3.bar(methods, scaling_efficiency, color=colors)
ax3.set_ylabel('Scaling Efficiency (%)')
ax3.set_title('Multi-GPU Scaling Efficiency')
ax3.set_ylim(0, 110)
for i, v in enumerate(scaling_efficiency):
    ax3.text(i, v + 2, f'{v}%', ha='center')

plt.tight_layout()
plt.show()

print("Key Observations:")
print("- DDP provides best scaling efficiency")
print("- FSDP reduces memory usage significantly")
print("- DP has lower efficiency due to Python GIL")

## 9. Choosing the Right Strategy

In [None]:
# Decision tree for distributed training
print("Choosing the Right Distributed Training Strategy:\n")

print("1. Model fits on single GPU?")
print("   YES → Use DDP for multi-GPU speedup")
print("   NO  → Continue to #2\n")

print("2. Model fits with gradient checkpointing?")
print("   YES → Use DDP with gradient checkpointing")
print("   NO  → Continue to #3\n")

print("3. Model has natural splitting points?")
print("   YES → Consider Pipeline Parallel")
print("   NO  → Use FSDP\n")

print("Additional Considerations:")
print("- Single machine? → DP is simplest (but DDP is better)")
print("- Multiple machines? → Must use DDP or FSDP")
print("- Very large model (>10B params)? → FSDP is likely necessary")
print("- Low bandwidth between nodes? → Consider gradient compression")

## 10. Best Practices and Tips

In [None]:
# Best practices for distributed training
print("Distributed Training Best Practices:\n")

best_practices = {
    "Data Loading": [
        "Use DistributedSampler for proper data distribution",
        "Set sampler.set_epoch(epoch) for different shuffling",
        "Pin memory for faster GPU transfer",
        "Use multiple workers for data loading"
    ],
    "Gradient Synchronization": [
        "Use SyncBatchNorm for batch normalization layers",
        "Consider gradient accumulation for large batches",
        "Use gradient clipping to prevent instabilities"
    ],
    "Checkpointing": [
        "Save checkpoints only from rank 0 process",
        "Use torch.save with proper map_location when loading",
        "Save optimizer state for resuming training"
    ],
    "Performance": [
        "Profile with torch.profiler to find bottlenecks",
        "Use mixed precision training (AMP) for speedup",
        "Overlap computation and communication",
        "Tune batch size for optimal GPU utilization"
    ],
    "Debugging": [
        "Set TORCH_DISTRIBUTED_DEBUG=DETAIL for debugging",
        "Use dist.barrier() for synchronization points",
        "Monitor GPU utilization and memory usage",
        "Start with small scale before scaling up"
    ]
}

for category, practices in best_practices.items():
    print(f"{category}:")
    for practice in practices:
        print(f"  • {practice}")
    print()

## Summary

In this tutorial, we covered:
1. **Data Parallel (DP)** - Simple multi-GPU on single machine
2. **Distributed Data Parallel (DDP)** - Efficient multi-GPU/multi-node
3. **Model Parallel** - For models too large for single GPU
4. **Pipeline Parallel** - Efficient training of deep models
5. **Fully Sharded Data Parallel (FSDP)** - For extremely large models

Key takeaways:
- Use DDP for most distributed training scenarios
- Consider FSDP for very large models
- Combine strategies for optimal performance
- Always profile and monitor your training