---
title: "Distributed Data Parallelism"
format:
  html:
    include-in-header: tippy-head.html
    include-after-body: tippy-init.html

execute: 
  echo: true
  eval: false
code-annotations: hover
---

> Coding Distributed Data Parallelism from scratch for 1D parallelism

In [1]:
%load_ext nbdistributed

In [2]:
%dist_init --num-processes 2 --gpu-ids 1,2

Using GPU IDs: [1, 2]
Starting 2 distributed workers...
✓ Successfully started 2 workers
  Rank 0 -> GPU 1
  Rank 1 -> GPU 2
Available commands:
  %%distributed - Execute code on all ranks (explicit)
  %%rank [0,n] - Execute code on specific ranks
  %sync - Synchronize all ranks
  %dist_status - Show worker status
  %dist_mode - Toggle automatic distributed mode
  %dist_shutdown - Shutdown workers

🚀 Distributed mode active: All cells will now execute on workers automatically!
   Magic commands (%, %%) will still execute locally as normal.

🐍 Below are auto-imported and special variables auto-generated into the namespace to use
  `torch`
  `dist`: `torch.distributed` import alias
  `rank` (`int`): The local rank
  `world_size` (`int`): The global world size
  `gpu_id` (`int`): The specific GPU ID assigned to this worker
  `device` (`torch.device`): The current PyTorch device object (e.g. `cuda:1`)


<IPython.core.display.Javascript object>

## What *is* Distributed Data Parallelism?

Let's think about the concept:

In general we are:

- *Not* parallelizing the model*
- *Are* parallelizing the **data**

Which means:

Given `n` GPUs, we have `n` replicants of the model

Each replica `n` sees `B/n` chunks of data at once

Then we could average the gradients of `n` at the end

This is directly a speedup parallelism strategy with a **singular** communication happening, the averaging of the gradients.

Or, written in math form:

\begin{align*}
B_i &= \frac{B}{n} 
&& \text{Each replica gets a mini-batch of size } B/n \\
g &= \frac{1}{n} \sum_{i=1}^{n} g_i 
&& \text{Gradients from all replicas ($g_i$) are averaged ($g$)} \\
\theta_i &= \theta_i - g 
&& \text{Each replica updates its own copy of the parameters ($\theta_i$) using $g$}
\end{align*}


Now let's do the code

## Imports

Here's all the relative imports we will need and be using:

In [14]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd()))
from training_utils.utils import get_smol_model, get_dataset, get, get_smol_tokenizer, set_seed

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### What is `get`?

In [4]:
get("ws")


🔹 Rank 0:
  2

🔹 Rank 1:
  2


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

`get` is a handy utility for grabbing distributed information:

```
`ws` -> dist.get_world_size(pg)
`pg` -> dist.get_process_group()
`rank` -> dist.get_rank(pg) # global
`grank` -> dist.get_rank(pg) # global
`lrank` -> local_rank
```

In [5]:
get("rank")


🔹 Rank 0:
  0

🔹 Rank 1:
  1


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [6]:
get("grank")


🔹 Rank 0:
  0

🔹 Rank 1:
  1


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
get("lrank")


🔹 Rank 0:
  0

🔹 Rank 1:
  1


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [8]:
import random
import numpy as np

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Back to the task at hand

Let's setup reproducability across all GPUs

In [15]:
set_seed(42)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Creating `torch.distributed.DistributedDataParallelism`

Let's break down what we might need for our `DistributedDataParallism` class.

1. We know we need *replicants* of the model, so each process needs its own `model`, which we should likely store
2. Given each gpu needs a replicant, we should *verify* they are all replicants (how could we do this?)
3. We need a way to *average* the gradients
4. Thought: does gradient accumulation need to be special?

### The `__init__`:

In [10]:
class SimpleDistributedDataParallelism:
    def __init__(self, model:torch.nn.Module):
        self.model = model

        for param in model.parameters():
            rank0_param = param.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(param.data, rank0_param):
                raise ValueError(
                    "Expected model parameters to be identical during `__init__`, but this is not true. "
                    "Make sure to set the seeds before creating your model"
                )

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Let's test to make sure that we get that `ValueError` by explicitly setting a different seed for rank 0 and rank 1

