In [3]:
import torch

##### Example 1

In [None]:
input = torch.randn(2, 4)
weight = torch.randn(2, 4)
bias = torch.randn(2)

In [None]:
def get_tensor_model_parallel_group(): pass

In [None]:
import torch

get_tensor_model_parallel_group

In [None]:
input.shape, weight.shape, bias.shape

(torch.Size([2, 4]), torch.Size([2, 4]), torch.Size([2]))

Let `L` be the loss function. We want to compute the gradient of `L` with respect to the input, which we denote as `∂L/∂input`. Using the chain rule, we can write this gradient as:

**Hints**: 
- `output = input @ weight.t() + bias`
- `∂L/∂input = (∂L/∂output) * (∂output/∂input) = grad_output * weight.t()`
- `∂L/∂weight = (∂L/∂output) * (∂output/∂weight) = grad_output.t() @ input`
- `∂L/∂bias = grad_output.sum(dim=0)`

Explain the distributed part

In [None]:
class ColumnParallelLinearWithcAllreduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        output = torch.matmul(input, weight.T) + bias
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        grad_input = torch.matmul(grad_output, weight)
        
        handle = torch.distributed.all_reduce(
            grad_input, group=get_tensor_model_parallel_group(), async_op=True
        )
        
        # ignored: deplay for 3us, to have all-reduce
        # scheduled first and have GPU resources allocated
        
        grad_weight = torch.matmul(grad_output.T, input)
        grad_bias = grad_output.sum(dim=0)
        
        return grad_input, grad_weight, grad_bias

In [None]:
grad_input, grad_weight, grad_bias = ColumnParallelLinearWithAsyncAllreduce.apply(input, weight, bias)

### `ColumnParallelLinear`

##### Example 1

In [None]:
import torch.distributed as dist

In [None]:
x = torch.tensor([69])

In [None]:
xs = [torch.tensor(0.), torch.tensor(1.)]

In [None]:
import torch

In [None]:
x.shape

torch.Size([1])

Create a function that will gather all tensors in a distributed group. This script will be distributed across all processes. And explain each line

In [None]:
import torch.distributed as dist

In [None]:
def gather_tensors(x):
    world_size = dist.get_world_size()
    xs = [torch.empty_like(x) for _ in range(world_size)]
    
    dist.all_gather(xs, x)
    
    rank = dist.get_rank()
    print(f"Rank {rank}: xs = {xs}")
    
    return xs

**Explain**

- `world_size = dist.get_world_size()`: This line retrieves the number of processes involved in the distributed computation, using the `get_world_size` function from the PyTorch `dist` module. This is necessary because the `all_gather` function will collect tensors from all processes, and we need to know how many tensors to expect.

- `xs = [torch.empty_like(x) for _ in range(world_size)]`: This line creates a list xs of `world_size` empty tensors, each with the same shape and data type as the input tensor `x`. These tensors will be used to store the gathered data from all processes.

- `dist.all_gather(xs, x)`: This line uses the `all_gather` function from the PyTorch dist module to collect data from all processes and store it in the `xs` list. Specifically, it collects the data in tensor `x` from each process and stores the resulting tensors in the corresponding positions in `xs`.

In [None]:
xs = gather_tensors(x)

In [None]:
xs

[tensor(0.), tensor(1.)]

##### Example 2

In [None]:
world_size = 4
input_size = 16
output_size = 12
input_data = torch.randn(input_size)

In [None]:
world_size, input_size, output_size

(4, 16, 12)

In [None]:
import torch
from torch import nn

In [None]:
input_data.shape

torch.Size([16])

Write **the forward pass** of `ColumnLinearParallel` in Megatron-LM. Explain

**Hints**:

- Focus on how to do parallel computing and ignore details like initialization.
- Do not initialize a master weight (the weight of a non-parallel linear layer) and scatter the corresponding part to each process.
- The final output will be sent to all processes.

In [None]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, num_partitions):
        super().__init__()
        self.input_size = input_size
        self.output_size_per_partition = output_size // num_partitions

        self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
        self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))

    def forward(self, input):
        output_patrition = F.linear(input, self.weight, self.bias)
        
        world_size = torch.distributed.get_world_size()
        outputs = [torch.empty_like(output_patrition) for _ in range(world_size)]
        dist.all_gather(outputs, output_patrition)
        
        output = torch.cat(outputs, dim=-1)
        return output

**Explain**

- `self.output_size_per_partition = output_size // num_partitions:` This line calculates the output size for each partition by dividing the total output size by the number of partitions. This is done because the output dimension of the linear layer is divided among multiple processes, and each process will handle its corresponding portion of the output dimension.


- `self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))`: This line initializes the weight parameter for the current process. Since each process is responsible for its own portion of the output dimension.

- `output_partition = F.linear(input, self.weight, self.bias)`: The output partition corresponding to the current process.

- `outputs = [torch.empty_like(output_partition) for _ in range(world_size)]`: This line creates an `outputs` list with empty tensors that have the same shape as `output_partition`. These tensors will be used to store the `output` of each process.

- `dist.all_gather(outputs, output_partition)`: The `dist.all_gather` function is called to gather the `output_partition` from all processes in the distributed group and store them in the `outputs` list.

##### Example 3

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

Write the forward pass and backward pass of `ColumnParallelLinear` in Megatron-LM

**Hint**: In the backward pass: `Y = [Y1, Y2]` > ... > `X1 + X2 = X`

In [2]:
class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        torch.distributed.all_reduce(grad_output)
        return grad_output

NameError: name 'torch' is not defined

