![banner](https://raw.githubusercontent.com/priyammaz/HAL-DL-From-Scratch/main/src/visuals/banner.png)

## Distributed Training

A major problem you may quickly face is the size of your models are getting too large to fit on your GPU! This makes sense as todays state-of-the-art models are tens of billions of parameters, so what do we then do? There is typically two approaches we can take depending on the issue.

#### **Gradient Accumulation** 
If your model completely fills up your GPU memory where there is no space to pass in a reasonable batch size, we can use **gradient accumulation**. 
Gradient accumulation is instead of passing in a batch size of 64 all at once, pass in smaller minibatches of size 8, 8 different times. Then store all those losses and accumulate them at the end before the backward pass. 
    
#### Data Parallelism

![dataparallel](https://naga-karthik.github.io/media/ddp-figures/bothPasses.png)

[credit](https://naga-karthik.github.io/post/pytorch-ddp/)

If you have more than 1 GPU but the model can fit on any one of them, we can do **Data Parallelism**. What this does is make a copy of the full model to each GPU and then passes smaller batchsizes of data to push through all the models. We then accumulate this information at the end and perform backpropagation. An important aspect of this is all the weights are Synced across the GPUs so they are identical models!! Typically GPU 0 is responsible for managing the syncing between the GPUs. This will require us to spawn a bunch of parallel processes and a bunch of steps to setup.


#### Model Parallelism

![modelparallel](https://fairscale.readthedocs.io/en/latest/_images/pipe.png)

[credit](https://fairscale.readthedocs.io/en/latest/deep_dive/pipeline_parallelism.html)


If you have multiple GPUs, but your model is so large that it cannot fit on any of them, then we have to split the model up between GPUS. This means we will have the weights of the model split between the different GPU blocks. This is relatively easy to implement though with some code changes!


### AlexNet
AlexNet is not a very large model by any stretch of the imagination, but the ideas we will explore with this relatively simple implementation will hold true regardless of the architecture! We will be performing the things mentioned above by updating the training and model code from our [Intro to Vision Tutorial](https://github.com/priyammaz/HAL-DL-From-Scratch/tree/main/PyTorch%20for%20Computer%20Vision/Intro%20to%20Vision). Lets take a look at our starting point first copied from there and I will indicate changes for every version we explore. 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

#### Define Vanilla AlexNet Architecture

In [2]:
class VanillaAlexNet(nn.Module):
    def __init__(self, classes=2, dropout_p=0.5):
        super().__init__()
        self.classes = classes
        
        self.feature_extractor = nn.Sequential(
                nn.Conv2d(in_channels=3, out_channels=64, kernel_size=11, stride=4, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),  
                nn.BatchNorm2d(num_features=64), # ADDED IN BATCHNORM
                
                nn.Conv2d(in_channels=64, out_channels=192, kernel_size=5, stride=1, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2), 
                nn.BatchNorm2d(num_features=192), # ADDED IN BATCHNORM 
                
                nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),
                nn.BatchNorm2d(num_features=384), # ADDED IN BATCHNORM 
                
                
                nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(num_features=256), # ADDED IN BATCHNORM 
                
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),
                nn.BatchNorm2d(num_features=256), # ADDED IN BATCHNORM 
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6,6))
        
        self.head = nn.Sequential(
                nn.Dropout(dropout_p),
                nn.Linear(256*6*6, 4096),
                nn.ReLU(),
                nn.Dropout(dropout_p),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Linear(4096, classes)
        )
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        x = self.feature_extractor(x)
        x = self.avgpool(x)
        x = x.reshape(batch_size, -1)
        x = self.head(x)
        return x

#### Prep DataLoaders for Training

In [3]:
### Build Cats vs Dogs Dataset ###
PATH_TO_DATA = "../data/PetImages/"

### DEFINE TRANSFORMATIONS ###
normalizer = transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ### IMAGENET MEAN/STD ###
train_transforms = transforms.Compose([
                                        transforms.Resize((224,224)),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        normalizer
                                      ])


dataset = ImageFolder(PATH_TO_DATA, transform=train_transforms)

train_samples, test_samples = int(0.9 * len(dataset)), len(dataset) - int(0.9 * len(dataset))
train_dataset, val_dataset = torch.utils.data.random_split(dataset, lengths=[train_samples, test_samples])

#### Default Training Script

In [4]:
### SELECT DEVICE ###
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on Device {DEVICE}")

### LOAD IN AlexNet ###
model = VanillaAlexNet()
model = model.to(DEVICE)

### MODEL TRAINING INPUTS ###
epochs = 1
optimizer = optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()
batch_size = 128

### BUILD DATALOADERS ###
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

def vanilla_train(model, device, epochs, optimizer, loss_fn, batch_size, trainloader, valloader):
    for epoch in range(1, epochs + 1):
        print(f"Starting Epoch {epoch}")
        training_losses = []
        validation_losses = []
        
        model.train() # Turn On BatchNorm and Dropout
        for image, label in tqdm(trainloader):
            image, label = image.to(DEVICE), label.to(DEVICE)
            optimizer.zero_grad()
            out = model.forward(image)
        
            ### CALCULATE LOSS ##
            loss = loss_fn(out, label)
            training_losses.append(loss.item())

            loss.backward()
            optimizer.step()

        model.eval() # Turn Off Batchnorm 
        for image, label in tqdm(valloader):
            image, label = image.to(DEVICE), label.to(DEVICE)
            with torch.no_grad():
                out = model.forward(image)

                ### CALCULATE LOSS ##
                loss = loss_fn(out, label)
                validation_losses.append(loss.item())


        training_loss_mean = np.mean(training_losses)
        valid_loss_mean = np.mean(validation_losses)

        print("Training Loss:", training_loss_mean) 
        print("Validation Loss:", valid_loss_mean)
        
    return model

model = vanilla_train(model=model,
                      device=DEVICE,
                      epochs=epochs,
                      optimizer=optimizer,
                      loss_fn=loss_fn,
                      batch_size=batch_size,
                      trainloader=trainloader,
                      valloader=valloader)

Training on Device cuda
Starting Epoch 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 176/176 [00:29<00:00,  6.00it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:03<00:00,  5.31it/s]

Training Loss: 0.5134251930496909
Validation Loss: 0.41623771488666533





## Gradient Accumulation

We will split our batch size of 128 up with a gradient accumulation steps and then aggregate that information at the end before backpropagation! There is one consideration though! If your batch size is too small, then the batchnorm layers wont have enough samples to work with to calcualte their mean and standard deviation. There are some workarounds for this (LayerNormalization could be one) but it is out of the scope in an introduction to these materials.


**Downsides of Accumulation:**

One thing you will see though is, beause we are splitting up our forward propagation into an iteration of steps rather in parallel, your time for training will definitely go up!

In [5]:
############### OLD CODE ##############
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on Device {DEVICE}")
model = VanillaAlexNet()
model = model.to(DEVICE)
epochs = 1
optimizer = optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()
#######################################


############### NEW CODE ##############
batch_size = 128
gradient_accum_steps = 8 # Split up large batch size into 8 steps
sub_batch_size = 128 // 8 # Get number of samples for each sub-batch
print(f"Target Batch Size: {batch_size} Split into {gradient_accum_steps} Accumulation Steps")

### BUILD DATALOADERS ###
trainloader = DataLoader(train_dataset, batch_size=sub_batch_size, shuffle=True, num_workers=4) # use sub batch size
valloader = DataLoader(val_dataset, batch_size=sub_batch_size, shuffle=False, num_workers=4) # use sub batch size
#######################################

def grad_accum_train(model, device, epochs, optimizer, loss_fn, batch_size, trainloader, valloader, gradient_accum_steps): # Add gradient accum steps
    for epoch in range(1, epochs + 1):
        print(f"Starting Epoch {epoch}")
        training_losses = []
        validation_losses = []
        model.train()
        for image, label in tqdm(trainloader):
            image, label = image.to(DEVICE), label.to(DEVICE)
            optimizer.zero_grad()
            
            
            
            ############### NEW CODE ##############
            accumulated_losses = 0 # Create a list to store loss values
            accumulated_accuracies = []
            for _ in range(gradient_accum_steps):  # Iterate through the number of steps we want
                out = model.forward(image) # Pass through a sub batch of images
                subloss = loss_fn(out, label) # Calculate the loss for the sub batch of images
                accumulated_losses+=subloss # Add the loss
                
            loss = accumulated_losses / gradient_accum_steps # Calculate mean loss across the sub batches
            ########################################
            
            training_losses.append(loss.item())
            loss.backward()
            optimizer.step()

        model.eval() 
        for image, label in tqdm(valloader):
            image, label = image.to(DEVICE), label.to(DEVICE)
            with torch.no_grad():
                out = model.forward(image)
                loss = loss_fn(out, label)
                validation_losses.append(loss.item())

        training_loss_mean = np.mean(training_losses)
        valid_loss_mean = np.mean(validation_losses)

        print("Training Loss:", training_loss_mean) 
        print("Validation Loss:", valid_loss_mean)
        
    return model

model = grad_accum_train(model=model,
                         device=DEVICE,
                         epochs=epochs,
                         optimizer=optimizer,
                         loss_fn=loss_fn,
                         batch_size=batch_size,
                         trainloader=trainloader,
                         valloader=valloader,
                         gradient_accum_steps=gradient_accum_steps)

Training on Device cuda
Target Batch Size: 128 Split into 8 Accumulation Steps
Starting Epoch 1


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1403/1403 [01:33<00:00, 15.01it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 41.65it/s]

Training Loss: 0.490524502828218
Validation Loss: 0.3431617067887997





## Model Parallelism

In this next stage we will be splitting our model up between 2 GPUS! This is very simple and we just need to let the model know which layers belong where. Lets first take a look at the resources we have:

In [6]:
!nvidia-smi

Wed Apr 26 10:15:51 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN RTX    Off  | 00000000:01:00.0  On |                  N/A |
| 61%   75C    P2    90W / 280W |   4874MiB / 24212MiB |     18%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX    Off  | 00000000:02:00.0 Off |                  N/A |
| 41%   34C    P8    29W / 280W |     13MiB / 24220MiB |      0%      Default |
|       

We can see that we have 2 GPUs, and they will each be labeled "cuda:0" and "cuda:1". So we can use those to pass our data through the model as we want. 

### Updated Model Code to Pipe GPU Placement

In [7]:
class ModelParallelAlexNet(nn.Module):
    def __init__(self, classes=2, dropout_p=0.5):
        super().__init__()
        self.classes = classes
        
        self.feature_extractor = nn.Sequential(
                nn.Conv2d(in_channels=3, out_channels=64, kernel_size=11, stride=4, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),  
                nn.BatchNorm2d(num_features=64), # ADDED IN BATCHNORM
                
                nn.Conv2d(in_channels=64, out_channels=192, kernel_size=5, stride=1, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2), 
                nn.BatchNorm2d(num_features=192), # ADDED IN BATCHNORM 
                
                nn.Conv2d(in_channels=192, out_channels=384, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),
                nn.BatchNorm2d(num_features=384), # ADDED IN BATCHNORM 
                
                
                nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(num_features=256), # ADDED IN BATCHNORM 
                
                nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2),
                nn.BatchNorm2d(num_features=256), # ADDED IN BATCHNORM 
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6,6))
        
        self.head = nn.Sequential(
                nn.Dropout(dropout_p),
                nn.Linear(256*6*6, 4096),
                nn.ReLU(),
                nn.Dropout(dropout_p),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Linear(4096, classes)
        )
        
        ############### NEW CODE ##############
        self.feature_extractor = self.feature_extractor.to("cuda:0")
        self.avgpool = self.avgpool.to("cuda:1")
        self.head = self.head.to("cuda:1")
        ########################################
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        ### MOVE X TO GPU WITH FEATURE EXTRACTOR ###
        x = x.to("cuda:0")
        x = self.feature_extractor(x)
        
        
        ### MOVE X TO GPU WITH REMAINING LAYERS ###
        x = x.to("cuda:1")
        x = self.avgpool(x)
        x = x.reshape(batch_size, -1)
        x = self.head(x)
        return x

### Training Script
In this example, we will be passing in some tensor which in the model will be moved to GPU 0. Internally this tensor will then move to GPU 1, so we need to ensure that when we calculate the loss, our labels we are comparing against are also in GPU 1.

In [8]:
############### NEW CODE ##############
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # We will not set a device anymore
# print(f"Training on Device {DEVICE}")
########################################


### LOAD IN AlexNet Model With Model Parallelism ###
model = ModelParallelAlexNet()
# model = model.to(DEVICE) # No longer sending the model to a device, we did that within the model itself

epochs = 1
optimizer = optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()
batch_size = 128

### BUILD DATALOADERS ###
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

def model_parallel_train(model, device, epochs, optimizer, loss_fn, batch_size, trainloader, valloader):
    for epoch in range(1, epochs + 1):
        print(f"Starting Epoch {epoch}")
        training_losses = []
        validation_losses = []
        
        model.train() # Turn On BatchNorm and Dropout
        for image, label in tqdm(trainloader):
            
            
            ############### NEW CODE ##############
            label = label.to("cuda:1") # We need to ensure out label is sitting where we expect our model output
            ########################################
            
            
            optimizer.zero_grad()
            out = model.forward(image)
        
            ### CALCULATE LOSS ##
            loss = loss_fn(out, label)
            training_losses.append(loss.item())

            loss.backward()
            optimizer.step()

        model.eval() # Turn Off Batchnorm 
        for image, label in tqdm(valloader):
            
            
            ############### NEW CODE ##############
            label = label.to("cuda:1") # We need to ensure out label is sitting where we expect our model output
            ########################################

            
            with torch.no_grad():
                out = model.forward(image)

                ### CALCULATE LOSS ##
                loss = loss_fn(out, label)
                validation_losses.append(loss.item())


        training_loss_mean = np.mean(training_losses)
        valid_loss_mean = np.mean(validation_losses)

        print("Training Loss:", training_loss_mean) 
        print("Validation Loss:", valid_loss_mean)
        
    return model


model = model_parallel_train(model=model,
                             device=DEVICE,
                             epochs=epochs,
                             optimizer=optimizer,
                             loss_fn=loss_fn,
                             batch_size=batch_size,
                             trainloader=trainloader,
                             valloader=valloader)

Starting Epoch 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 176/176 [00:29<00:00,  5.95it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:03<00:00,  5.58it/s]

Training Loss: 0.5130078430202875
Validation Loss: 0.41395812183618547





## Data Parallelism

This is probably the easiest implementation for paralleism but it has some limitations compared to the **Distributed Data Parallelism** that we will see in a bit. Namely, DataParallel is limited to a single process on a single machine, meaning we cannot take advantage of multiple processes. Regardless, lets put it together!

We will be using here the Vanilla AlexNet along with our Vanilla training function. As you can see, there is really nothing to do but wrap our model with DataParallel!

In [13]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 
############### NEW CODE ##############
model = VanillaAlexNet()
model = torch.nn.DataParallel(model, device_ids=[0,1]) # Wrap model with Data Parallel and provide IDX for GPUS we want to train on
model = model.to(DEVICE) # Move model to the main GPU node (typically the first one)
#######################################
    
epochs = 1
optimizer = optim.Adam(params=model.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()
batch_size = 128
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


model = vanilla_train(model=model,
                      device=DEVICE,
                      epochs=epochs,
                      optimizer=optimizer,
                      loss_fn=loss_fn,
                      batch_size=batch_size,
                      trainloader=trainloader,
                      valloader=valloader)

Starting Epoch 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 176/176 [00:29<00:00,  5.93it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:03<00:00,  5.43it/s]

Training Loss: 0.5091178042983467
Validation Loss: 0.4232334554195404





## Distributed DataParallelism

Everything from here on forward is not going to work on a Jupyter Notebook. We will explore the ideas here but then the actual script that can run will be available in [ddp.py](ddp.py)

There are some ideas we need to talk about first before we keep going:

- **Process/Worker:** Refers to an instance of a Python program, where each process will control a single GPU.
- **Node:** An entire computer and all its components. Training on multiple nodes means training accross multiple computers
- **World Size**: Total number of processes participating in the compute task. Typically this is your **Number of GPUS**
- **Global Rank**: The Unique ID assigned to each of your GPUS across all nodesl. Typically 0 is the main process
- **Local Rank**: The Unique ID assigned to each GPU inside a single Node (only matters in multinode training, otherwise identical to Global Rank)


#### Step 1) Setup the IP Address/Port Number for the GPUS to communicate through
- Typically our address will be "localhost" and the Port can be any number you want that is not being used by your computer, we will just use 12355. These will be accessed by the multithread by reading our environmental variables, so we will update our environmental variables as such. 
- We need to initialize the process group with the rank and world size. Rank will be determined by the DDP Module so we will just have a variable for it to pass in, but we need to let it know the world size, or the number of GPUs we have access to. 

When we initialize the group, you will notice we are using *nccl* as our backend which stands for NVIDIA Collective Communication Library. There are some other options such as **gloo** and **mpi** but each have varying capabilities that you can explore [here](https://pytorch.org/docs/stable/distributed.html)!

In [15]:
### Import In DDP Related Packages ###
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp

In [None]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

#### Step 2) Build a Distributed Data Sampler

If we want to split our dataset across multiple GPUS, we also need to be able to sample data in a distrbuted way for a specific GPU (the rank). Therefore lets build a distributed sampler DataLoader.

In [18]:
def build_distributed_sampler(path_to_data, batch_size, world_size, rank):
    
    ############### OLD CODE ##############
    dataset = ImageFolder(path_to_data)
    normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_transforms = Compose([
        Resize((224, 224)),
        RandomHorizontalFlip(),
        ToTensor(),
        normalizer])

    val_transforms = Compose([
        Resize((224, 224)),
        ToTensor(),
        normalizer])

    train_samples, test_samples = int(0.9 * len(dataset)), len(dataset) - int(0.9 * len(dataset))
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, lengths=[train_samples, test_samples])

    train_dataset.dataset.transform = train_transforms
    val_dataset.dataset.transform = val_transforms
    #######################################
    
    ############### NEW CODE ##############
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)

    valloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, sampler=val_dataset)
    
    return trainloader, valloader