In [16]:
%%rank [0]
set_seed(43)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [17]:
local_rank = get("lrank")
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")

model = get_smol_model()
model.to(device)

model = SimpleDistributedDataParallelism(model)


❌ Error on Rank 1: Expected model parameters to be identical during `__init__`, but this is not true. Make sure to set the seeds before creating your model
Traceback (most recent call last):
  File "/home/zach/nbdistributed/src/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zach/miniconda3/envs/torch/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    local_rank = get("lrank")
               ^
SyntaxError: invalid syntax

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/zach/nbdistributed/src/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 8, in <module>
  File "<string>", line

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Now let's try again:

In [18]:
set_seed(43)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [19]:
model = get_smol_model()
model.to(device)
model.train()

model = SimpleDistributedDataParallelism(model)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Okay great! We have the replicant part done. What's next?

### Averaging the Grads

We need process *just* after the backward pass which will average our gradients for us.

This should entail some sort of **collection** operation, which one let's us both **gather** a tensor from all processes and then **average** all of them together?

<details>
<summary>Reveal answer</summary>

`dist.all_reduce`

</details>

In [20]:
class SimpleDistributedDataParallelism:
    def __init__(self, model:torch.nn.Module):
        self.model = model

        for param in model.parameters():
            rank0_param = param.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(param.data, rank0_param):
                raise ValueError(
                    "Expected model parameters to be identical during `__init__`, but this is not true. "
                    "Make sure to set the seeds before creating your model"
                )
    
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    
    def train(self):
        self.model.train()
    
    def eval(self):
        self.model.eval()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

First let's write a base-case, showing it fails:

In [31]:
dataset = get_dataset()["train"]

# Get a sample of each based on the process 
item = dataset[get("rank")]; print(item)


