# Efficient Multi-GPU Training with PyTorch FSDP

## Introduction

This example demonstrates how to implement both distributed and single GPU training using PyTorch. The code is organized into several blocks that cover importing libraries, defining a simple convolutional neural network (CNN) model for MNIST digit classification, and setting up distributed training using Fully Sharded Data Parallel (FSDP). Additionally, it provides fallback mechanisms to run on a single GPU if multiple GPUs are not available. Each section includes detailed comments explaining its purpose and functionality, making the example useful for understanding distributed training strategies and optimizing deep learning workflows in PyTorch.

## 1. Import Statements and Environment Setup  

This block imports necessary libraries for PyTorch, distributed training, data loading, and data transformations. It also sets up the environment for CUDA operations.

In [10]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
import torch.distributed as dist

## 2. GPU Availability Check  

This function checks the number of available GPUs on the machine. It uses PyTorch’s CUDA utilities to determine if there are sufficient GPUs for distributed training.

In [11]:
def check_gpu_availability():
    # Get the number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    return num_gpus

## 3. SimpleCNN Model Definition  

This block defines a simple convolutional neural network (CNN) using PyTorch's nn.Module. The network consists of two convolutional layers, followed by ReLU activations, max pooling layers, and a fully connected output layer designed for MNIST images.

In [12]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # First convolutional layer: input channels=1, output channels=32
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()  # Activation function after conv1
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # Reduces spatial dimensions
        
        # Second convolutional layer: input channels=32, output channels=64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()  # Activation function after conv2
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # Further reduces spatial dimensions
        
        # Fully connected layer that maps the flattened output to 10 classes (MNIST)
        self.fc = nn.Linear(64 * 7 * 7, 10)  # Assuming input image size is 28x28

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))   # Apply first convolution, activation and pooling
        x = self.pool2(self.relu2(self.conv2(x)))   # Apply second convolution, activation and pooling
        x = x.view(-1, 64 * 7 * 7)                  # Flatten the output for the fully connected layer
        x = self.fc(x)                              # Final classification layer
        return x

## 4. Distributed Training Setup and Cleanup  

This block defines helper functions for distributed training. The `setup` function initializes the process group for distributed training using the NCCL backend, while `cleanup` destroys the process group once training is complete.

In [4]:
def setup(rank, world_size):
    """Sets up the process group for distributed training."""
    # Set environment variables for master address and port
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the distributed process group using NCCL backend (optimized for GPUs)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    # Set the current device based on the rank
    torch.cuda.set_device(rank)

def cleanup():
    """Cleans up the distributed process group."""
    if dist.is_initialized():
        dist.destroy_process_group()

## 5. Distributed Training Function Using FSDP  

This function handles distributed training using Fully Sharded Data Parallel (FSDP). It sets up the training environment, loads the MNIST dataset with a distributed sampler, wraps the model with FSDP, and executes the training loop with logging for every 100 batches on the primary GPU.

In [5]:
def train_distributed(rank, world_size, epochs=2):
    """Trains the model using FSDP with multiple GPUs."""
    try:
        # Set up distributed environment for current rank
        setup(rank, world_size)
        device = torch.device(f"cuda:{rank}")
        
        # Import FSDP-related modules here to prevent issues on single-GPU setups
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        
        # Attempt to import size_based_auto_wrap_policy based on the PyTorch version
        try:
            from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
        except ImportError:
            try:
                from torch.distributed.fsdp import size_based_auto_wrap_policy
            except ImportError:
                # Define a simple custom policy if the import fails
                def size_based_auto_wrap_policy(module, recurse, unwrapped_params, min_num_params=1e8):
                    return sum(p.numel() for p in unwrapped_params) > min_num_params

        # Define data transformations for the MNIST dataset
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        # Load MNIST training dataset
        dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
        # Create a DistributedSampler for the dataset
        sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
        # Define DataLoader with distributed sampler
        dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=2)

        # Instantiate the SimpleCNN model and move it to the designated GPU
        model = SimpleCNN().to(device)
        # Define the loss function (cross entropy) and move it to the GPU
        criterion = nn.CrossEntropyLoss().to(device)
        # Define the optimizer (Adam)
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Wrap the model with FSDP for distributed training
        fsdp_model = FSDP(model, auto_wrap_policy=size_based_auto_wrap_policy)

        # Training loop
        for epoch in range(epochs):
            fsdp_model.train()
            # Ensure that the sampler shuffles differently every epoch
            sampler.set_epoch(epoch)
            for batch_idx, (data, target) in enumerate(dataloader):
                # Move data and labels to the designated GPU
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()             # Reset gradients
                output = fsdp_model(data)         # Forward pass
                loss = criterion(output, target)  # Compute loss
                loss.backward()                   # Backward pass
                optimizer.step()                  # Update weights

                # Print training status for the primary GPU only
                if batch_idx % 100 == 0 and rank == 0:
                    print(f"Rank {rank}, Epoch: {epoch}, Batch: {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")

        if rank == 0:
            print("Distributed training finished!")
    except Exception as e:
        print(f"Error in rank {rank}: {str(e)}")
    finally:
        # Clean up the distributed process group
        cleanup()

