In [5]:
##inferencing modeling for 1.8T parameter

#what is performance metric to measure model performance
# throught put  n tokens/s 
# latency   1 / micro-second

#problem statement
# maximizes GPU utilization (tokens per second per GPU) through batching different user requests without incurring additional costs
# maximizes user interactivity (repsonse time = time to wait for response, measured tokens per second per user)

#inference context 128K-1M
#prefill context 8192

#number of GPU/s per Node
num_GPU_per_node = 8

#nvidia
mem_per_stack = 24*1e9
num_stacks = 8
chip_config_name = ["b100","b200"]   # per GPU is two devices 

capacity_per_config_per_GPU = num_stacks * mem_per_stack

#memory capacity requirement
data_type = ["fp16","fp8","fp4"]

#benchmark for inference 
# TTFT = time to first token  (user interactivity)
# time per output token (tpot): time to generate an output token for each user  100 ms / token /user   -> 10 tokens/sec/user 
# latency = ttft + (tpot)*(total_output_tokens)
# throughput = number of output tokens / second 

#metrics from profiling inference servicing
# MEMBW.utilization (BWU) = (KV cache size + model parameter -size) / TPOT 
# TPOT = (KV cache size + model parameter size) / BWU


#higher parallelism lowers BWU ??
# for fixes batch-size, move lower amount of data in each GPU results into lower BW achieved

#latency reduction through parallelism
#higher GPU(s) with fixed batch-size offers insignificant reduction in latency due to lower memory chunks movement and communication overhead
#larger batch should help alleviate this problem

#increasing batch-size increases TTFT and TPOT 

#The formula to calculate KV cache size is
#batch_size * seqlen * (d_model/n_heads) * n_layers * 2  * 2 (bytes per Float16) * n_kv_heads



#legends

# n - number of tokens
# d - attention dimension
# h_kv - number of kv heads
# h_q  - number of q heads
# p - total number of GPUs
# l - number of layers
# p_j  - parallelism degree for parallelism strategy j (TP,PP,EP,DP)
# F_x  - Flops for x tensor or layer op
# R_x  - bytes of read for x tensor or layer op
# AI_x - Arithmetic intensity of layer x or tensor op x
# C - chunk size
# T - execution time
# T_p  - prefill latency
# T_d  - decode latency


In [6]:
import torch
import numpy as np
import math
from dataclasses import dataclass

