### Multi-GPU training

From the PyTorch_lightning exersice. I have basically make the GPU training with ` Pytorch Lightning `
Try to implement multi-GPU training. Inner workings and implementation.

Problem / Situation :
Base Pytorch uses only single GPU while training any model, Even if there are multiple GPUs available. We can use PyTorch parallel processing capabilities to do ` multi-GPU training `

This is a big bottleneck as PyTorch is inherently designed to use one machine to do training. _This Helps in optimizing the the run-time_

```
In essence , Multi-GPU trainnig enables us to distribute the workload of model training across multiple GPUs and even on multiple machines if necessary
```

There are two main approaches:
1. DataPrallel (DP)
2. DistributedDataParallel (DDP)


Reference : [accurate, Large Minibatch SGD](https://arxiv.org/pdf/1706.02677) paper

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torchmetrics


# Data transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load datasets
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


class SimpleNet(pl.LightningModule):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 10)
        
        self.criterion = nn.CrossEntropyLoss()

        # Define torchmetrics for tracking
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.precision = torchmetrics.Precision(task="multiclass", num_classes=10)
    # forward for forward pass logic
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    # Training_step() for one step of batch training
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)

        # Compute metrics (convert logits to class predictions)
        preds = outputs.argmax(dim=1)

        acc = self.accuracy(preds, labels)
        prec = self.precision(preds, labels)

        # Log metrics separately
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_precision', prec, on_step=True, on_epoch=True, prog_bar=True)

        return loss
    # Validation_step() for one step of batch validation
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)

        preds = outputs.argmax(dim=1)
        acc = self.accuracy(preds, labels)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        return train_loader

    def val_dataloader(self):
        return val_loader

    def test_dataloader(self):
        return test_loader

In [5]:
# Define model
model = SimpleNet()

# Declare Trainer object
trainer = pl.Trainer(max_epochs=5,
                     accelerator= "auto",
                     devices=1)

# Fit model on train dataset
trainer.fit(model,
            train_loader,
            val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | fc1       | Linear              | 401 K  | train
1 | fc2       | Linear              | 131 K  | train
2 | fc3       | Linear              | 32.9 K | train
3 | fc4       | Linear              | 1.3 K  | train
4 | criterion | CrossEntropyLoss    | 0      | train
5 | accuracy  | MulticlassAccuracy  | 0      | train
6 | precision | MulticlassPrecision | 0      | train
----------------------------------------------------------
567 K     Trainable params
0         Non-trainable params
567 K     Total params
2.270     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/rakeshk94/Desktop/multi_gpu_training/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 4: 100%|██████████| 938/938 [00:15<00:00, 62.23it/s, v_num=18, train_loss_step=0.159, train_acc_step=0.969, train_precision_step=0.969, val_loss=0.135, val_acc=0.960, train_loss_epoch=0.0761, train_acc_epoch=0.976, train_precision_epoch=0.976]  

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 938/938 [00:15<00:00, 62.15it/s, v_num=18, train_loss_step=0.159, train_acc_step=0.969, train_precision_step=0.969, val_loss=0.135, val_acc=0.960, train_loss_epoch=0.0761, train_acc_epoch=0.976, train_precision_epoch=0.976]


### We can define number of devices with lightning

In [None]:
# Declare Trainer object
trainer = pl.Trainer(max_epochs=5,
                     accelerator = "gpu",
                     devices = 6)

## Distributed Training strategy with multiple devices

In [None]:
# Declare Trainer object for multi-GPU training
trainer = pl.Trainer(max_epochs=5,
                     accelerator="gpu",
                     devices=2,
                     strategy="dp")

### Process in DP strategy

- The central machine replicates the model to all GPUs.
- The individual GPUs communicate the outputs on their respective datasets back to the central machine.
- The central machine then computes the loss and gradients, which are then used to update the weights of the model located at a central machine.
- These updated weights are then sent back to the individual GPUs.

` Main problem with DataPrallel strategy is that the model is trained on one device `

In [None]:
# Declare Trainer object for multi-GPU training
trainer = pl.Trainer(max_epochs=5,
                     accelerator="gpu",
                     devices=2,
                     strategy="ddp")


This  Data is converted into batches and Each batch is sent to Replicated GPU machine 
Where we have Gradients, outputs, Loss for each GPU

### Process in DDP startegy

- The central machine replicates the model to all GPUs. This happens only once.
- The individual GPUs compute the gradients themselves, communicate them to other GPUs, and all replicates get updated.
- The central machine is never overloaded with model outputs.

## In multi-GPU training, the typical idea is to distribute the dataset across multiple machines or GPUs to take advantage of parallel processing capabilities