#### Step 3) Wrap the Entire Training Scipt in a Main Function

The code we write in the main function will be spawed and run on all GPUS indicated in the world size. The Spawning (done later by the torch.multiprocessing module) will determine the rank (which GPU) each process should go to. So again, we will have a parameter for rank but we don't need to worry about it too much.

There is another complication though we need to consider: **Batch Normalization**. As we know, BatchNormalization will calculate the mean and standard deviation for a specific batch, but now that we have batches distributed between GPUS we need to somehow bring it all together. To do this we will convert all the BatchNorms with SyncBatchNorms!

In [20]:
def main(rank, world_size):
    ### Setup The Environmental Variables and Initialize ###
    setup(rank, world_size)

    ### Build our Distributed DataLoaders ###
    trainloader, valloader = build_distributed_sampler("../data/PetImages/", batch_size=64, world_size=world_size,
                                                       rank=rank)

    ### Define Model and Convert to SyncBatchNorm ###
    model = VanillaAlexNet().to(rank)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)

    ### Set Optimizer and Loss Function ###
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    ### Start Training loop ###
    epochs = 5

    ### Standard Trainin Loop ###
    for epoch in range(1, epochs + 1):
        if rank == 0:  # Only print when we are using the default node 0 (otherwise everything will print once for each GPU)
            print(f"Starting Epoch {epoch}")

        ### Set Epoch for train and valloader for every epoch. This is necessary for proper shuffling in every iteration
        trainloader.sampler.set_epoch(epoch)
        valloader.sampler.set_epoch(epoch)

        training_losses = []
        validation_losses = []
        training_accuracies = []
        validation_accuracies = []

        model.train()

        for image, label in trainloader:
            image, label = image.to(rank), label.to(rank)
            optimizer.zero_grad()
            out = model.forward(image)
            loss = loss_fn(out, label)
            training_losses.append(loss.item())

            ### CALCULATE ACCURACY ###
            predictions = torch.argmax(out, axis=1)
            accuracy = (predictions == label).sum() / len(predictions)
            training_accuracies.append(accuracy.item())

            loss.backward()
            optimizer.step()

        model.eval()
        for image, label in valloader:
            image, label = image.to(rank), label.to(rank)
            with torch.no_grad():
                out = model.forward(image)
                loss = loss_fn(out, label)
                validation_losses.append(loss.item())

                ### CALCULATE ACCURACY ###
                predictions = torch.argmax(out, axis=1)
                accuracy = (predictions == label).sum() / len(predictions)
                validation_accuracies.append(accuracy.item())

        
        ### Convert Lists of Metrics to Tensors, Take the Mean, Store in Corresponing GPU ###
        training_loss_mean = torch.mean(torch.tensor(training_losses, dtype=torch.float)).to(rank)
        valid_loss_mean = torch.mean(torch.tensor(validation_losses, dtype=torch.float)).to(rank)
        training_acc_mean = torch.mean(torch.tensor(training_accuracies, dtype=torch.float)).to(rank)
        valid_acc_mean = torch.mean(torch.tensor(validation_accuracies, dtype=torch.float)).to(rank)

        ### AGGREGATE LOSSES AND MEANS ACROSS ALL GPUS ###
        torch.distributed.all_reduce(training_loss_mean, op=dist.ReduceOp.SUM)
        torch.distributed.all_reduce(valid_loss_mean, op=dist.ReduceOp.SUM)
        torch.distributed.all_reduce(training_acc_mean, op=dist.ReduceOp.SUM)
        torch.distributed.all_reduce(valid_acc_mean, op=dist.ReduceOp.SUM)

        ### DIVIDE THE SUM BY NUMBER OF GPUS (WORLD SIZE)
        training_loss_mean = training_loss_mean / world_size
        valid_loss_mean = valid_loss_mean / world_size
        training_acc_mean = training_acc_mean / world_size
        valid_acc_mean = valid_acc_mean / world_size

        if rank == 0:  # Only print when we are using the default node 0 (otherwise everything will print once for each GPU)
            print("Training Loss:", training_loss_mean.item())
            print("Training Accuracy:", training_acc_mean.item())
            print("Validation Loss:", valid_loss_mean.item())
            print("Validation Accuracy:", valid_acc_mean.item())

    dist.destroy_process_group()  ## End Training and Remove Everything

