<a href="https://colab.research.google.com/github/xrsrke/fsdl-website/blob/main/docs/blog/posts/megatron-lm-parallelism/megatron_lm_parallelism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Why?



Large language models are large.

They are so large that even
the latest and greatest hardware accelerators, like
[NVIDIA's H100 GPU](https://lambdalabs.com/blog/nvidia-h100-gpu-deep-learning-performance-analysis),
cannot fit all the calculations that transform
input text into output text and compute the information
used to make the model better during training.

That makes training large language models a distributed programming problem,
where the work of computing an output is split, or _distributed_,
across multiple accelerators or machines.

That's common enough for neural networks:
it's actually pretty typical for training
to require many GPUs,
and the simple solution is to split work
across data points, which
[isn't so hard](https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html?highlight=distributed%20data%20parallel).

But for the largest models,
things are yet worse:
you can't fit all the calculations
for the outputs of _even a single layer_
working on _even a single datapoint_
on one accelerator.

So training large language models requires
multiple layers of parallelization.

And despite the increases in scale
and the rapid pace of innovation in language model applications
in the past few years (and even months!),
the best reference for understanding the fundamentals of how that problem
is posed and solved is still a paper from 2020:
[_Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism_](https://arxiv.org/abs/1909.08053).

In this blog post/notebook,
we'll walk through the main ideas in that paper.
Our goal will be to build up an understanding of how to implement
a Megatron-style linear layer.

We'll use that understanding to implement
`ColumnParallelLinear` in PyTorch.

## 1. Three Nested Parallelizations


### Data Parallelism

Parallelize by splitting up the data.

Each worker gets a piece of a batch and is responsible for running their own replica of the model on it.

Easy! Fun, even! Elements of a batch should have nothing to do with one another.

Synchronization: share gradients.

### Pipeline Parallelism

Parallelize by splitting up the model into distinct steps --
by layer, typically.

Each worker gets assigned one or more layers.

Note, this is also known as "vertical splitting",
because distributed systems people like
to write neural networks going left-to-right
instead of bottom-to-top,
as is [ancient tradition](https://www.iro.umontreal.ca/~vincentp/ift3395/lectures/backprop_old.pdf)
in the world of NN research.

We prefer the term "pipeline parallelism"
because it's clearer and more evocative.

Synchronization: implemented naively,
during training a worker needs to wait for
the rest of the forward and beginning of the backward pass to complete before they can move forward.
"Bubbles".

Notice that after you've split up your batch into microbatches,
you can then further split up the model running on each microbatch:
model parallelism is nested within data parallelism.
And the only thing different model groups need to share is gradients.

### Tensor Parallelism

Parallelize by splitting up matrix multiplications.

Each worker gets assigned a piece of a matrix multiplication.

This is the trickiest bit, because more information
needs to be communicated between workers, and it's where the Megatron-LM paper makes its intellectual contribution.

But notice that if you do it right, tensor parallelism can be nested inside pipeline parallelism,
which is nested inside data parallelism.

That's three levels of parallelization,
and that degree of decomposition is what it takes to scale models to 10s or 100s of billions of parameters.

## 2. Megatron-LM: A Recipe for Combining Pipeline & Tensor Parallelism for Transformers

Data parallelism is straightforward enough --
in principle, it's just a matter of taking whatever you're doing for a batch size that you're parallelizing with other techniques
and then "copying" it to multiple instances
that each have their own smaller dataset
to draw microbatches from
and which share gradients at the end of each batch.

Megatron-LM is all about the harder part,
what you're parallelizing within a microbatch.

At high-level, Megatron-LM first breaks down a model
into different stages,
with each stage having several layers -
that's our pipeline parallelism.

Within each layer of a given stage in the pipeline,
the computation is divided into smaller sections,
with each section assigned to a different GPU -
that's our tensor parallelism.

To make sure we do our parallelization efficiently,
we need to be smart when we define the "sections"
of our computation.

For a matrix multiplication,
there are two choices:
splitting by row and splitting by column.

But for most neural network layers,
there's only one sensible choice:
split along the neuron dimension.
That way, you can calculate the output
of your non-linearity without communication
between workers.

## 3. Let’s implement `ColumnLinearParallel` from scratch


This is a parallelized version of a linear layer where we parallelize by column. [[Megatron's ColumnLinearParallel]](https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/layers.py#L418)

### But why `Column` parallel?

There are, roughly speaking, two kinds of matrices:
- matrices that represent a collection of _vectors_
- matrices that represent a collection of _functions to apply to vectors_

A single matrix can switch between being one or the other,
depending on how it's used,
but most of the time a matrix only does one of those two things.

For example, a batch of data is a collection of vectors, while the weights of a layer in a neural network are a collection of functions to apply to vectors -- each of which is what some might call a "neuron".

When we distribute work,
we always want to split such that
the different pieces of the computation are as independent as possible.

For a collection, that's the dimension that goes across
different elements of the collection.

So when we're parallelizing a batch,
that means splitting the different entries in the batch
onto different workers.

And when we're parallelizing across a layer,
we want to split the different neurons onto different workers.

In the conventions of PyTorch,
that means we want to split the weights along
their last dimension, the columns:

```python
# [batch, n_outputs] = [batch, n_inputs] @ [n_inputs, n_outputs] + [n_outputs]
out = inputs @ weights + biases
```

In [None]:
import torch
import torch.distributed as dist

torch.random.manual_seed(117)

world_size = 4
batch_size, input_size, output_size = 10, 16, 12

inputs = torch.randn(2, input_size, requires_grad=False)
weights = torch.randn(output_size, input_size, requires_grad=True)
biases = torch.randn(output_size, requires_grad=True)

outputs = torch.matmul(inputs, weights.T) + biases

In [None]:
def compute_column_parallel_linear(inputs, weights, biases, n_partitions):
    num_columns = weights.shape[-1]  #

    # partition into groups of "neurons"
    partition_size = num_columns // n_partitions
    w1, w2 = weights[:partition_size, :], weights[partition_size:, :]

    # now these can run independently
    out1, out2 = torch.matmul(inputs, w1.T), torch.matmul(inputs, w2.T)

    # and then we get the final result by combining them -- along the same dimension we split
    out = torch.cat([out1, out2], dim=-1)

    return out + biases

In [None]:
outputs_parallel = compute_column_parallel_linear(inputs, weights, biases, n_partitions=4)

assert torch.equal(outputs, outputs_parallel)


In summary, the `ColumnParallelLinear` class divides the work of a linear layer across multiple processes. It does this by dividing the output dimension of the layer among the processes. Each process then computes its portion of the output and the gradients during the forward and backward passes, respectively. After the forward pass, the outputs from all the processes are gathered together to create the final output tensor. During the backward pass, the gradients are distributed across all the processes, and each process uses its portion of the gradient to update its parameters.

In [None]:
from torch import nn
import torch.nn.functional as F

class ColumnParallelLinear(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        world_size = torch.distributed.get_world_size()
        self.input_size = input_size
        self.output_size = output_size
        self._output_size_per_partition = output_size // world_size

        self.weight = nn.Parameter(torch.randn(
            self._output_size_per_partition,
            self.input_size,
        ))
        self.bias = nn.Parameter(torch.randn(
            self._output_size_per_partition,
        ))

    def forward(self, input):
        output_parallel = F.linear(Broadcast.apply(input), self.weight, self.bias)
        outputs = Gather.apply(output_parallel)
        return outputs

In [None]:
class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        torch.distributed.all_reduce(
            grad_output,  # modified in-place!
            op=torch.distributed.ReduceOp.SUM
        )
        return grad_output

class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        torch.distributed.all_reduce(
            grad_output, # modified in-place!
            op=torch.distributed.ReduceOp.SUM
        )
        return grad_output

In [None]:
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = torch.distributed.get_world_size()

        inputs = [torch.empty_like(input) for _ in range(world_size)]

        torch.distributed.all_gather(inputs, input)

        inputs = torch.cat(inputs, dim=-1)

        return inputs

    @staticmethod
    def backward(ctx, grad_output):
        rank = torch.distributed.get_rank()

        world_size = torch.distributed.get_world_size()

        dim_size = grad_output.shape[-1]

        dim_size_per_partition = dim_size // world_size

        grad_chunks = torch.split(grad_output, dim_size_per_partition, dim=-1)

        return grad_chunks[rank]

### Explain

From the `f` class

- `output = F.linear(input, self.weight, self.bias)`: The output partition corresponding to the current process.

From the `g` class

- `inputs = [torch.empty_like(input) for _ in range(world_size)]`: This line creates an outputs list with empty tensors that have the same shape as `input`. These tensors will be used to store the output of each process.

- `torch.distributed.all_gather(inputs, input)`: The `torch.distributed.all_gather` function is called to gather the input from all processes in the distributed group and store them in the `inputs` list.

From the `ColumnParallelLinear` class

- `self.output_size_per_partition = output_size // world_size`: This line calculates the output size for each partition by dividing the total output size by the number of partitions. This is done because the output dimension of the linear layer is divided among multiple processes, and each process will handle its corresponding portion of the output dimension.

- `self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))`: This line initializes the weight parameter for the current process. Since each process is responsible for its own portion of the output dimension.

In [None]:
outputs.sum().backward()

weight_grads = weights.grad.detach().requires_grad_(False)
bias_grads = biases.grad.detach().requires_grad_(False)

In [None]:
import os

def run_parallel(
    rank, world_size,
    input_size, output_size,
    inputs, weights, biases, outputs,
    weight_grads, bias_grads
):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12359'
    torch.distributed.init_process_group(
        "gloo",
        rank=rank,
        world_size=world_size
    )

    torch.use_deterministic_algorithms(True)
    torch.random.manual_seed(rank)

    model = ColumnParallelLinear(input_size, output_size)

    # Partition the weights and biases and assign to the model
    partition_size = weights.shape[0] // world_size
    partition_start, partition_end = rank * partition_size, (rank + 1) * partition_size

    model.weight.data = weights[partition_start: partition_end].detach().requires_grad_(True)
    model.bias.data = biases[partition_start: partition_end].detach().requires_grad_(True)

    outputs_parallel = model(inputs.detach().requires_grad_(False))
    outputs_parallel.sum().backward()

    print(f"rank={rank}, parallel_output.shape: {outputs_parallel.shape}, non_parallel_output.shape: {outputs.shape}\n")
    print(f"rank={rank}, is the forward correct? {torch.allclose(outputs_parallel, outputs)}\n")
    print(f"rank={rank}, is the gradient of the weight correct? {torch.allclose(model.weight.grad, weight_grads[rank])}\n")
    print(f"rank={rank}, is the gradient of the bias correct? {torch.allclose(model.bias.grad, bias_grads[rank])}\n")

    torch.distributed.destroy_process_group()

In [None]:
from torch.multiprocessing import Process

processes = []

for rank in range(world_size):
    p = Process(target=run_parallel, args=(
        rank, world_size,
        input_size, output_size,
        # Because PyTorch does not support sending tensors
        # that require gradients through inter-process communication
        # we need to detach them from the computational graph
        inputs, weights.detach(), biases.detach(), outputs.detach(),
        weight_grads, bias_grads
    ))
    processes.append(p)
    p.start()

for p in processes:
    p.join()

#### Write a ParallelMLP from scratch

In [None]:
class Scatter(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
        last_dim_size = input.shape[-1]
        n_chunks = last_dim_size // world_size
        input_chunks = torch.split(input, n_chunks, dim=-1)
        return input_chunks[rank]

    @staticmethod
    def backward(ctx, grad_output):
        world_size = torch.distributed.get_world_size()
        grad_outputs = [torch.empty_like(grad_output) for _ in range(world_size)]
        torch.distributed.all_gather(grad_outputs, grad_output)
        grad_outputs = torch.cat(grad_outputs, dim=-1)
        return grad_outputs

class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = torch.distributed.get_world_size()
        if world_size == 1:
            return input
        torch.distributed.all_reduce(input)
        return input

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

In [None]:
class RowParallelLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        world_size = torch.distributed.get_world_size()
        input_size_per_partition = input_size // world_size

        self.weight = nn.Parameter(torch.randn(
            output_size,
            input_size_per_partition
        ))
        self.bias = nn.Parameter(torch.randn(output_size))

    def forward(self, input):
        dist.barrier()
        input_parallel = Scatter.apply(input)
        output_parallel = F.linear(input_parallel, self.weight)
        outputs = Reduce.apply(output_parallel)
        return outputs + self.bias

In [None]:
def run_parallel(
    rank, world_size,
    input_size, output_size,
    inputs, weights, biases, outputs,
    weight_grads, bias_grads
):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12359'
    torch.distributed.init_process_group(
        "gloo",
        rank=rank,
        world_size=world_size
    )

    torch.use_deterministic_algorithms(True)
    torch.random.manual_seed(rank)

    hidden_size = output_size * 4
    model = nn.Sequential(
        ColumnParallelLinear(input_size, hidden_size),
        nn.ReLU(),
        RowParallelLinear(hidden_size, output_size),
    )

    def load_data(model, layer_idx, idx):
        if layer_idx == 0:
            partition_size = weights[idx].shape[0] // world_size
        elif layer_idx == 2:
            partition_size = weights[idx].shape[1] // world_size

        partition_start, partition_end = rank * partition_size, (rank + 1) * partition_size

        if layer_idx == 0:
            model[layer_idx].weight.data = weights[idx][partition_start: partition_end].detach().requires_grad_(True)
            model[layer_idx].bias.data = biases[idx][partition_start:partition_end].detach().requires_grad_(True)
        elif layer_idx == 2:
            model[layer_idx].weight.data = weights[idx][:, partition_start:partition_end].detach().requires_grad_(True)
            model[layer_idx].bias.data = biases[idx][:partition_end].detach().requires_grad_(True)
        return model

    model = load_data(model, layer_idx=0, idx=0)
    model = load_data(model, layer_idx=2, idx=1)

    outputs_parallel = model(inputs)
    outputs_parallel.sum().backward()

    print(f"rank={rank}, parallel_output.shape: {outputs_parallel.shape}, non_parallel_output.shape: {outputs.shape}\n")
    print(f"rank={rank}, is the forward correct? {torch.allclose(outputs_parallel, outputs, rtol=0.01)}\n")

    for layer_idx, grad_idx in [[0, 0], [2, 1]]:
        if layer_idx == 0:
            partition_size = weight_grads[grad_idx].shape[0] // world_size
            grad_chunks = torch.split(weight_grads[grad_idx], partition_size, dim=0)
            bias_chunks = torch.split(bias_grads[grad_idx], partition_size, dim=0)
        elif layer_idx == 2:
            partition_size = weight_grads[grad_idx].shape[1] // world_size
            grad_chunks = torch.split(weight_grads[grad_idx], partition_size, dim=1)

        print(f"rank={rank}, is the gradient of the weight correct? {torch.allclose(model[layer_idx].weight.grad, grad_chunks[rank])}\n")
        if layer_idx == 0:
            print(f"rank={rank}, is the gradient of the bias correct? {torch.allclose(model[layer_idx].bias.grad, bias_chunks[rank])}\n")
        else:
            print(f"rank={rank}, is the gradient of the bias correct? {torch.allclose(model[layer_idx].bias.grad, bias_grads[grad_idx])}\n")

    torch.distributed.destroy_process_group()

In [None]:
from copy import deepcopy

processes = []
world_size = 4
batch_size, input_size, output_size = 10, 16, 12
hidden_size = output_size * 4

inputs = torch.randn(batch_size, input_size, requires_grad=False)

model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, output_size),
)
outputs = model(inputs)
outputs.sum().backward()

weights = [
    model[0].weight.data.detach(),
    model[2].weight.data.detach(),
]
biases = [model[0].bias.data.detach(), model[2].bias.data.detach()]
weight_grads = [
    model[0].weight.grad.detach().requires_grad_(False),
    model[2].weight.grad.detach().requires_grad_(False)
]
bias_grads = [
    model[0].bias.grad.detach().requires_grad_(False),
    model[2].bias.grad.detach().requires_grad_(False)

]

for rank in range(world_size):
    p = Process(target=run_parallel, args=(
        rank, world_size,
        input_size, output_size,
        # Because PyTorch does not support sending tensors
        # that require gradients through inter-process communication
        # we need to detach them from the computational graph
        inputs, deepcopy(weights), deepcopy(biases), outputs.detach(),
        deepcopy(weight_grads), deepcopy(bias_grads)
    ))
    processes.append(p)
    p.start()

for p in processes:
    p.join()

## 2. Distributed Communication



When we train a model in a distributed manner, there are four atomic operations in distributed communication that we need to perform

- Broadcast: We start with a tensor in one process and send it to all the other processes within the group. This is like sharing a piece of information with everyone in the group.
- Scatter: We take a tensor from one process and distribute its elements or chunks to all the other processes in the group. This is like dividing up a task among all the members in a team.
- Gather: We gather data from all the processes in the group and assemble it into a single tensor at the destination process. This is like collecting everyone’s input and putting it together in one place.
- Reduce: We take data from all processes in the group, apply a specific operation to it (like summing, multiplying, finding the minimum or maximum), and then store the result in the destination process. This is like combining everyone’s efforts and producing a single output

However, we can’t just directly use these operations from PyTorch like `torch.distributed.broadcast`. This is because in training, let’s say we are broadcasting a tensor `x` from device 0 to all devices 1, 2, and 3 during the forward pass. We must also support the reverse order during the backward pass. This means we have to write a broadcast operation that can handle both forward and backward passes.


In [None]:
def is_grad_enable(input):
    return torch.is_grad_enabled() and input.requires_grad

def broadcast(inputs):
    return inputs.clone()

def reduce(inputs):
    world_size_of_parallel_group = torch.distributed.get_world_size()

    if world_size_of_parallel_group == 1:
        return inputs

    torch.distributed.all_reduce(
        inputs,
        op=torch.distributed.ReduceOp.SUM
    )

    return inputs

class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return broadcast(input)

    @staticmethod
    def backward(ctx, grad_output):
        return reduce(grad_output)

def broadcast_with_backward(inputs):
    if is_grad_enable(inputs):
        outputs = Broadcast.apply(inputs)
    else:
        outputs = broadcast(inputs)
    return outputs

In the forward pass, all the workers - different parts of our computer system - start with the same model parameters. One worker, usually the boss or ‘master’, gives these parameters to everyone else at the beginning of each cycle (`torch.distributed.broadcast()`).

In the backward pass, each worker does its own calculation. They all figure out their own gradients - basically, these are pointers that show how to tweak the model to improve it. After everyone’s done their calculations, they pool together their gradients. Each gradient represents the best direction to adjust the weight to minimize the loss with respect to its mini-batch. What we want is to find an average direction that works best for all the model replicas, so we pool all these gradients together (`torch.distributed.all_reduce()`). Then, we use this big pooled gradient to tweak the model’s parameters, making it a bit better with each cycle.

### 5. Pipeline Scheduler

So, let's say we have a big model with 10 layers, like a Transformer model. And we've got 5 devices (like GPUs) to run our model. We want to split the model into 5 parts (which we're gonna call 'partitions') and each part will run on one device.

But here's the problem. In a Transformer model, each layer needs the result from the previous layer before it can do its work. It's like a relay race, you can't start running until you've got the baton from the runner before you. So if we split our model into 5 parts, then the second part can't start until the first part is done, the third part can't start until the second part is done, and so on. That means that most of the time, most of our devices are just sitting around doing nothing. That's a bummer!

![image.png](attachment:680b6f85-3a24-47d2-956a-9ef8bacc48c1.png)

So what can we do? Here's where GPipe comes in. Instead of feeding a big batch of data to our model all at once, GPipe splits that batch into smaller chunks, which we're gonna call 'micro-batches'. And here's the trick: while one micro-batch is being processed by the second part of our model, the next micro-batch can start being processed by the first part of the model.

This way, there's always something for each part of the model to do. It's like a factory assembly line. As soon as one car is done with one station, it moves to the next station and a new car moves into the first station. This keeps all our devices busy (although they might still have some idle time, like when a worker in the factory is waiting for the next car to arrive).

So how does GPipe know what to do at each moment? That's the job of the GPipe's scheduler. The scheduler works in 'clock cycles'. For each clock cycle, it figures out which partitions should be active and which micro-batch each partition should work on.

3 microbatches

3 layers

each layer are split into 2

In [None]:
n_microbatches = 3
n_partritions = 3

Because it takes `n_microbatches` clock cycles for all micro-batches to pass through the first partition. Once the last micro-batch enters the first partition, it needs to go through the remaining partitions. Since there are `n_partritions` partitions, this requires `n_partritions-1` additional clock cycles because the first clock cycle is already counted when the micro-batch enters the first partition.

Therefore, the total number of clock cycles is `n_microbatches+n_partritions-1`

In [None]:
n_clock_cycles = n_microbatches+n_partritions-1

In [None]:
n_clock_cycles

In [None]:
for clock_idx in range(n_clock_cycles):
    start_partrition = max(clock_idx+1-n_microbatches, 0)
    end_partrition = min(clock_idx+1, n_partritions)

    tasks = []
    for partrition_idx in range(start_partrition, end_partrition):
        microbatch_idx = clock_idx-partrition_idx
        task = (microbatch_idx, partrition_idx)
        tasks.append(task)

    print(f"Clock cycle {clock_idx}: {tasks}")

**Explain**

`min(clock_idx+1, n_partritions)`
- For each clock cycle, a new partrition actives in the pipeline. If we are currently in clock_idx, it means that clock_idx partritions have already been actived.
- The next partritions will be `clock_idx+1`. However, we cannot exceed the total number of partitions (`n_partitions`), so we use the min function to limit the range.

### 6. Let's build a pipeline parallelism from scratch

In [None]:
def clock_cycles(n_microbatches, n_partritions):
    n_clock_cycles = n_partritions + n_microbatches - 1
    for clock_idx in range(n_clock_cycles):
        start_partrition = max(clock_idx+1-n_microbatches, 0)
        end_partrition = min(clock_idx+1, n_partritions)

        tasks = []
        for partrition_idx in range(start_partrition, end_partrition):
            microbatch_idx = clock_idx-partrition_idx
            task = (microbatch_idx, partrition_idx)
            tasks.append(task)

        yield tasks

In [None]:
for schedules in clock_cycles(n_microbatches, n_partritions):
    print(schedules)

In [None]:
from typing import Iterable, List, Tuple, Annotated, Optional, Generator, Dict, Callable, Any
from dataclasses import dataclass
from contextlib import contextmanager
from queue import Queue
from threading import Thread
import time
import os
from copy import deepcopy

import torch
from torch import nn
import torch.multiprocessing as mp

In [None]:
@dataclass
class QueueOutput:
    task: Callable
    output: Any
    is_done: bool = False


def wait_and_execute(device: torch.device, in_queue: Queue, out_queue: Queue):
    """Wait for a task and execute it."""
    while True:
        task = in_queue.get()

        if task.is_done is True:
            break

        try:
            output = task.compute()
        except Exception:
            raise RuntimeError(f"Failed to execute a task on {device}")
            out_queue.put(QueueOutput(task=task, output=None, is_done=False))
            continue

        out_queue.put(QueueOutput(task=task, output=output, is_done=True))


@contextmanager
def spawn_worker(
    devices: List[torch.device],
) -> Generator[
    Tuple[
        Annotated[List[Queue], "A list of tasks to be executed"],
        Annotated[List[Queue], "A list of tasks has been executed"],
    ],
    None,
    None,
]:
    """Spawn new worker threads."""
    in_queues: List[Queue] = []
    out_queues: List[Queue] = []

    workers: Dict[torch.device, Tuple[Queue, Queue]] = {}

    for device in devices:
        # TODO: remove device
        try:
            in_queue, out_queue = workers[device]
        except KeyError:
            in_queue = Queue()
            out_queue = Queue()
            workers[device] = (in_queue, out_queue)

            thread = Thread(target=wait_and_execute, args=(device, in_queue, out_queue), daemon=True)
            thread.start()

        in_queues.append(in_queue)
        out_queues.append(out_queue)

    yield (in_queues, out_queues)

In [None]:
class Task:
    def __init__(self, compute: Callable[[], torch.Tensor], is_done: bool = False):
        self._compute = compute
        self.is_done = is_done

    def compute(self) -> torch.Tensor:
        return self._compute()

In [None]:
class Pipeline:
    """A base class for pipeline."""

    def __init__(
        self,
        batches: List[torch.Tensor],
        partitions: List[nn.Sequential],
        devices: Optional[List[torch.device]] = None,
    ) -> None:
        """Initialize the pipeline.

        Args:
            batches (List[Batch]): A list of micro-batches.
            partitions (List[nn.Sequential]): A partitioned model.
            devices (Optional[List[torch.device]], optional): A list of devices. Defaults to None.
            scheduler (BaseScheduler, optional): _description_. Defaults to DetermisticScheduler().
        """
        self.batches = batches
        self.partitions = partitions
        self.devices = devices

    def fit(self):
        batches = self.batches
        partitions = self.partitions
        devices = self.devices

        n_batches = len(batches)
        n_partitions = len(partitions)

        with spawn_worker(devices) as (in_queues, out_queues):
            for schedule in clock_cycles(n_batches, n_partitions):
                self._compute(schedule, in_queues, out_queues)

    def _compute(self, schedule: List[Tuple[int, int]], in_queues: List[Queue], out_queues: List[Queue]):
        """Compute the partitions."""
        batches = self.batches
        partitions = self.partitions

        for microbatch_idx, partition_idx in schedule:
            batch = batches[microbatch_idx]
            partrition = partitions[partition_idx]

            def compute(batch, partrition):
                def wrapper():
                    return partrition(batch)

                return wrapper

            task = Task(compute=compute(batch, partrition))
            in_queues[partition_idx].put(task)

        for microbatch_idx, partition_idx in schedule:
            queue_output = out_queues[partition_idx].get()
            task, output = queue_output.task, queue_output.output

            # put the output back to the batch
            batches[microbatch_idx] = output

#### **Test the forward and backward time's line of the pipeline**

In [None]:
N_MICROBATCHES = 3
N_PARTITIONS = 2

forward_timeline = []
backward_timeline = []

def backward_hook(module, grad_input, grad_output):
    backward_timeline.append((module.microbatch_idx - 1, module.partition_idx))
    module.microbatch_idx -= 1

class AddOne(nn.Module):
    def __init__(self, partition_idx, is_logging):
        super().__init__()
        self.microbatch_idx = 0
        self.partition_idx = partition_idx
        self.is_logging = is_logging
        self.net = nn.Linear(1, 1)
        self.register_backward_hook(backward_hook)

    def forward(self, x):
        if self.is_logging:
            time.sleep(0.5)
            forward_timeline.append((self.microbatch_idx, self.partition_idx))
            self.microbatch_idx += 1

        return self.net(x)


def loss_func(x):
    return x.mean()

In [None]:
batch = torch.arange(0, N_MICROBATCHES, dtype=torch.float32, requires_grad=True)
microbatches = [x.unsqueeze(0) for x in batch.unbind()]
partitions = [nn.Sequential(AddOne(partition_idx=x, is_logging=True)) for x in range(N_PARTITIONS)]
devices = [torch.device("cpu") for _ in range(N_PARTITIONS)]

In [None]:
pipeline = Pipeline(microbatches, partitions, devices)

assert pipeline.batches == microbatches
assert pipeline.partitions == partitions

In [None]:
pipeline.fit()

assert forward_timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]

outputs = microbatches

In [None]:
forward_timeline

In [None]:
for x in outputs:
    loss = loss_func(x)
    loss.backward()

assert backward_timeline == [(2, 1), (2, 0), (1, 1), (1, 0), (0, 1), (0, 0)] or backward_timeline == [
    (2, 1),
    (2, 0),
    (1, 1),
    (0, 1),
    (1, 0),
    (0, 0),
]

In [None]:
backward_timeline

#### **Let's put it all together**

*Will write explanation here*

Train a ParallelMLP using pipeline parallelism.

Tests include test the output and the gradiens with non-parallel version

In [None]:
def run_pipeline(rank, world_size, input_size, hidden_size, output_size, microbatches, weights, biases, outputs, weight_grads, bias_grads):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12359'
    torch.distributed.init_process_group(
        "gloo",
        rank=rank,
        world_size=world_size
    )

    partitions = [
        nn.Sequential(ColumnParallelLinear(input_size, hidden_size), nn.ReLU()),
        nn.Sequential(RowParallelLinear(hidden_size, output_size)),
    ]

    partitions = load_param(rank, world_size, weights, biases, partitions)
    devices = [torch.device("cpu") for _ in range(len(partitions))]
    pipeline = Pipeline(microbatches, partitions, devices)

    assert pipeline.batches == microbatches
    assert pipeline.partitions == partitions

    pipeline.fit()

    parallel_outputs = microbatches
    print(f"rank={rank}, outputs.shape: {len(parallel_outputs)}\n")
    print(parallel_outputs[0].shape)

    for x, y in zip(outputs, parallel_outputs):
        assert torch.allclose(x, y, rtol=0.01)

    for x in parallel_outputs:
        x.sum().backward()

    for layer_idx, grad_idx in [[0, 0], [1, 1]]:
        if layer_idx == 0:
            partition_size = weight_grads[grad_idx].shape[0] // world_size
            grad_chunks = torch.split(weight_grads[grad_idx], partition_size, dim=0)
            bias_chunks = torch.split(bias_grads[grad_idx], partition_size, dim=0)
        elif layer_idx == 1:
            partition_size = weight_grads[grad_idx].shape[1] // world_size
            grad_chunks = torch.split(weight_grads[grad_idx], partition_size, dim=1)

        print(f"rank={rank}, is the gradient of the weight correct? {torch.allclose(partitions[layer_idx][0].weight.grad, grad_chunks[rank])}\n")
        if layer_idx == 0:
            print(f"rank={rank}, is the gradient of the bias correct? {torch.allclose(partitions[layer_idx][0].bias.grad, bias_chunks[rank])}\n")
        else:
            print(f"rank={rank}, is the gradient of the bias correct? {torch.allclose(partitions[layer_idx][0].bias.grad, bias_grads[grad_idx])}\n")


In [None]:
world_size = 4
batch_size, input_size, output_size = 10, 16, 12
hidden_size = output_size * world_size

batch = torch.randn(batch_size, input_size, dtype=torch.float32)
microbatches = [x.unsqueeze(0) for x in batch.unbind(dim=0)]

model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, output_size),
)
outputs = model(batch)
outputs.sum().backward()

def extract_params(model):
    weights = [model[0].weight.data.detach(), model[2].weight.data.detach()]
    biases = [model[0].bias.data.detach(), model[2].bias.data.detach()]
    return weights, biases

def extract_grads(model):
    weight_grads = [
        model[0].weight.grad.detach().requires_grad_(False),
        model[2].weight.grad.detach().requires_grad_(False)
    ]
    bias_grads = [
        model[0].bias.grad.detach().requires_grad_(False),
        model[2].bias.grad.detach().requires_grad_(False)

    ]
    return weight_grads, bias_grads

weights, biases = extract_params(model)
weight_grads, bias_grads = extract_grads(model)

In [None]:
def load_param(rank, world_size, weights, biases, partitions):
    def calculate_start_end_idx(rank, idx):
        if idx == 0: # column parallel
            partition_size = weights[idx].shape[0] // world_size
        elif idx == 1: # row parallel
            partition_size = weights[idx].shape[1] // world_size
        return rank * partition_size, (rank + 1) * partition_size

    def load(model, idx):
        partition_start, partition_end = calculate_start_end_idx(rank, idx)
        if idx == 0:  # column parallel
            model[idx][0].weight.data = weights[idx][partition_start: partition_end].detach().requires_grad_(True)
            model[idx][0].bias.data = biases[idx][partition_start:partition_end].detach().requires_grad_(True)
        elif idx == 1:  # row parallel
            model[idx][0].weight.data = weights[idx][:, partition_start:partition_end].detach().requires_grad_(True)
            model[idx][0].bias.data = biases[idx][:partition_end].detach().requires_grad_(True)
        return model

    partitions = load(partitions, idx=0)
    partitions = load(partitions, idx=1)
    return partitions

In [None]:
from torch.multiprocessing import Process
processes = []

for rank in range(world_size):
  p = Process(target=run_pipeline, args=(
        rank, world_size,
        input_size, hidden_size, output_size,
        microbatches, deepcopy(weights), deepcopy(biases),
        outputs.detach().requires_grad_(False),
        deepcopy(weight_grads), deepcopy(bias_grads),
    ))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

### 7. Are we done?

Alright, so after learning about how pipeline parallelism works, you might think "Cool, I got this!". But hold on a sec, because there's a lot more to it. Setting up pipeline parallelism, like using GPipe, can get pretty complex and bring up some tricky problems to solve.

Think about having a bunch of devices working together. They need to talk to each other, right? But if they're not careful about when and how they communicate, they might end up waiting on each other even when they don't really need to. That's kind of like being on a group chat where everyone needs to wait for one person to respond before they can keep talking. Not so fun, huh?

Then there's the problem of skip connections. That's when a layer in the model gets to skip ahead and pass its stuff directly to a layer further down the line. It's like being able to skip a question on a test and come back to it later. But how does we figure out which parts are the skip connections and get them where they need to go?

And how do we even decide how to split the model into partitions in the first place? Not all layers are the same. Some might need more memory than others, so we can't just split the model evenly. It's like dividing up chores at home, but some chores take longer than others. If we're not careful, some devices might run out of memory, while others are sitting around twiddling their thumbs.

There's also a tricky issue with how PyTorch works. In pipeline parallelism, we need to make sure that the backward pass (that's when the model is learning from its mistakes) on one device finishes before the backward pass on the previous device starts. But PyTorch isn't aware of this requirement. So how do we make sure this happens?

And what about gradient checkpointing? That's a technique where we save some intermediary results so we can use them later, like keeping your place in a book with a bookmark. But this can get complicated because we're only building the computational graph (that's like the roadmap of how our data moves through the model) during the forward pass, not the backward pass. So how do we schedule these checkpoints?

Even more, what happens if a node (another term for device) crashes or stops working? How do we get back to where we were before the crash? And can all the nodes get back to the same point at the same time?

And there's even more! In a big cluster (that's a group of devices working together), we're not the only ones running experiments. Other teams are too. So what happens if another team needs some of the resources that we're using? How can we add and remove nodes dynamically for our training task without messing everything up?

And guess what? This isn't even the half of it! You only truly get the full picture when you roll up your sleeves and start building these things from the ground up. So, here's a heads up for our next tutorial: We're gonna build something called a Multi-Processing Unit (MPU) in Megatron from scratch. It's kind of like the boss of your GPU, telling it where to send stuff for different types of parallelism - like tensor parallelism, pipeline parallelism, and data parallelism. Think of it like a traffic cop, directing the flow of data.

Stay tuned, and let's get ready to tackle this beast in the next tutorial!