In [58]:
from torch import nn

##### Example 1

In [57]:
import torch
import torch.nn.functional as F

In [10]:
batch_size = 10
seq_len = 100
d_model = 16
n_experts = 5

In [11]:
switch = nn.Linear(d_model, n_experts)

In [3]:
inputs = torch.randn(batch_size, seq_len, d_model)

In [47]:
inputs.shape

torch.Size([10, 100, 16])

In [48]:
switch

Linear(in_features=16, out_features=5, bias=True)

`inputs` has the following shape:

- 10 represents the batch size,
- 100 represents the sequence length,
- 16 represents the hidden size.

Calculate which expert each token in the batch goes through using the Top-1 Router

In [49]:
flattened_inputs = inputs.view(-1, d_model)

In [50]:
flattened_inputs.shape

torch.Size([1000, 16])

In [51]:
probs = F.softmax(switch(flattened_inputs), dim=-1)

In [52]:
probs.shape

torch.Size([1000, 5])

In [53]:
_, idxs = torch.max(probs, dim=-1)

In [54]:
idxs[:3]

tensor([0, 4, 1])

In [55]:
idxs.shape

torch.Size([1000])

##### Example 1.1

In [77]:
indexes_list = [torch.eq(idxs, i).nonzero(as_tuple=True)[0] for i in range(n_experts)]

In [78]:
indexes_list