🔹 Rank 1:
  {'labels': 0, 'input_ids': [73, 38469, 44437, 9394, 17290, 901, 637, 99, 1092, 9417, 260, 5891, 288, 6015, 846, 341, 281, 216, 33, 41, 41, 40, 327, 1885, 216, 34, 30, 37, 4533, 1673, 73, 38469, 44437, 10897, 17290, 901, 637, 99, 281, 216, 33, 41, 41, 37, 327, 1885, 216, 38, 41, 35, 2215, 284, 3459, 357, 288, 6015, 846, 341, 327, 1885, 216, 33, 30, 40, 4533, 281, 216, 33, 41, 41, 40, 1673, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

🔹 Rank 0:
  {'labels': 1, 'input_ids': [8112, 291, 16668, 13259, 650, 5717, 3297, 5337, 384, 1217, 476, 260, 6267, 476, 3297, 282, 18519, 1006, 18412, 650, 2364, 1673, 7140, 810

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [22]:
item = {k:torch.tensor(v).unsqueeze(0).to(device) for k,v in item.items()};
model = get_smol_model()
model.to(device);
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDistributedDataParallelism(model)
output = ddp_model(**item)
output.loss.backward()
optimizer.step()
optimizer.zero_grad()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [23]:
output.loss


🔹 Rank 0:
  tensor(0.6914, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>)

🔹 Rank 1:
  tensor(0.6953, device='cuda:1', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Let's use the same logic we did before:

In [24]:
for i, param in enumerate(model.model.parameters()):
    local_param = param.data
    gathered = [torch.empty_like(local_param) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered, local_param)

    for rank_idx, other_param in enumerate(gathered):
        if not torch.allclose(local_param, other_param, atol=1e-6):
            raise ValueError(
                f"[Rank {dist.get_rank()}] Parameter {i} mismatch with rank {rank_idx}. "
                f"Max diff: {(local_param - other_param).abs().max().item()}"
            )


❌ Error on Rank 1: [Rank 1] Parameter 0 mismatch with rank 0. Max diff: 0.00390625
Traceback (most recent call last):
  File "/home/zach/nbdistributed/src/nbdistributed/worker.py", line 282, in _execute_code_streaming
    tree = ast.parse(code.strip(), mode='eval')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zach/miniconda3/envs/torch/lib/python3.12/ast.py", line 52, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    for i, param in enumerate(model.model.parameters()):
    ^^^
SyntaxError: invalid syntax

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/zach/nbdistributed/src/nbdistributed/worker.py", line 367, in _execute_code_streaming
    exec(compile(tree, '<string>', 'exec'), self.namespace)
  File "<string>", line 8, in <module>
ValueError: [Rank 1] Parameter 0 mismatch with rank 0. Max diff: 0.00390625




<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Now let's write a method for synchronizing the gradients (which will be called right after `loss.backward`)

In [25]:
class SimpleDistributedDataParallelism:
    def __init__(self, model:torch.nn.Module):
        self.model = model

        for param in model.parameters():
            rank0_param = param.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(param.data, rank0_param):
                raise ValueError(
                    "Expected model parameters to be identical during `__init__`, but this is not true. "
                    "Make sure to set the seeds before creating your model"
                )

    def sync_gradients(self):
        """
        Should be called before the backward pass, iterates 
        through all params, and:
        1. Check if it is `None` (not trainable)
        2. If trainable, will perform an `all_reduce` using `SUM`
        (aka: take the global average of all grads)
        """
        for param in self.model.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                param.grad /= dist.get_world_size()
    
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    
    def train(self):
        self.model.train()
    
    def eval(self):
        self.model.eval()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [26]:
del model.model, model
del optimizer
torch.cuda.empty_cache()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [27]:
set_seed(42)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [34]:
model = get_smol_model()
model.to(device);
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ddp_model = SimpleDistributedDataParallelism(model)
output = ddp_model(**item)
output.loss.backward()
ddp_model.sync_gradients()
optimizer.step()
optimizer.zero_grad()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [35]:
for i, param in enumerate(model.model.parameters()):
    local_param = param.data
    gathered = [torch.empty_like(local_param) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered, local_param)

    for rank_idx, other_param in enumerate(gathered):
        if not torch.allclose(local_param, other_param, atol=1e-6):
            raise ValueError(
                f"[Rank {dist.get_rank()}] Parameter {i} mismatch with rank {rank_idx}. "
                f"Max diff: {(local_param - other_param).abs().max().item()}"
            )


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

It works!

## Putting it all together

Now let's put it all together and see just how much faster using DDP is vs non-parallelism

In [8]:
from torch.utils.data import DataLoader
import time

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [9]:
tokenizer = get_smol_tokenizer()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [16]:
%%rank [0]

train_ds = dataset.shuffle(seed=42)

def collate_func(batch):
    return tokenizer.pad(
        batch,
        padding="longest",
        max_length=None,
        pad_to_multiple_of=8,
        return_tensors="pt",
    )

per_device_batch_size = 16


train_dataloader = DataLoader(
    train_ds,
    batch_size=per_device_batch_size,
    collate_fn=collate_func,
    drop_last=True,
    shuffle=True
)

model = get_smol_model()
model.to(device)
optimizer = torch.optim.SGD(model.model.parameters(), lr=1e-3)

start_time = time.time()
num_batches = 0
for (i, batch) in enumerate(train_dataloader):
    if i > 20:
        break
    # Move batch to GPU
    batch = {k: v.to(device) for k, v in batch.items()}
    
    torch.cuda.synchronize()  # Ensure previous GPU ops are done
    batch_start = time.time()
    
    output = model(**batch)
    output.loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    torch.cuda.synchronize()  # Wait for GPU ops to complete
    num_batches += 1

torch.cuda.synchronize()  # Ensure all GPU ops are done
total_time = time.time() - start_time
avg_time_per_batch = total_time / num_batches
print(f"Total training time: {total_time:.2f} seconds")
print(f"Average time per batch: {avg_time_per_batch:.4f} seconds")


🔹 Rank 0:
  Total training time: 1.58 seconds
  Average time per batch: 0.0751 seconds


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [11]:
dataset = get_dataset()["train"]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [12]:
# do initial shuffle
train_ds = dataset.shuffle(seed=42)

# Shard data for first parallel dimension
# Takes dataset of [0, 1, 2, ... n] -> [[0, 1, 2, ... n/ws], [n/ws, n/ws+1, ... n-1]]
ds_length = len(train_ds)
ds_length_per_rank = ds_length // get("ws")
rank = get("rank")
start = rank * ds_length_per_rank
end = start + ds_length_per_rank if rank != get("ws") - 1 else ds_length

train_shard = train_ds.select(list(range(start, end)))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [16]:
train_ds = dataset.shuffle(seed=42)

def collate_func(batch):
    return tokenizer.pad(
        batch,
        padding="longest",
        max_length=None,
        pad_to_multiple_of=8,
        return_tensors="pt",
    )

per_device_batch_size = 8


train_dataloader = DataLoader(
    train_shard,
    batch_size=per_device_batch_size,
    collate_fn=collate_func,
    drop_last=True,
    shuffle=True
)

model = get_smol_model()
model.to(device)
optimizer = torch.optim.SGD(model.model.parameters(), lr=1e-3)

start_time = time.time()
num_batches = 0
for (i, batch) in enumerate(train_dataloader):
    if i > 20:
        break
    # Move batch to GPU
    batch = {k: v.to(device) for k, v in batch.items()}
    
    torch.cuda.synchronize()  # Ensure previous GPU ops are done
    batch_start = time.time()
    
    output = model(**batch)
    output.loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    torch.cuda.synchronize()  # Wait for GPU ops to complete
    num_batches += 1

torch.cuda.synchronize()  # Ensure all GPU ops are done
total_time = time.time() - start_time
avg_time_per_batch = total_time / num_batches
print(f"Total training time: {total_time:.2f} seconds")
print(f"Average time per batch: {avg_time_per_batch:.4f} seconds")


🔹 Rank 0:
  Total training time: 1.13 seconds
  Average time per batch: 0.0540 seconds

🔹 Rank 1:
  Total training time: 1.16 seconds
  Average time per batch: 0.0551 seconds


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

As we can see, using multiple GPUs we can get double the batch size (since 2*16 == GBS of 32) *without* sacrificing speed (our communication time is a fraction of the total training time, in this instance)

## Expanding for Gradient Accumulation

In [None]:
from contextlib import contextmanager

In [None]:
class SimpleDistributedDataParallelism:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        self.enable_grad_sync()

        # Ensure model is the same across all ranks
        for param in model.parameters():
            rank0_param = param.data.clone()
            dist.broadcast(rank0_param, src=0)
            if not torch.equal(param.data, rank0_param):
                raise ValueError(
                    "Expected model parameters to be identical during `__init__`, but this is not true. "
                    "Make sure to set the seeds before creating your model"
                )

    def sync_gradients(self):
        """
        Call after backward if gradients should be synchronized.
        """
        if not self.do_sync:
            return  # skip syncing
        for param in self.model.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
                param.grad /= dist.get_world_size()

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    @property
    def do_sync(self):
        return self._do_sync

    def enable_grad_sync(self):
        self._do_sync = True

    def disable_grad_sync(self):
        self._do_sync = False

    @contextmanager
    def no_sync(self):
        """
        Context manager to temporarily disable gradient syncing.
        """
        prev = self.do_sync
        self.disable_grad_sync()
        try:
            yield
        finally:
            self._do_sync = prev

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

In [None]:
for i, batch in enumerate(dataloader):
    if i % grad_accum_steps == 0:
        ddp_model.enable_grad_sync()
    else:
        ddp_model.disable_grad_sync()

    output = ddp_model(batch)
    output.loss.backward()

    if ddp_model.do_sync:
        ddp_model.sync_gradients()
        optimizer.step()
        optimizer.zero_grad()