In [20]:
##model parameters
@dataclass
class modelConfig:
    
    num_layers  : int = 120
    hidden_dim  : int = 10752
    vocab_size  : int = 100256
    mlp_factor  : int = 4
    num_layers  : int = 120
    context_len : int = 16384
    output_tokens : int = 1024
    total_experts : int = 16
    active_experts : int = 4
    fsdp_enable   : int = 1
    fsdp_factor : float = 0.15
    batch_size : int = 16
    num_microbatches : int = 4
    wt_bpe: int = 2
    act_bpe: int = 2
    kv_bpe: int = 2
    
    def baselayer_params(self):
        return 4 * self.hidden_dim * self.hidden_dim + 13 * self.hidden_dim

    def model_base(self):
        return  (4 * self.hidden_dim * self.hidden_dim + 13 * self.hidden_dim) * self.num_layers + (self.vocab_size + self.context_len) * self.hidden_dim
    
    def moe_params(self):
        return ( self.num_layers * self.total_experts) * (2 * self.mlp_factor * self.hidden_dim * self.hidden_dim)
    
    def qkv_flops(self,seqlen):
        #print(6 * seqlen * self.hidden_dim * self.hidden_dim)
        return  6 * seqlen * self.hidden_dim * self.hidden_dim
   
    def qkv_payload(self,seqlen):
        return  (seqlen*self.hidden_dim*self.act_bpe ,  3 * self.hidden_dim * self.hidden_dim * self.wt_bpe)
   
    def attn_flops(self,seqlen):
       return 2 * 2 * seqlen * seqlen * self.hidden_dim
   
    def kvcache_payload(self,seqlen):
        return (2*seqlen*self.hidden_dim*self.kv_bpe)
   
    def outproj_flops(self,seqlen):
       return 2 * seqlen * self.hidden_dim * self.hidden_dim
    #ln_flops = 8 * context_len * hidden_dim * hidden_factor * hidden_dim
    
    def outproj_payload(self,seqlen):
        return  (seqlen*self.hidden_dim*self.act_bpe ,   self.hidden_dim * self.hidden_dim * self.wt_bpe)
    
    def rope_flops(self,seqlen):
       return seqlen * self.hidden_dim  # rotary position matrix is pre-computed 
    
    def rope_payload(self,seqlen):
        return  (seqlen*self.hidden_dim*self.act_bpe)
    
    def moe_flops(self,seqlen):
    #normal  distribution of 'num_tot_experts' of 4 * context_len
        expert_capacity = self.active_experts * seqlen / self.total_experts
        return 2 * 2 * expert_capacity * self.hidden_dim * self.mlp_factor * self.hidden_dim   # 2 FFN
    
    def moe_payload(self,seqlen):
        expert_capacity = self.active_experts * seqlen / self.total_experts
        return (2*expert_capacity * self.hidden_dim*self.act_bpe, 2*self.hidden_dim * self.hidden_dim * self.wt_bpe)
    
    #unembedding_flops = 2 * context_len * hidden_dim * hidden_dim
    def layernorm_flops(self,seqlen):
        return 2 * 6 * self.hidden_dim * seqlen
    
    def layernorm_payload(self,seqlen):
        return seqlen * self.hidden_dim * self.act_bpe
    
    
    def topk_argmax(self,seqlen):
        return seqlen * self.active_experts * self.total_experts * (self.active_experts-1) * 6   # argmax , sorted max, historgram , atomic 
    
    def moe_topk_softmax(self,l2_latency,seqlen):
        ##algorithm
        
        ## gemm(seqlen,hidden_dim,total_experts)
        ## softmax(seqlen,total_experts)
        ## for topk_iter 0 to active_experts
        ##   for expert_iter 0 to num_experts//num_waves
        ##       for compare_iter 0 to topk_iter -1
        ##           read top_k[]   # l2 latency once and data kept in l1 for each iteration of topk_iter  l2 latency 
        ##           v_cmp + v_mov = 8 cycles
        ##       block_reduce(0) = 64
        ##   write out indices and topk_val #l2_latency  = 256 cycles
        ##   sync_threads()  = 64 cycles
        
        
        topk_cycles  = seqlen * self.active_experts * (l2_latency + self.total_experts  * 8 + 128 + l2_latency)   # 1 read + 1 write
        gemm_cycles  = seqlen * self.hidden_dim * 4 // l2_latency   # activation in l2  weights are amortized by 100% reuse
        softmax_cycles = seqlen * self.active_experts  * 12 // 64  # 64 ops / cycles
        return (gemm_cycles , softmax_cycles, topk_cycles)
    
    def layer_flops(self,seqlen):
        flops = self.qkv_flops(seqlen) + self.attn_flops(seqlen) + self.outproj_flops(seqlen) + self.moe_flops(seqlen)
        return flops
    
#per_layer_base = (4*hidden_dim*hidden_dim + 13*hidden_dim)
#model_base = (4*hidden_dim*hidden_dim + 13*hidden_dim)*num_layers + (VOCAB_SIZE+context_len)*hidden_dim
#moe_params = num_layers*num_tot_experts*(2*factor*hidden_dim*hidden_dim)


In [10]:
#per_layer_base = 2*(4*hidden_dim*hidden_dim ) #+ 13*hidden_dim)
#print(8*per_layer_base/1000/1000/1000)
#moe = 2*16*8*hidden_dim*hidden_dim/
#print(8*moe/1000/1000/1000)
#print(33792e3/1024/1024)

model = modelConfig()
print(model.layer_flops(1)*model.num_layers/1000/1000/1000)

166.47708672


