In [1]:
"""
Minimal PyTorch examples for different parallelism strategies:
1. DDP (Distributed Data Parallel)
2. FSDP 2 (Fully Sharded Data Parallel v2)
3. Tensor Parallelism
4. Sequence/Context Parallelism

Run with: torchrun --nproc_per_node=4 parallelism_examples.py --strategy ddp
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
    ColwiseParallel, RowwiseParallel, parallelize_module
)
from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard
import argparse
import os


In [5]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=1000, d_model=512, nhead=8, num_layers=6, seq_len=128):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(seq_len, d_model))
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        # x shape: (batch, seq_len)
        seq_len = x.size(1)
        
        x = self.embedding(x) + self.pos_encoding[:seq_len].unsqueeze(0)
        
        x = self.transformer(x)
        
        x = self.classifier(x)
        return x

In [39]:
def setup_distributed():
    """Initialize distributed training environment"""
    if dist.is_initialized():
        dist.destroy_process_group()
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    device_str = "cpu"
    count = torch.cpu.device_count()
    if torch.cuda.is_available():
        device_str = "cuda"
        count = torch.cuda.device_count()
    # elif torch.backends.mps.is_available():
    #     device_str = "mps"
    #     count = torch.mps.device_count()
    device = torch.device(f"{device_str}:{rank % count}")
    torch.set_default_device(device)
    return rank, world_size, device


def create_dummy_data(batch_size, seq_len, vocab_size, device):
    """Create dummy training data"""
    inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    targets = inputs[:, 1:].contiguous()
    inputs = inputs[:, :-1].contiguous()
    return inputs, targets

In [36]:
import os

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'

In [40]:
# =============================================================================
# 1. DDP (Distributed Data Parallel)
# =============================================================================
def ddp_example():
    """DDP replicates model on each GPU, synchronizes gradients"""
    rank, world_size, device = setup_distributed()
    
    model = SimpleTransformer().to(device)
    model = DDP(model, device_ids=[device])

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    for step in range(5):
        batch_size = 4
        inputs, targets = create_dummy_data(batch_size, 128, 1000, device)
        
        outputs = model(inputs)  # Shape: (batch, seq_len, vocab_size)
        
        loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if rank == 0:
            print(f"DDP Step {step}: Loss = {loss.item():.4f}")
    
    dist.destroy_process_group()

In [41]:
ddp_example()

ValueError: DistributedDataParallel device_ids and output_device arguments only work with single-device/multiple-device GPU modules or CPU modules, but got device_ids [device(type='cpu', index=0)], output_device None, and module parameters {device(type='cpu')}.

In [None]:
# =============================================================================
# 2. FSDP 2 (Fully Sharded Data Parallel)
# =============================================================================
def fsdp_example():
    """FSDP 2 shards model parameters across GPUs using manual wrapping"""
    rank, world_size, device = setup_distributed()
    
    # Create model with manual FSDP wrapping
    model = SimpleTransformer().to(device)
    
    # FSDP 2 uses manual wrapping - wrap each transformer layer individually
    # This gives more control over sharding strategy
    for i, layer in enumerate(model.transformer.layers):
        model.transformer.layers[i] = FSDP(
            layer,
            device_id=device,
        )
    
    # Wrap the entire model as well
    model = FSDP(
        model,
        device_id=device,
    )
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Training loop
    for step in range(5):
        batch_size = 4
        inputs, targets = create_dummy_data(batch_size, 128, 1000, device)
        
        # Forward pass - FSDP handles all-gather internally
        outputs = model(inputs)
        
        # Loss
        loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        
        # Backward pass - FSDP handles reduce-scatter internally  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if rank == 0:
            print(f"FSDP Step {step}: Loss = {loss.item():.4f}")
    
    dist.destroy_process_group()
