### 3. Model Parallel Implementation

In [None]:
import torch
from torch import nn

In [None]:
n = 256

In [None]:
mlp = nn.Sequential(
    nn.Linear(n, 4*n),
    nn.GELU(),
    nn.Linear(4*n, n),
    nn.Dropout(0.1)
)

In [None]:
mlp

Sequential(
  (0): Linear(in_features=256, out_features=1024, bias=True)
  (1): GELU(approximate='none')
  (2): Linear(in_features=1024, out_features=256, bias=True)
  (3): Dropout(p=0.1, inplace=False)
)

Column Parallel Linear `n -> 4*n`

`megatron.model.transformer`

In [None]:
import torch.distributed as dist

from einops import rearrange

Ignore little details like initiailization, focus on the main idea of column parallism

In [None]:
def get_tensor_model_parallel_group():
    return 1

In [None]:
def _reduce(x):
    if dist.get_world_size() == 1:
        return x
    
    torch.distributed.all_reduce(x, group=get_tensor_model_parallel_group())

In [None]:
class CopyToModelParallelRegion(torch.autograd.Function):
    @staticmethod
    def forward(self, x):
        return x

    @staticmethod
    def backward(self, grad_output):
        return _reduce(grad_output)

In [None]:
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        crx.save_for_backward(input, weight)
        output = torch.matmul(input, weight.T) + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors

In [None]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, gather_output):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        
        world_size = dist.get_world_size()
        self.output_size_per_partition = output_size // world_size
        
        self.weight = nn.Parameter(torch.empty(
            self.output_size_per_partition,
            self.input_size
        ))
        self.bias = nn.Parameter(torch.empty(
            self.output_size_per_partition,
            self.input_size
        ))
    
    def forward(self, input):
        input = rearrange(
            input,
            "sequence batch hidden -> (sequence batch hidden)"
        )
        
        input_parallel = CopyToModelParallelRegion.apply(input)
        
        output_parallel = ColumnParallelLinearWithAsyncAllreduce()
        
        output = _GatherFromModelParallelRegion()
        
        return output

In [None]:
class ParallelMLP(nn.Module):
    def __init__(self):
        self.dense_h_to_4h = 1

In [None]:
input = torch.randn(10, 5, 2)

In [None]:
from typing import Sequence

class VocabUtility:
    """ Split the vocabulary into `world_size` chunks and return the first
        and last index of the vocabulary belonging to the `rank`
        partition: Note that indices in [fist, last)

    """

    @staticmethod
    def vocab_range_from_per_partition_vocab_size(
        per_partition_vocab_size: int, rank, world_size: int
    ) -> Sequence[int]:
        index_f = rank * per_partition_vocab_size
        index_l = index_f + per_partition_vocab_size
        return index_f, index_l

    @staticmethod
    def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
        per_partition_vocab_size = global_vocab_size // world_size
        return VocabUtility.vocab_range_from_per_partition_vocab_size(
            per_partition_vocab_size, rank, world_size
        )

In [None]:
VocabUtility.vocab_range_from_global_vocab_size(1000, 1, 4)

(250, 500)

In [None]:
def extract_range(n_embed, rank, world_size):
    per_patrition_vocab_size = n_embed // world_size
    start_idx = rank * per_patrition_vocab_size 
    end_idx = start_idx + per_patrition_vocab_size
    return start_idx, end_idx

In [None]:
extract_range(1000, 1, 4)

(250, 500)

In [None]:
import torch

In [None]:
input = torch.randint(low=0, high=50, size=(1, 50)).view(5, 10)

In [None]:
input

tensor([[23, 34, 29, 45, 27, 27, 39, 28,  6,  4],
        [39, 11, 19,  7, 27, 22, 34, 31, 47, 25],
        [47, 22, 36,  1, 39, 47, 44, 48, 16, 36],
        [36, 23, 16, 41,  7, 48, 23, 41,  2, 36],
        [ 8,  0,  5, 23, 33, 33, 43, 25, 19, 42]])

In [None]:
new_input = (input > 30) | (input < 5)

In [None]:
new_input

tensor([[False,  True, False,  True, False, False,  True, False, False,  True],
        [ True, False, False, False, False, False,  True,  True,  True, False],
        [ True, False,  True,  True,  True,  True,  True,  True, False,  True],
        [ True, False, False,  True, False,  True, False,  True,  True,  True],
        [False,  True, False, False,  True,  True,  True, False, False,  True]])

In [None]:
import torch

In [None]:
x = torch.randn(4, 3)

In [None]:
x.shape

torch.Size([4, 3])

In [None]:
x.dim() - 1

1

In [None]:
last_dim = x.dim() - 1

In [None]:
last_dim

1

In [None]:
x.size()[last_dim]

3

In [None]:
x.shape[-1]

3

In [None]:
x.ndim - 1

1

In [None]:
from torch import nn

In [None]:
a = nn.Parameter(torch.tensor([2., 3.]))
b = nn.Parameter(torch.tensor([6., 3.]))

In [None]:
Q = 3*a**3 - b**2

In [None]:
external_grad = torch.tensor([1., 1.])
Q.backward(gradient=external_grad)

In [None]:
a.grad

tensor([36., 81.])