In [77]:
@dataclass
class GpuConfig:
    fp16_flops: int  #flops/cycle/cu
    num_cus : int 
    fp32_flops : int   #valu co-execution
    fma_flops : int # flops/cycle/cu 
    trans_flops : int  
    freq : float #Ghz
    hbm_bw: float  #Gbytes/s   #stream 
    hbm_capacity: int  #Gbytes
    gpus_per_node : int 
    number_of_nodes : int 
    l2_latency : int    #in clocks
    hbm_latency : int   #in clocks
    nic_cards : int
    intracommEff : float
    intercommEff : float
    topology : str  
    intra_bw : float      #Gbytes/s  unidirectional per node 7 links/node
    inter_bw : float      #Gbytes/s  #nic(s) card bw per node 
    intra_latency : int
    inter_latency : int
    reduction_bw : dict    #import tables for variouse sizes
    gather_bw : dict
    scatter_bw : dict
    all2all_bw : dict
    
    def f16flops(self):
        #print(f"f16_flops = {self.fp16_flops * self.num_cus * self.freq}")
        return self.fp16_flops * self.num_cus
    
    def f32flops(self):
        return self.fp32_flops * self.num_cus 
    
    def transflops(self):
        return self.trans_flops * self.num_cus
    
    def total_gpus(self):
        return self.gpus_per_node * self.number_of_nodes
    
    def set_all2all(self,all2all_bwidth : dict):
        self.all2all_bw = all2all_bwidth
    
    def get_all2all(self):
        return self.all2all_bw
    
    def set_reduction(self,reduce_bwidth : dict):
        self.reduction_bw= reduce_bwidth
    
    def get_reduction(self):
        return self.reduction_bw
    
    def set_gather(self,gather_bwidth : dict):
        self.gather_bw= gather_bwidth
    
    def get_gather(self):
        return self.gather_bw
    
    def set_scatter(self,scatter_bwidth: dict):
        self.scatter_bw= scatter_bwidth
    
    def get_scatter(self):
        return self.scatter_bw
    
    
