# Distributed PyTorch with nbdistributed

## Overview
This notebook uses the `nbdistributed` package for interactive distributed computing from Jupyter notebooks.

**Note:** Run this notebook on a machine with a GPU (Windows/Linux). It can coordinate with other GPU nodes.

## Requirements
- `nbdistributed` package
- PyTorch with CUDA support
- GPU available


In [None]:
# Install required packages
%pip install nbdistributed torch torchvision --index-url https://download.pytorch.org/whl/cu118


In [None]:
import torch
import torch.distributed as dist
from nbdistributed import Config, distributed

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = "cuda"
else:
    print("Running on CPU")
    device = "cpu"


## Configure Distributed Setup

For single GPU, nbdistributed will run locally. For multi-GPU:
- Set `MASTER_ADDR` and `MASTER_PORT` environment variables
- Run this notebook on multiple machines


In [None]:
@distributed(backend='gloo')
def test_all_gather():
    """Test all_gather operation across GPUs"""
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    # Create tensor with rank-specific data
    local_tensor = torch.tensor([rank * 10, rank * 10 + 1], dtype=torch.float32)
    
    # All-gather operation
    gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(world_size)]
    dist.all_gather(gathered_tensors, local_tensor)
    
    print(f"Rank {rank}: Gathered tensors = {gathered_tensors}")
    return gathered_tensors

# Run the distributed function
result = test_all_gather()


In [None]:
@distributed(backend='gloo')
def test_all_reduce():
    """Test all_reduce operation - sum across all GPUs"""
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    # Each rank contributes its rank number
    local_value = torch.tensor([float(rank)], dtype=torch.float32)
    
    print(f"Rank {rank}: Local value before reduce = {local_value.item()}")
    
    # Sum across all ranks
    dist.all_reduce(local_value, op=dist.ReduceOp.SUM)
    
    print(f"Rank {rank}: Sum across all ranks = {local_value.item()}")
    return local_value.item()

# Run the distributed function
sum_result = test_all_reduce()
print(f"\nFinal result: {sum_result}")