In [1]:
class g(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = torch.distributed.get_world_size()
        input_list = [torch.empty_like(input) for _ in range(world_size)]
        dist.all_gather(input_list, input)
        inputs = torch.cat(input_list, dim=-1)
        return inputs

    @staticmethod
    def backward(ctx, grad_output):
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
        dim_size = grad_output.shape[-1]
        chunk_size = dim_size // world_size
        grad_chunks = torch.split(grad_output, chunk_size, dim=-1)
        return grad_chunks[rank]

NameError: name 'torch' is not defined

In [None]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size: int, output_size: int, num_partitions: int):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.output_size_per_partition = output_size // num_partitions

        self.weight = Parameter(torch.empty(
            self.output_size_per_partition,
            self.input_size,
            requires_grad=True
        ))
        self.bias = Parameter(torch.empty(
            self.output_size_per_partition,
            requires_grad=True
        ))

    def forward(self, input):
        input_parallel = f.apply(input)
        output_parallel = F.linear(input_parallel, self.weight, self.bias)
        outputs = g.apply(output_parallel)
        return outputs

### `RowParallelLinear`

##### Example 1

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

Write the forward pass of `RowParallelLinear` in Megatron-LM. Ignore the details like initialization

In [None]:
class RowParallelLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        
        world_size = torch.distributed.get_world_size()
        self.input_size_per_patrition = input_size // world_size
        self.weight = nn.Parameter(torch.empty(
            self.output_size,
            self.input_size_per_patrition
        ))
        self.bias = nn.Parameter(torch.empty(
            self.output_size
        ))
    
    def forward(self, input):
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
        
        last_dim_size = input.shape[-1]
        n_chunks = last_dim_size // world_size
        input_chunks = torch.split(input, n_chunks, dim=-1)
        
        input_parallel = input_chunks[rank]
        output_parallel = F.linear(input_parallel, self.weight, self.bias)
        
        torch.distributed.all_reduce(output_parallel)
        return output_parallel

### `VocabParallelEmbedding`

##### Example 0

In [None]:
def extract_range(self, num_embeddings, rank, world_size):
    num_embeddings_per_patrition = num_embeddings // world_size
    start_idx = rank * num_embeddings_per_patrition
    end_idx = start_idx + num_embeddings_per_patrition
    return start_idx, end_idx

##### Example 1

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

Implement `VocabParallelEmbedding` from Megatron-LM. Explain your code

**Hints**
- Ignore details like weight initialization. Just focus on the parallelization
- Broadcast the final embedding to all processes

In [None]:
class VocabParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
        self.vocab_start_idx, self.vocab_end_idx = self.extract_range(
            self.num_embeddings, rank, world_size
        )
        self.num_embedding_per_patrition = self.vocab_end_idx - self.vocab_start_idx
        
        self.weight = nn.Parameter(torch.empty(
            self.num_embedding_per_patrition,
            self.embedding_dim
        ))

    def extract_range(self, num_embeddings, rank, world_size):
        per_patrition_vocab_size = num_embeddings // world_size
        start_idx = rank * per_patrition_vocab_size
        end_idx = start_idx + per_patrition_vocab_size
        return start_idx, end_idx

    def forward(self, input):
        input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx)
        masked_input = input.clone() - self.vocab_start_idx
        masked_input[input_mask] = 0

        output_parallel = F.embedding(masked_input, self.weight)
        masked_idxs = torch.where(input_mask == True)[1]
        output_parallel[:, masked_idxs, :] = 0.

        torch.distributed.all_reduce(
            output_parallel,
            op=torch.distributed.ReduceOp.SUM
        )

        return output_parallel

**Explain**
- Each process is responsible for a partition of the vocab embedding. To determine the corresponding partition, we calculate its `vocab_start_idx` and `vocab_end_idx`. The number of embedding tokens for the current process is calculated as `num_embedding_per_partition`. We then initialize the weight matrix for the current process's partition.

- During the forward pass, we create an `input_mask` to identify tokens that are not covered by the current process's partition. We then adjust the input tensor by subtracting the `vocab_start_idx` to shift the token values to the corresponding indices in the embedding of the current process. And we set all masked input elements to `0`. This is because each process is only responsible for a portion of the final embedding, and tokens outside the current process's partition should not contribute to its output.

- Next, we compute the partial embeddings for the masked input using the partition's weight matrix. We also find the indices of the masked elements in the input tensor and set the corresponding elements in the output tensor to `0`.

- Finally, we perform an all-reduce operation to sum up the partial embeddings across all processes. This results in a complete embedding tensor, as each process contributes its portion of the embeddings. The complete embedding tensor is then broadcast to all processes.

### `ParallelMLP`

##### Example 1

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
def ColumnParallelLinear(input_size, output_size): pass

def RowParallelLinear(input_size, output_size): pass

In [None]:
from torch import nn

In [None]:
RowParallelLinear(input_size=1024, output_size=200)

In [None]:
ColumnParallelLinear(input_size=100, output_size=1024)

Write the `ParallelMLP` in Megatron-LM with GELU activation, where the middle hidden state is four times the input size

In [None]:
class ParallelMLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.gelu = nn.GELU()
        self.dense_h_to_4h = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size*4
        )
        self.dense_4h_to_h = RowParallelLinear(
            input_size=hidden_size*4,
            output_size=hidden_size
        )
    
    def forward(self, hidden_states):
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.gelu(intermediate_parallel)
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias

In [None]:
parallel_mlp = ParallelMLP(hidden_size=1024)

In [None]:
type(parallel_mlp)

__main__.ParallelMLP

### `VocabParallelCrossEntropy`

##### Example 1

In [None]:
class VocabParallel