@dataclass
class shardConfig:
    dp_intra_degree : int =1
    pp_intra_degree : int =1
    tp_intra_degree : int =1
    ep_intra_degree : int =1
    #sp_intra_degree : int =1
    
    dp_inter_degree : int =1
    pp_inter_degree : int =1
    tp_inter_degree : int =1
    ep_inter_degree : int =1
    #sp_inter_degree : int =1 
    
    def tp_parallel_degree(self):
        return self.tp_intra_degree * self.tp_inter_degree
    def pp_parallel_degree(self):
        return self.pp_intra_degree * self.pp_inter_degree
    def dp_parallel_degree(self):
        return self.dp_intra_degree * self.dp_inter_degree
    def ep_parallel_degree(self):
        return self.ep_intra_degree  * self.ep_inter_degree
    #def sp_parallel_degree(self):
    #    return self.sp_intra_degree * self.sp_inter_degree
    
    def parse_shard(self,_str):
        shard_num = []
        shard_str = []
        substr_num = []
        substr_char = []
        for i in range(len(_str)):
            if (_str[i].isnumeric()):
                if(len(substr_char) == 2):
                    shard_str.append("".join(substr_char))
                    substr_char.clear()
                substr_num.append(_str[i])
            elif(_str[i].isalpha()):
                if (len(substr_num) >= 1):
                    shard_num.append(int("".join(substr_num)))
                    substr_num.clear()
                substr_char.append(_str[i])

        if (len(substr_num)>=1):
            shard_num.append(int("".join(substr_num)))
        if (len(substr_char) >=1):
            shard_str.append("".join(substr_char))
        return(shard_str,tuple(shard_num))    
    
    def reset(self):
        self.dp_intra_degree =1
        self.pp_intra_degree =1
        self.tp_intra_degree =1
        self.ep_intra_degree =1
    #sp_intra_degree : int =1
    
        self.dp_inter_degree  =1
        self.pp_inter_degree  =1
        self.tp_inter_degree  =1
        self.ep_inter_degree  =1
    
    def print_values(self):
        print(f" intra_cnfiguration ep={self.ep_intra_degree} dp={self.dp_intra_degree} tp={self.tp_intra_degree} pp={self.pp_intra_degree}")
        print(f" inter_cnfiguration ep={self.ep_inter_degree} dp={self.dp_inter_degree} tp={self.tp_inter_degree} pp={self.pp_inter_degree}")
    
    def setup_configuration(self, key: str, value: tuple):
        (shard_key,shard_value) = self.parse_shard(key)
        assert(shard_value == value)
        #print(shard_key)
        #print(shard_value)
        for idx,item in enumerate(shard_key):
            if (item == "dp"):
                (self.dp_intra_degree, self.dp_inter_degree) = (1, shard_value[idx]) if shard_value[idx] <=8 else (shard_value[idx] // 8, 8)
            elif(item == "pp"):
                (self.pp_intra_degree, self.pp_inter_degree) = (1, shard_value[idx]) if shard_value[idx] <=8 else (shard_value[idx] // 8, 8)
            elif(item == "tp"):
                (self.tp_intra_degree, self.tp_inter_degree) = (shard_value[idx],1) if shard_value[idx] <=8 else (8,shard_value[idx] // 8)
            elif(item == "ep"):
                (self.ep_intra_degree, self.ep_inter_degree) = (1, shard_value[idx]) if shard_value[idx] <=8 else (shard_value[idx] // 8, 8)
            else:
                raise ValueError(f" unsupported parallelism configuration given {item} supported = ep,tp,pp,dp")
                    
@dataclass
class layershard_configuration:
    intra_degree : tuple = (1,1,1,1,1)   #dp,pp,tp,ep,sp
    inter_degree : tuple = (1,1,1,1,1)
    activation_einsum : str = None
    weight_einsum : str = None
    
    def tp_parallel_degree(self):
        return self.intra_degree[2] * self.inter_degree[2]
    def pp_parallel_degree(self):
        return self.intra_degree[1] * self.inter_degree[1]
    def dp_parallel_degree(self):
        return self.intra_degree[0] * self.inter_degree[0]
    def ep_parallel_degree(self):
        return self.intra_degree[0] * self.inter_degree[0]
    def sp_parallel_degree(self):
        return self.intra_degree[4] * self.inter_degree[4]
    


#batchsize configurations
#batch_size = 32
#minibatch_size = batch_size // dp_parallel_degree
#microbatch_size = minibatch_size // pp_parallel_degree


#inter & intra bw  #NV
#network_card_bw = 200 ## Gbits/s
#number_cards_per_node = 1 
#network_card_latency = 100  #ns
#intra_gpu_bw = 4.2 ## tbps/s
#intra_network_latency = 100  #ns
#total_intra_gpu_bw = intra_gpu_bw * gpus_per_node
#total_inter_gpu_bw = network_card_bw * number_cards_per_node

@dataclass
class communication_config():
    hidden_dim : int
    context_len : int
    batch_size : int
    bpe: int
    inter_bw: int
    intra_bw: int
    inter_nodes: int
    intra_nodes: int
    inter_latency : int
    intra_latency : int
    topology: str = "ring"

    def payload(self):
        return self.batch_size * self.hidden_dim * self.context_len * self.bpe
    def payload_per_batch(self):
        return self.hidden_dim * self.context_len * self.bpe

@dataclass
class algoConfig():
    mode : str 
    qkv_gemm_eff : float
    attn_gemm_eff : float
    output_gemm_eff : float
    moe_gemm_eff : float
    ln_eff : float
    rope_eff : float
    topk_gating_eff : float
    gelu_eff : float

In [9]:
def gpu_capacity(model_size: int, 
                 output_tokens: int,
                 context_len : int,
                 num_layers: int,
                 hidden_dim: int,
                 dtype: str = "fp8",
                 batch_size: int = 16):

   if dtype == "fp8":
      parameter_byte = 1
   elif dtype == "fp4":
      parameter_byte = 0.5
   else:
      parameter_byte = 2

   parameter_size = parameter_byte*(model_size)
   kvcache_size = parameter_byte*batch_size * (context_len+output_tokens)*hidden_dim*num_layers*2    
   total_size = kvcache_size + parameter_size
   num_gpus = math.ceil(total_size / (8*24*1e9))
   return num_gpus


In [52]:
# Flops required for attention op  per layer

#F_a  = 4*n**2  * d * h_q 

#R_a  = 2 * d * h_kv * l * n

#Ai_a = 4 * n ** 2  * d  * h _q / 2 * d * h_kv * l * n 
#     = 2 * n * h_q / h_kv
#     = n * h_q/h_kv
     
## AI of attention layer is directly proportional to context length (n)
## for long context length inference, sharding tokens across all workers reduces AI
## when number of tokens becomes too small , attention becomes communication/memory bound.

## allocating more workers reduces latency and give better  TBT (time between token) but reduces hardware utilization

#design constraint 

In [None]:
#intra payload 

def TP_communication(N:int, latency: float, BW_LINK: float, payload: int,topology: str="ring"):
    """
     N : for intra-node calculation , intra_node_parallelism degree 
     payload : activation payload ;  payload sliced based on (inter/intra) parallelism
             : 
     BW_LINK : intra/inter link bandwidth 
             : intra : total_intra_bw // gpus_per_node
             : inter :  total_inter_bw // gpus_per_node
     latency : intra/inter link latency
    """
    
    factor = 1 
    
    if topology == "ring":
        num_steps = factor*(N-1)/N
    elif topology == "Tree":
        num_steps = factor*math.log2(N)    
    #network latency in (us)
    #latency calculation in ms
    latency_time = latency * 1e-3 * num_steps * N
    payload_time =  (num_steps * payload * 1e-9 / (BW_LINK) ) * 1e3   #ms
    print(f"payload = {payload}")
    print(f"TP payload_time {payload_time}")
    print(f"TP latency time {latency_time}")
    return latency_time + payload_time
        
def PP_communication(latency: float, BW_LINK: float, payload: int):
    """
    latency : inter/intra node latency
    BW_link  : BW per GPU (divide up total/num_gpus_per_node)
    """
    
    payload_time = 1e3*((payload * 1e-9)/BW_LINK)
    print(f"payload = {payload}")
    print(f"PP payload_time {payload_time}")
    print(f"PP latency time {latency}")
    return ((latency * 1e-3) + payload_time)

def EP_communication(payload: int, num_experts : int, total_gpus: int, 
                     gpus_per_node: int, latency_inter: int, BWIntra_LINK: int, BWInter_LINK: int):
    #experts mapped to inter and intra nodes??
    #tokens are euqally distributed to all number of expert nodes
    # probability of token mapped to intra nodes  num_gpus/num_experts
    """
     num_experts : total number of experts
     num_gpus  : gpus per node
     latency :  node -<>- node latency
     BWIntra_LINK : bw of intra links 
     BWInter_LINK : bw if inter nodes links
     payload : activation payload
     bpe: bytes per element (activation precision)
     
    """
    compute_experts = min(num_experts,total_gpus)
    
    prob_intra_ep =  gpus_per_node / compute_experts
    prob_inter_ep =  (1 - prob_intra_ep)
    communication_steps = ( compute_experts - gpus_per_node)/compute_experts
    #print(compute_experts,gpus_per_node)
    #print(f"communication_steps = {communication_steps}")
    #print(prob_intra_ep,prob_inter_ep)
    #probility of token assigned to intra-node experts = num_gpus/num_experts
    #probability of token assgined to inter-node experts = 1  - num_gpus/num_experts
    payload_intra = payload * prob_intra_ep
    payload_inter = payload * prob_inter_ep
    print(payload_inter,payload_intra)
    latency = 1e-3 * latency_inter * communication_steps * compute_experts / gpus_per_node  ## calculate actual number of GPUS actively involved in inter node communication
    payload_latency = 1e3 * communication_steps * (1e-9 * ((payload_intra / BWIntra_LINK) + (payload_inter / BWInter_LINK)))
    print(f"EP payload_time {payload_latency}")
    print(f"EP latency time {latency}")
    
    return (latency + payload_latency)

def pipeline_eff(pipeline_degree: int , num_micro_batches: int):
    fraction_idle_time = (pipeline_degree-1)/num_micro_batches
    pipeline_eff = (pipeline_degree -1) / (num_micro_batches+(pipeline_degree-1))
    return pipeline_eff,fraction_idle_time


def print_compute_str(profile_rec):
    print(f"qkv_time = {profile_rec['qkv_time']}")
    print(f"attn_time = {profile_rec['attn_time']}")
    print(f"outproj_time = {profile_rec['outproj_time']}")
    print(f"mlp_time = {profile_rec['moe_time']}")
    print(f"topk_softmax_time = {profile_rec['topk_softmax_time']}")
    print(f"rope_time = {profile_rec['rope_time']}")
    print(f"layernorm_time = {profile_rec['ln_time']}")
    

def compute_forward_pass(mode: str, seqlen: int, algoCfg: algoConfig, modelCfg : modelConfig, hwCfg: GpuConfig, layer_dict : dict, verbose: bool = False):
    if mode == "TTFT":   #prefill
        #apply qkv_flops tp 
        qkv_time = modelCfg.qkv_flops(seqlen) * 1e-9 / (algoCfg.qkv_gemm_eff * hwCfg.f16flops() ) * (1/hwCfg.freq)  * 1e3
        layer_dict["qkv_time"] = qkv_time
        attn_time = modelCfg.attn_flops(seqlen) * 1e-9 / (algoCfg.attn_gemm_eff * (hwCfg.f16flops())) * (1/hwCfg.freq) * 1e3
        layer_dict["attn_time"] = attn_time
        out_time = modelCfg.outproj_flops(seqlen) * 1e-9 / (algoCfg.output_gemm_eff * (hwCfg.f16flops())) * (1/hwCfg.freq) *  1e3
        layer_dict["outproj_time"] = out_time
        moe_time = modelCfg.moe_flops(seqlen) * 1e-9 / (algoCfg.moe_gemm_eff * hwCfg.f16flops()) *  (1/hwCfg.freq) * 1e3
        gemm_cycles, softmax_cycles , topk_cycles = modelCfg.moe_topk_softmax(hwCfg.l2_latency,seqlen)
        topk_softmax_moe= (1/hwCfg.freq) * 1e3 * ((gemm_cycles + softmax_cycles+topk_cycles) * 1e-9/(hwCfg.num_cus * 8))    ##occupancy 8
        layer_dict["moe_time"] = moe_time 
        layer_dict["topk_softmax_time"] = topk_softmax_moe
        hbm_bw_gfxclk = hwCfg.hbm_bw / hwCfg.freq 
        ln_time  = (modelCfg.layernorm_payload(seqlen) * 2) * 1e-9 / (algoCfg.ln_eff * hbm_bw_gfxclk)  * 1e3
        layer_dict["ln_time"] = ln_time
        rope_time = (modelCfg.rope_payload(seqlen)*2) * 1e-9/ (algoCfg.rope_eff * hbm_bw_gfxclk) * 1e3
        layer_dict["rope_time"] = rope_time
        compute_time= qkv_time+attn_time+out_time+moe_time+ln_time+rope_time+topk_softmax_moe
        print(f"forward pass time(ms)-prefill = {compute_time}")
        if (verbose):
            print_compute_str(layer_dict)
        return compute_time
    else:
        return 1

def fsdp_communication_overhead():
    #retune the factor based on model size.
    return 0.25

def communication_tp_forward(payload:int,
                             tp_intra_node:int, intra_latency: int, bw_intra:int,
                             tp_inter_node:int, inter_latency: int, bw_inter:int):
    ## tensor parallelism (inter and intra)
    tp_intra_time = TP_communication(tp_intra_node,intra_latency,bw_intra,payload,'ring') if tp_intra_node > 1 else 0
    tp_inter_time = TP_communication(tp_inter_node,inter_latency,bw_inter,payload,'ring') if tp_inter_node > 1 else 0
    
    return tp_inter_time + tp_intra_time

def communication_pp_forward(payload:int, latency: int, bw_inter: int , inter_nodes:int, bw_intra:int, intra_nodes:int):


    ## layer sliced across nodes so need to divide this by layers so multiplying by laers
    pp_inter_time = PP_communication(latency,bw_inter,payload) if inter_nodes > 1 else 0
    pp_intra_time = PP_communication(latency,bw_intra,payload) if intra_nodes > 1 else 0
    
    return max(pp_inter_time,pp_intra_time)


def communication_forward_pass(shardCfg: shardConfig, 
                               comCfg: communication_config, 
                               gpus_per_node:int , num_nodes:int,
                               modelCfg: modelConfig):
    
    assert(comCfg != None)     
    #fixed PP sharding config.
    
    gpus_in_config = shardCfg.dp_parallel_degree() * shardCfg.tp_parallel_degree() * shardCfg.pp_parallel_degree()
    nodes_in_config = gpus_in_config // (shardCfg.pp_intra_degree * shardCfg.tp_intra_degree * shardCfg.dp_intra_degree)
    gpus_in_intra = gpus_in_config//nodes_in_config
    
    pp_inter_parallelism = shardCfg.pp_inter_degree
    pp_intra_parallelsim = shardCfg.pp_intra_degree
    
    
    bw_inter = comCfg.inter_bw // gpus_in_intra
    bw_intra = comCfg.intra_bw 
    payload = comCfg.context_len * comCfg.hidden_dim * comCfg.bpe
    pp_communication_time =  communication_pp_forward(payload,
                                                      comCfg.inter_latency,
                                                      bw_inter,
                                                      pp_inter_parallelism,
                                                      bw_intra,
                                                      pp_intra_parallelsim)
    
    pp_communication_time = pp_communication_time//modelCfg.num_layers   # total time is multipled by layers * batch_size
    
    
    #Add per layer sharding configuration
    # QKV math split Weight[h,3h]dimension between intra & inter  h/intra, 3h/inter  
    # attention layer  BS * heads -> intra
    # output projection     
    
    bw_intra = comCfg.intra_bw
    tp_intra_time = TP_communication(shardCfg.tp_intra_degree,
                                     comCfg.intra_latency,bw_intra,payload//nodes_in_config,'ring') if shardCfg.tp_intra_degree > 1 else 0


    bw_inter = comCfg.inter_bw // (gpus_in_intra)
    tp_inter_time = TP_communication(shardCfg.tp_inter_degree,
                                     comCfg.inter_latency,bw_inter,(payload//gpus_in_intra),'ring') if  shardCfg.tp_inter_degree > 1 else 0
    
    tp_communication_time = tp_intra_time + tp_inter_time
    
    ep_communication_time = 0
    bw_intra = comCfg.intra_bw
    bw_inter = comCfg.inter_bw
    if (shardCfg.ep_inter_degree > 1 or shardCfg.ep_intra_degree > 1):
         ep_communication_time = EP_communication(payload,
                                                  modelCfg.total_experts,
                                                  gpus_in_config,
                                                  gpus_in_intra,
                                                  comCfg.inter_latency,
                                                  bw_intra,
                                                  bw_inter)
                                                  
    ## 2 tp , ep per layer
    forward_com_time = pp_communication_time + 2*tp_communication_time + 2*ep_communication_time  
    forward_com_time = (1 + fsdp_communication_overhead()) * forward_com_time
    return forward_com_time

def pipeline_bubble_time_forward(shardCfg: shardConfig,
                                 comCfg: communication_config,
                                 compute_time : float,
                                 hwCfg : GpuConfig,
                                 modelCfg : modelConfig
                                 ):
    ## n workers m micro_batches  idle slots = n-1 total slots = n+m-1  per pass
    ## bubble fraction time = n-1/ (m+n-1)
    
    ## n layers per model are equally distributed for pp_parallel_degree ; divide by num_layers for later calculations
    num_microbatches = modelCfg.num_microbatches
    #compute_time =  compute_forward_pass(seqlen,algoCfg,modelCfg,hwCfg)//(total_gpus * num_layers)  
    comtime = communication_forward_pass(shardCfg,comCfg,hwCfg.gpus_per_node,hwCfg.number_of_nodes,modelCfg)
    bubble_time = 1/num_microbatches * (shardCfg.pp_parallel_degree()-1) * (compute_time + comtime)
    return bubble_time


In [73]:
#mi300 configuration
hw_dict = {
        "mi300" : {"fp16_flops" : 2048, "fp32_flops" : 128, "fma_flops" : 256 , "freq" : 1.3 , "hbm_bw" : 4200, "num_cus" : 308, "trans_flops" : 16,
                    "hbm_capacity" : 160 , "gpus_per_node" : 8 , "number_of_nodes" : 8, "topology" : "ring", "intracommEff" : .7, "intercommEff" : .5, "l2_latency" : 320, "hbm_latency" : 640,
                    "intra_bw" : 7*64 , "nic_cards" : 8, "inter_bw" : 50 , "intra_latency" : 1 , "inter_latency" : 2 , "gather_bw" : {},
                    "scatter_bw" : {}, "all2all_bw" : {}, "reduction_bw" : {}},
 }

algo_dict = {
        
        "ttft" : {"qkv_gemm_eff" : .85 , "attn_gemm_eff" : .8 , "output_gemm_eff" : .85 , "moe_gemm_eff" : .75 , "ln_eff" : .65,
         "rope_eff" : .8 , "topk_gating_eff" : .6, "gelu_eff" : .7 },
        "tpot" : {},
}

prefillCfg = algo_dict["ttft"]

class shardConfig:
    dp_intra_degree : int =1
    pp_intra_degree : int =1
    tp_intra_degree : int =1
    ep_intra_degree : int =1
    sp_intra_degree : int =1
    
    dp_inter_degree : int =1
    pp_inter_degree : int =1
    tp_inter_degree : int =1
    ep_inter_degree : int =1
    sp_inter_degree : int =1 

tensor_parallel = [1,2,4,8,16,32,64]
pipeline_parallel = [1,2,4,8,16,32,64]
data_parallel = [1,2,4]
expert_parallel = [1,2,4,8,16]



In [72]:
#build shard connfigurations. 
#num gpus = 64

tensor_parallel = [1,2,4,8,16,32,64]
pipeline_parallel = [1,2,4,8,16,32,64]
data_parallel = [1,2,4]
expert_parallel = [1,2,4,8,16]


def return_key(inp: tuple):
    dp,pp,tp,ep = inp
    val = ()
    temp_list = list(val)
    key_str = " "
    if (tp > 1):
        key_str += "tp" + f"{tp}"
        temp_list.append(tp)
    if (pp > 1):
        key_str += "pp" + f"{pp}"
        temp_list.append(pp)
    if (dp > 1):
        key_str += "dp" + f"{dp}"
        temp_list.append(dp)
    if (ep > 1):
        key_str += "ep" + f"{ep}"
        temp_list.append(ep)
        
    val = tuple(temp_list)
    return(key_str,val)   
    
permute_list = [(dp,pp,tp,ep) for dp in tensor_parallel for pp in pipeline_parallel for tp in data_parallel for ep in expert_parallel]
#print(len(permute_list))

shard_permute = {}
num_gpus = 64
for item in permute_list:
    dp,pp,tp,ep = item
    if (tp*pp*dp*ep == num_gpus):
        (key_str,value) = return_key(item)
        shard_permute[key_str] = value

#remove one configuration from the dict
shard_cfg = {key: value for key,value in shard_permute.items() if len(value) > 1}
#print(len(shard_cfg))


In [78]:
def ttft_time(batch_size:int, seqlen:int):
    modelCfg = modelConfig()
    hwCfg = GpuConfig(**hw_dict["mi300"])
    ttftCfg = algoConfig("ttft",**prefillCfg)
    shardCfg  = shardConfig()

    ttft_dict = {}
    layercompute_dict = {"qkv_time":0, "attn_time" : 0 , "outproj_time" : 0, "moe_time" : 0, "ln_time" : 0 , "rope_time" : 0 }
    #computetime_spent = compute_forward_pass("TTFT",seqlen,ttftCfg,modelCfg,hwCfg,layercompute_dict)
    comCfg = communication_config(modelCfg.hidden_dim,
                                  seqlen,
                                  batch_size,
                                  modelCfg.wt_bpe,
                                  hwCfg.nic_cards * hwCfg.inter_bw * hwCfg.intercommEff,
                                  hwCfg.intra_bw * hwCfg.intracommEff,
                                  hwCfg.number_of_nodes,
                                  hwCfg.gpus_per_node,
                                  hwCfg.inter_latency,
                                  hwCfg.intra_latency,"ring")
    computetime_spent = compute_forward_pass("TTFT",seqlen,ttftCfg,modelCfg,hwCfg,layercompute_dict,True)
    for (key,val) in shard_cfg.items():
         shardCfg.setup_configuration(key,val)
         print(f"shard_topology : {key}")
         modelCfg.num_microbatches = shardCfg.pp_inter_degree * shardCfg.pp_intra_degree
         commtime_spent = communication_forward_pass(shardCfg,comCfg,hwCfg.gpus_per_node,hwCfg.number_of_nodes,modelCfg)
         parallelism_degree = shardCfg.pp_parallel_degree() * shardCfg.tp_parallel_degree() * shardCfg.dp_parallel_degree()
         computetime = computetime_spent//(parallelism_degree * modelCfg.num_layers)
         pipeline_time = pipeline_bubble_time_forward(shardCfg,comCfg,computetime,hwCfg,modelCfg) if (shardCfg.pp_inter_degree * shardCfg.pp_intra_degree) > 1 else 0
         time_batch_size_layer = (computetime_spent//(parallelism_degree) + commtime_spent + pipeline_time)
         print(time_batch_size_layer)
         time = batch_size * modelCfg.num_layers * time_batch_size_layer
         ttft_dict[key] = { key : shard_cfg[key], "latency" : time}
         shardCfg.reset()
         assert(0)
    return ttft_dict
    #for (key,val) in shard_cfg.items():
    #    shardCfg.setup_configuration(key,val)
    #    #shardCfg.print_values()
    #    shardCfg.reset()

In [83]:
ttft_time(16,16384)


forward pass time(ms)-prefill = 52.28390700194937
qkv_time = 16.304472562731384
attn_time = 17.59847832167832
outproj_time = 5.434824187577129
mlp_time = 12.318934825174825
topk_softmax_time = 0.019023024787712287
rope_time = 0.27262976000000005
layernorm_time = 0.33554432
shard_topology :  tp4ep16
payload = 352321536
TP payload_time 0.8426057142857145
TP latency time 0.003
4 4
communication_steps = 0.0
1.0 0.0
0.0 352321536.0
EP payload_time 0.0
EP latency time 0.0
15.114014285714287


AssertionError: 

In [None]:
def ttft_time(batch_size:int, seqlen:int):
    modelCfg = modelConfig()
    hwCfg = GpuConfig(**hw_dict["mi300"])
    ttftCfg = algoConfig("ttft",**prefillCfg)
    shardCfg  = shardConfig()