### Step 4) Spawn Processes and we are Done!!

```
world_size = 2
mp.spawn(main,
         args=(world_size,),
         nprocs=world_size
    )
```

We wont be running this inside the Jupyter Notebook as it will give errors but we will have all this in the [ddp.py](ddp.py)!

#### Rank
Internally, the mp.spawn will generate ranks from [0 to world_size - 1], so if we have 2 GPU's then the options for ranks are 0 and 1. It will then append the additional arguments that we pass in args that the main function needs. Our main function only needed to know the World Size, so our args was just world size, and then again internally the world size will be appended to our current rank and passed to the main function as a tuple.


### We Can Now Train Larger Models!!
Again, these techniques were applied to a relatively simple model, but the principles will be identical regardless! As a recap again:

- If you only have 1 GPU, your only option is **Gradient Accumulation**
- If you have multiple GPUS and your model can fit on one of them, then you can do DataParallel, or preferrably, DistributedDataParallel. 
- If you have multiple GPUS and your model CANNOT fit on them, then you can split the model between GPUS.


All the principles above can also be put together. You can do Distributed Data Parallel but also do gradient accumulation. If we have 4 GPUs we can split a model between two of them and then copy it to the other two, so we will be Distributed Data Parallel and Model Parallel! The setup we do is really dependent on the scale of the model you are trying to train. 


### Next Steps
Although it is nice to understand how to build a distributed system yourself, there are other tools that we will explore in the future that can help optimize the workflow without much effort. The two main ones of interest will be **Ray** for MultiGPU Hyperparameter turning and **HuggingFace Accelerate**!