[tensor([  0,   6,   9,  11,  19,  21,  23,  24,  25,  27,  35,  36,  37,  40,
          42,  43,  47,  51,  53,  54,  61,  68,  71,  78,  85,  90,  94,  96,
         107, 112, 115, 116, 123, 130, 131, 134, 137, 141, 143, 146, 147, 156,
         159, 163, 165, 168, 176, 177, 180, 182, 183, 189, 193, 196, 199, 201,
         211, 212, 213, 218, 228, 231, 232, 235, 236, 237, 238, 239, 241, 244,
         249, 255, 259, 262, 264, 268, 269, 271, 276, 279, 280, 288, 292, 295,
         297, 299, 304, 305, 311, 314, 323, 327, 329, 332, 333, 334, 335, 336,
         341, 343, 345, 350, 351, 356, 362, 363, 365, 366, 376, 381, 383, 393,
         395, 397, 401, 404, 415, 416, 420, 421, 423, 424, 425, 429, 434, 435,
         443, 447, 449, 450, 458, 463, 465, 467, 468, 469, 485, 488, 492, 499,
         500, 501, 510, 519, 525, 526, 532, 534, 543, 545, 558, 561, 563, 566,
         567, 568, 576, 579, 591, 597, 599, 606, 609, 610, 612, 616, 622, 624,
         625, 628, 629, 633, 639, 640, 643, 644, 661

In [83]:
len(indexes_list) == n_experts

True

In [85]:
indexes_list[0][:3]

tensor([0, 6, 9])

In [None]:
for i in range(n_experts):
    if indexes_list[]

##### Example 2

In [60]:
world_size = 16
dp_group_size = 2

In [61]:
data_parallel_groups = list()
for i in range(dp_group_size):
    data_parallel_groups.append(
        list(range(i, world_size, dp_group_size))
    )

In [62]:
data_parallel_groups

[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]]

In [65]:
dp_world_size = world_size // dp_group_size

In [66]:
expert_parallel_size = 4

In [67]:
expert_parallel_info = None
for dp_ranks in data_parallel_groups:
    # Construct Expert Parallel Group
    for i in range(0, dp_world_size, expert_parallel_size):
        ranks = dp_ranks[i : i + expert_parallel_size]
        
        print(ranks)

[0, 2, 4, 6]
[8, 10, 12, 14]
[1, 3, 5, 7]
[9, 11, 13, 15]


In [None]:
inputs = inputs.view(-1, experts)

In [None]:
probs = F.softmax(inputs, dim=-1)

In [None]:
_, idxs = torch.max(probs, dim=-1)

In [None]:
ready, running, failed, succeeed, cooldown, blacklisted

In [86]:
import torchvision.transforms as tfms

In [None]:
tfms.Compose

In [None]:
biocomaptiable, elasticity, reliable

In [87]:
def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel_size_):
    """Generate expert parallel and expert data parallel group ranks list.

        Example - E + M + D parallel
        world_size = 16
        model_degree = 2
        expert_degree = 4 # number of experts in same group
        mp_group = [0, 1], [2,3], [4,5] ...
        data_parallel_group =[0,2,4,6,8,10, 12,14],                 [1,3,5,7,9,11,13,15]
        expert_parallel_group = [0,2,4,6], [8,10,12,14]             [1,3,5,7], [9,11,13,15]
        expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14],    [1,9],[3,11],[5,13],[7,15]

    Args:
        world_size (int): Distributed world size.
        model_parallel_size_ (int): Model parallel group size.
        expert_parallel_size_ (int): Expert parallel group size.

    Returns:
        Expert parallel group ranks and Expert data parallel group ranks list.
    """
    _ensure_divisibility(world_size, model_parallel_size_)
    dp_world_size = world_size // model_parallel_size_
    _ensure_divisibility(dp_world_size, expert_parallel_size_)

    # Generate data parallel groups
    data_parallel_groups = []
    dp_group_size = model_parallel_size_
    for i in range(dp_group_size):
        data_parallel_groups.append(list(range(i, world_size, dp_group_size)))

    expert_parallel_groups = []
    expert_data_parallel_groups = []
    for dp_ranks in data_parallel_groups:
        # partition of expert parallel groups, e.g. [0,2,4,6], [8,10,12,14]
        part_ep_groups = []
        for i in range(0, dp_world_size, expert_parallel_size_):
            part_ep_groups.append(dp_ranks[i:i + expert_parallel_size_])
        expert_parallel_groups.extend(part_ep_groups)

        # zip part_ep_groups get expert data parallel ranks, e.g [0,8],[2,10],[4,12],[6,14]
        for expert_dp_ranks in zip(*part_ep_groups):
            expert_data_parallel_groups.append(list(expert_dp_ranks))

    return expert_parallel_groups, expert_data_parallel_groups

In [88]:
model_parallel_size = 2

In [89]:
world_size = 16

In [90]:
rank = 0

In [91]:
data_parallel_size = 2

In [92]:
data_parallel_rank = 1

In [None]:

def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu):
    """
        Create expert and data parallel groups based on MPU (model parallel) group.

        Note: Caller of this function is responsible to check if the groups already exist.

        Example - E + M + D parallel
        world_size = 16
        model_degree = 2
        expert_degree = 4 # number of experts in same group
        mp_group = [0, 1], [2,3], [4,5] ...
        data_parallel_group =[0,2,4,6,8,10, 12,14],                 [1,3,5,7,9,11,13,15]
        expert_parallel_group = [0,2,4,6], [8,10,12,14]             [1,3,5,7], [9,11,13,15]
        expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14],    [1,9],[3,11],[5,13],[7,15]
    """
    expert_tensor_parallel_world_size = model_parallel_size


    # Get world size and rank. Ensure some consistencies.
    _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
    _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()

    group_name = f"ep_size_{expert_parallel_size_}"

    # Only create groups if they don't already exist
    # Need to check conditions outside the group creation loop because of the way torch.dist group creation works
    if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
        
        expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
            world_size, model_parallel_size_, expert_parallel_size_)
        
        for ranks in expert_parallel_groups:
            group = dist.new_group(ranks)
            if rank in list(ranks):
                _EXPERT_PARALLEL_GROUP[group_name] = group

        for ranks in expert_data_parallel_groups:
            group = dist.new_group(ranks)
            if rank in list(ranks):
                _EXPERT_DATA_PARALLEL_GROUP[group_name] = group