Source: https://nbviewer.org/github/tunib-ai/large-scale-lm-tutorials/blob/main/notebooks/05_data_parallelism.ipynb

### Data Parallelism

In [None]:
import torch
from torch import nn

##### Example 1

In [None]:
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 10)
)

In [None]:
input = torch.randn(10, 10)

In [None]:
logits = model(input)

In [None]:
model

Sequential(
  (0): Linear(in_features=10, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=10, bias=True)
)

In [None]:
input.shape

torch.Size([10, 10])

Write the forward pass for data parallelism. The `input_ids` variable holds the IDs of all the devices, while `output_id` is the ID of the main device

In [None]:
from torch import nn

In [None]:
def compute_forwward_pass_using_data_parallelism(model, input, device_ids, output_id):
    # Distribute the inputs to all `device_ids`.
    inputs = nn.parallel.scatter(input, device_ids)
    
    # Replicate the model to all `device_ids`.
    models = nn.parallel.replicate(model, device_ids)
    
    # Compute the logits of each micro-batch with respect to each replicated model
    # on each device.
    logit = nn.parallel.parallel_apply(models, inputs)
    
    # Gathers the logits from all devices and sends them to `output_device`.
    logits = nn.parallel.gather(logit, output_id)
    
    return logits

In [None]:
logits = compute_forwward_pass_using_data_parallelism(model, input, device_ids, output_id)

In [None]:
logits.shape

torch.Size([10, 10])

##### Example 

In [None]:
device_ids = torch.tensor([0, 1, 2, 3, 4])

In [None]:
output_id = torch.tensor(0)

In [None]:
model

Sequential(
  (0): Linear(in_features=10, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=10, bias=True)
)

In [None]:
device_ids, output_id

(tensor([0, 1, 2, 3, 4]), tensor(0))

Write a function that utilizes PyTorch's built-in module for performing Data Parallelism.

**Hint**: both forward and backward pass

In [None]:
from torch import nn

In [None]:
def one_iter(inp, targ, model, loss_func, optimizer, device_ids, output_ids):
    optimizer.zero_grad()
    model = nn.DataParallel(model, device_ids=device_ids, output_device=output_ids)
    
    logits = model(inp)
    loss = loss_func(logits, targ)
    loss.backward()
    
    optimizer.step()

In [None]:
one_iter(inp, targ, model, loss_func, optimizer, device_ids, output_id)

##### Example 3

In [None]:
class Model(nn.Module):
    def forward(self, x): return self.net(x)

In [None]:
model = nn.DataParallel(
    model,
    device_ids=[0, 1, 2, 3],
    output_device=0
)

Explain how this code causes memory imbalance across all GPUs.

How to utilize all the GPUs's memory?

In [None]:
for _, data in enumerate(data_loader):    
    ### ....
    outputs = model(x) 
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()

**Explain**

`nn.DataParallel` splits the input data to the model and distributes it to all available devices (GPUs). After processing the data on each device, the final logits are concatenated and sent back to device `0` for further processing. 

The memory imbalance arises because only one GPU (device `0`) is responsible for collecting the outputs from all other devices and computing the loss function. This aggregation of outputs and loss computation on a single GPU increases the memory usage on device `0`, while the other devices do not experience the same memory consumption.

We can utilize all GPUs by distributing the loss computation across all GPUs.

In [None]:
for _, data in enumerate(data_loader):    
    ### ....
    _, losses = model(x)    
    # because the final `loss` is
    # the loss from all devices
    loss = losses.mean()
    loss.backward()
    optimizer.step()

In [None]:
class Model(nn.Module):
    def forward(self, x, labels):
        output = self.net(x)
        loss = loss_fn(outputs, labels)
        return output, loss