In [25]:
from main import (grigora2, Worker, Group, expert_parallel_group_objective_function,
                  all_reduce_function, gamma_function, Expert, expert_assignment)
from typing import Dict, List
import numpy as np
import math

## Evaluation of Lysi (renamed from Grigora)
Let's perform realistic experiments using the sharding specification generated by Lysi.

Specifically, we will use this [script](https://github.com/osayamenja/Megatron-DeepSpeed/blob/main/examples_deepspeed/MoE/ds_pretrain_gpt_350M_MoE128.sh) and train GPT-3 16x350M on four Perlmutter [GPU nodes](https://docs.nersc.gov/systems/perlmutter/architecture/#gpu-nodes) 

Below, we define the number of workers and number of GPUs per node

In [26]:
dim = 16
intra_node_width = 4.0

Next, we build the adjacency matrix. We manually obtained the alpha and beta values below via NCCL-tests mirco-benchmarks. 
We anticipate automating this network profiling procedure.

In [27]:
adjacency = np.zeros((dim, dim, 2))
intra_node_cost = (0.009, 0.014)  # (ms, ms/MB)
inter_node_cost = (0.03, 0.054)

Note that each GPU connects to a separate NIC in the Perlmutter; thus, there are only two types of links: intra-node NVLink and internode NIC connections.

In [28]:
for ii in range(adjacency.shape[0]):
    for jj in range(adjacency.shape[0]):
        if ii != jj and math.floor(jj / intra_node_width) == math.floor(ii / intra_node_width):
            # intra-node
            adjacency[ii, jj] = intra_node_cost
        else:
            # inter-node
            adjacency[ii, jj] = inter_node_cost

Below outlines the theoretical FLOPS of the tensor core in the A100 GPU, which comprises our testbed. We used the values from official [documentation](https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf), *but* we obtained the realistic scaling factor from empirical measurements. 

Note this scale is much less than the 75% reported by [NVIDIA](https://forums.developer.nvidia.com/t/about-gpu-peak-performance/264462/5) but aligns with the [literature](https://ieeexplore.ieee.org/document/9415606), which details 40% utilization for $4096\times4096$ matrices on the V100.

The expert matrix GEMMs for GPT-3 MoE are described below. Note that $\bigotimes$ denotes matrix multiplication, $s$ sequence length, $h$ hidden size, and $b$ batch size.
$$(s\cdot b,\; h)\bigotimes (h, \;4h) = (2048\cdot 4, \;1024) \bigotimes (1024, \;4096)$$   

In [29]:
a100_theoretical_flop_per_ms = 312 * 1E9
realistic_scaling_factor = 0.43
real_flops = int(math.ceil(realistic_scaling_factor * a100_theoretical_flop_per_ms))

In [30]:
mem = 32
workers = []
for ii in range(adjacency.shape[0]):
    workers.append(Worker(ii, real_flops, mem))

We use the below `dict` for later expert allocation.

In [31]:
# needed for later
workers_info: Dict[int, Worker] = dict()
for i in range(len(workers)):
    workers_info.update({i: workers[i]})

We define the experts below.

In [32]:
n_exp = 64
exp = []
experts = []
exp_flops = 16 * 4 * 2048 * (1024 ** 2)
for ii in range(n_exp):
    exp.append(exp_flops)
    experts.append(Expert(exp_flops, ii))

Ensure to check this [file](grigora_manuscript.pdf) for more details.

In [33]:
p2p_buf_mb = 16
p2p_fr = 4
all_r_buf = 512

gamma_arguments = {Group.NUM_LAYERS: 24,
                   Group.GLOBAL_BATCH_SIZE: 256,
                   Group.MINI_BATCH_SIZE: 4,
                   Group.MOE_FREQUENCY: 2,
                   Group.RECOMPUTATION_AMOUNT: 1}

In [34]:
shard_spec, inv = grigora2(a=adjacency,
                               obj=expert_parallel_group_objective_function,
                               all_reduce_func=all_reduce_function,
                               gamma=gamma_function,
                               p2p_buffer_size=p2p_buf_mb,
                               p2p_freq=p2p_fr,
                               all_reduce_buffer_size=all_r_buf,
                               workers=workers,
                               expert_workload=exp,
                               gamma_args=gamma_arguments)
print(shard_spec.subsets())

[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}]


As shown above, Lysi produces a sharding specification where ranks of workers are grouped into communication-optimal groups.
Lysi computes that as defined concretely above, for 64 experts and 16 GPUs, all GPUs should be in a single expert parallel group.

Now lets optimally shard experts across workers given this parallelism strategy.

In [35]:
for group in shard_spec.subsets():
    expert_parallel_workers: List[Worker] = []
    for worker in group:
        expert_parallel_workers.append(workers_info[worker])
    print(expert_assignment(expert_parallel_workers, experts))

{Worker(id=0, flops=134160000000, mem_capacity=32): {0, 61, 62, 63}, Worker(id=1, flops=134160000000, mem_capacity=32): {1, 58, 59, 60}, Worker(id=2, flops=134160000000, mem_capacity=32): {56, 57, 2, 55}, Worker(id=3, flops=134160000000, mem_capacity=32): {3, 52, 53, 54}, Worker(id=4, flops=134160000000, mem_capacity=32): {49, 50, 51, 4}, Worker(id=5, flops=134160000000, mem_capacity=32): {48, 5, 46, 47}, Worker(id=6, flops=134160000000, mem_capacity=32): {43, 44, 45, 6}, Worker(id=7, flops=134160000000, mem_capacity=32): {40, 41, 42, 7}, Worker(id=8, flops=134160000000, mem_capacity=32): {8, 37, 38, 39}, Worker(id=9, flops=134160000000, mem_capacity=32): {9, 34, 35, 36}, Worker(id=10, flops=134160000000, mem_capacity=32): {32, 33, 10, 31}, Worker(id=11, flops=134160000000, mem_capacity=32): {11, 28, 29, 30}, Worker(id=12, flops=134160000000, mem_capacity=32): {25, 26, 27, 12}, Worker(id=13, flops=134160000000, mem_capacity=32): {24, 13, 22, 23}, Worker(id=14, flops=134160000000, mem_c

As depicted above, due to the homogeneity of experts and workers, the best assignment is an equal number of experts, specifically $\frac{64}{16} = 4$ per worker.

This is a trivial case, especially considering the memory abundance per worker; a more challenging scenario arises when the experts and workers are heterogeneous. Our algorithm still finds an optimal assignment; we encourage you to experiment and verify this claim for yourself. 