## 6. Single GPU Training Function  

This function performs training on a single GPU without using FSDP. It loads the MNIST dataset, defines the model, loss function, optimizer, and executes the training loop with status logging.

In [6]:
def train_single_gpu(epochs=2):
    """Trains the model on a single GPU without FSDP."""
    # Select the appropriate device (GPU if available, otherwise CPU)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Define data transformations for MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Load MNIST dataset
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    # DataLoader for batching and shuffling
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)

    # Instantiate the SimpleCNN model and move it to the device
    model = SimpleCNN().to(device)
    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(dataloader):
            # Move data and target to the device
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()             # Reset gradients
            output = model(data)              # Forward pass
            loss = criterion(output, target)  # Compute loss
            loss.backward()                   # Backward pass
            optimizer.step()                  # Update model parameters

            # Print training progress every 100 batches
            if batch_idx % 100 == 0:
                print(f"Epoch: {epoch}, Batch: {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    print("Single GPU training finished!")
    return model

## 7. Main Training Launcher  

This section determines whether to run distributed training or single GPU training based on the number of available GPUs. It spawns multiple processes for distributed training if two or more GPUs are detected, otherwise it falls back to single GPU training.

In [7]:
def run_training():
    """Spawns processes for multi-GPU training or runs single GPU training based on availability."""
    num_gpus = check_gpu_availability()
    
    if num_gpus >= 2:
        print(f"Running with distributed training on {num_gpus} GPUs")
        world_size = num_gpus
        try:
            # Spawn processes for distributed training using the available GPUs
            torch.multiprocessing.spawn(train_distributed,
                                        args=(world_size,),
                                        nprocs=world_size,
                                        join=True)
        except Exception as e:
            print(f"Distributed training failed with error: {str(e)}")
            print("Falling back to single GPU training")
            train_single_gpu()
    else:
        print("Not enough GPUs for distributed training, using single GPU mode")
        train_single_gpu()

# Execute training when the script is run directly
if __name__ == "__main__":
    run_training()

Number of GPUs available: 2
Running with distributed training on 2 GPUs


W0414 07:18:36.912000 31 torch/multiprocessing/spawn.py:160] Terminating process 65 via signal SIGTERM


Distributed training failed with error: process 1 terminated with exit code 1
Falling back to single GPU training
Using device: cuda:0
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 60.2MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.65MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 14.2MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 7.04MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch: 0, Batch: 0/938, Loss: 2.3069
Epoch: 0, Batch: 100/938, Loss: 0.1601
Epoch: 0, Batch: 200/938, Loss: 0.1409
Epoch: 0, Batch: 300/938, Loss: 0.0263
Epoch: 0, Batch: 400/938, Loss: 0.0408
Epoch: 0, Batch: 500/938, Loss: 0.0309
Epoch: 0, Batch: 600/938, Loss: 0.0319
Epoch: 0, Batch: 700/938, Loss: 0.0488
Epoch: 0, Batch: 800/938, Loss: 0.1460
Epoch: 0, Batch: 900/938, Loss: 0.0267
Epoch: 1, Batch: 0/938, Loss: 0.0619
Epoch: 1, Batch: 100/938, Loss: 0.0113
Epoch: 1, Batch: 200/938, Loss: 0.0156
Epoch: 1, Batch: 300/938, Loss: 0.0272
Epoch: 1, Batch: 400/938, Loss: 0.1430
Epoch: 1, Batch: 500/938, Loss: 0.0045
Epoch: 1, Batch: 600/938, Loss: 0.1344
Epoch: 1, Batch: 700/938, Loss: 0.0099
Epoch: 1, Batch: 800/938, Loss: 0.0446
Epoch: 1, Batch: 900/938, Loss: 0.0045
Single GPU training finished!


## Conclusion

The provided example effectively illustrates the practical aspects of both distributed and single GPU training in PyTorch. By modularizing the code into clear sections, it becomes easier to understand how to:
- Check for GPU availability and decide on the training mode,
- Set up a distributed environment using PyTorch's NCCL backend,
- Implement and wrap a neural network model with FSDP for scalable training,
- Handle data loading with distributed samplers for the MNIST dataset,
- Fallback gracefully to a single GPU training regime when necessary.

Overall, the code serves as a comprehensive guide for developers aiming to leverage the benefits of distributed computing for deep learning applications, while also ensuring compatibility with setups that only have access to a single GPU.