In [254]:
import os
import sys
from pathlib import Path
import numpy as np      
import math
import argparse
from enum import Enum

In [537]:
#LLM Inference roofline modeling for AMD GPU Mi300
#infernce has two passes. prefill + token generation
# time taken = prefill + token_gen
# prefill (prompt processing) = run through once forward pass for entire prompt sequence length
# decode = generate out token length for each step
# for large batch-size, both prefill and token is bound by linear layer flops
# usually inference has smaller batch-size , hence token generation phase dominated
# by bandwidth bound linear layer.

#llama_65B model configurtion  parameters
max_seq_len = 2048
model_dim = 8192
hidden_dim = 22016
num_layers = 80
num_heads = 64
out_tokens = 32
vocab_size = 32000
model_size = 65  # 1e9 Billion

def qkv_proj(seqlen:int =1):
    return 2*3*model_dim*seqlen

def out_proj(seqlen:int =1):
    return 2*model_dim*model_dim*seqlen

def mlp(seqlen:int = 1):
    return 4*model_dim*hidden_dim*seqlen

def fa(seqlen:int = 1):
    return (2*2*seqlen*seqlen*model_dim)

#forward pass flops calculation
qkv_flops = qkv_proj(max_seq_len)
out_flops = out_proj(max_seq_len)
mlp_flops = mlp(max_seq_len)
#use flash attention for  self-attention layer in both prefill and token generation   
fa_flops  = fa(max_seq_len)
#total_flops = num_layers * (qkv_proj+out_proj+mlp+fa_flops)

#required flops for 65B model 
total_flops = num_layers * (qkv_flops+out_flops+mlp_flops+fa_flops)
print("****************************************************")
print("******  MI300 benchmarks (roofline) ****************")
print("****************************************************")
#mi300 configuraiton fp16 precision
memory_bw = 5.3 * 64*128/8  # (64 bits/ channel 128 channel 5.3 bit rate)
achieved_bw = .8 * memory_bw / 1000
print(f"acheiveable memory bw (Gbytes/s) {achieved_bw}")
compute_flops = 304*2048*1.5/1000    # core frequency = 1.5Ghz
gemm_eff = 0.85
print(f"achieveable compute (tera) flops = {compute_flops}")
#acheived memory bw (Gbytes/s) 4341.76
#achieveable compute (tera) flops = 933.888

#capacity 
MI300_hbm_capacity = 192   #GBytes/GPU
#capcity required by model (weights + KV cache)
#kv cache - token generation phase require past attention KV values for every
#step of the phase (source GPT2- paper), instead regenerating KV tensors, store them 
# in cache / hbm and read back for every step.

model_capacity = 65*2   # fp16 precision
kv_cache_capacity = 2*2*model_dim*num_layers*(max_seq_len+out_tokens)/1024/1024/1024   #per max-seq-len
print(f"kv_cache capacity (Gbytes) ={kv_cache_capacity}")
#kv_cache capacity (Gbytes) =5.078125

batch_size_max = np.floor((MI300_hbm_capacity-(model_capacity+kv_cache_capacity)))
print(f" maximum batch_size we could use (per GPU) = {batch_size_max}")
#maximum batch_size we could use (per GPU) = 56.0
print("")

#using 1 GPU , we could fit 65B model and still do get decent  performance 
# with ~50 batch-size.
# there are two reasons, we want to go upto 8 GPu(s) are,
# 1. higher batch-size to make problem size as compute bound
# 2. apply (tensor,pipeline, data) parallelism to inferenece to get higher thoughput

#parallelism trade-offs
# sharding layer,weights, data across n GPUs would help us get higher througput but it also
# introduces new parameters that needs to be part of calculations namely
# communication payload between GPU(S), reduction , synchronization, gather and communication latency
# 

# pipeline parallelism
# each GPU is mapped to 'n' layers (total_layers/ pipeline rank devices) 
# each needs to send and recv inputs from neighbur devices through communication
# network. 

# (MLP) tensor parallelism
# weights of linear layer are split in the middle and mapped to GPU devices
# after the completion of linear layer, all reduction op must be performed.
# MLP layer split among 'n' devices, after second layer all devices 
# communicate hidden_dim parameters for reducation and scatter back.

# attention layer parallelism
# heads mapped to different devices, all devices communciate model_dim/N parameters


#Communication perameters
# payload (bytes xfered between devices)
num_devices = 8
payload_size = model_dim * (num_devices-1) / num_devices   #per token
communication_latency = 12   # micro-seconds 
communication_bw = 448       # GBytes/s
communication_eff = .7       # based on communication primitives micro-kernel

#Batch_size =16 (latency bound case)
# model performance is bounded by memory bw and communcation latency

#token_genearation =- 65B model batch_size=16
seq_len = 1
kvcache_read = (num_layers*model_dim*max_seq_len*2*2 / 1e9) / (num_devices *achieved_bw)

compute_time = np.round(2*model_size*1e9*seq_len / (num_devices *achieved_bw*1e12)  * 1e3,1)   # convert to ms
communication_time = np.round(2*num_layers*communication_latency / 1e3,1)      # convert to ms

#print(f"kvache_read_time {kvcache_read}")
print("Using model-size as flops (Approximate)")
print(f"compute_time {compute_time}")
print(f"communcation time {communication_time}")
#missing embedding , unembedding and actation layers 

#ignoring kvcache read time taken per token 
print(f"1 - token generation (ms): {max(compute_time,communication_time)} ")

#using flops calculator
qkv_flops = qkv_proj(seq_len)
out_flops = out_proj(seq_len)
mlp_flops = mlp(seq_len)
#use flash attention for  self-attention layer in both prefill and token generation   
fa_flops  = fa(seq_len)
total_flops = num_layers * (qkv_flops+out_flops+mlp_flops+fa_flops)
#print(num_layers*24*model_dim*model_dim)
#print(f"total_flops {total_flops}")
print("")
print("Using flops-calculator:")
compute_time = np.round((total_flops*seq_len)/(num_devices *achieved_bw*1e12)*1e3,4)
print(f"1token generation (ms) {compute_time} ")
print("")



#large_batcsize = 512 tokens
seq_len = 2048
qkv_flops = qkv_proj(seq_len)
out_flops = out_proj(seq_len)
mlp_flops = mlp(seq_len)
#use flash attention for  self-attention layer in both prefill and token generation   
fa_flops  = fa(seq_len)
total_flops = num_layers * (qkv_flops+out_flops+mlp_flops+fa_flops)
#total_flops = total_flops/(1e9)
total_flops = total_flops
print(total_flops)
total_flops = total_flops/1000/1000/1000
compute_time = np.round(2*total_flops *1e9 / (num_devices*compute_flops*1e12*gemm_eff) * 1e3,2)
communication_time = np.round((seq_len * 2 * 4 * num_layers * model_dim) / (1e9*communication_bw * communication_eff) *1e3,1)
print("Using flops-calculator:")
print(f"compute_time {np.round(compute_time,1)} (ms)")
print(f"communication time {np.round(communication_time,1)} (ms)")
print(f"{seq_len}-token generation time = {max(compute_time,communication_time)} (ms)")


****************************************************
******  MI300 benchmarks (roofline) ****************
****************************************************
acheiveable memory bw (Gbytes/s) 4.34176
achieveable compute (tera) flops = 933.888
kv_cache capacity (Gbytes) =5.078125
 maximum batch_size we could use (per GPU) = 56.0

Using model-size as flops (Approximate)
compute_time 3.7
communcation time 1.9
1 - token generation (ms): 3.7 

Using flops-calculator:
1token generation (ms) 1.9709 

151190901882880
Using flops-calculator:
compute_time 47.6 (ms)
communication time 34.2 (ms)
2048-token generation time = 47.62 (ms)


In [514]:
import math

MODEL_CLASSES = {
    "Gpt3_175B"  : {"microbatch_size": 1, "batch_size": 1, "max_seq_len": 2048, "vocab_size": 64000, "model_dim": 12288, "hidden_dim": 49152, "num_layers" : 96, "attn_heads" : 96, "out_tokens": 32},
    "llama_65B"  : {"microbatch_size": 1, "batch_size": 1, "max_seq_len": 2048,"vocab_size": 32000, "model_dim": 8192, "hidden_dim": 22016, "num_layers" : 80, "attn_heads" : 64, "out_tokens": 32},
}
system_config = {
    "mi300" : { 'hbm_cap': 192, 'lds_cap': 64,
               'num_cus': 304,
               'fp16_rate' : 2048, 'fmax' : 1.5 , 'l2_rdbw' : 16384,
               'l2_wrbw': 8192, 'hbm_bw' : 6.4*8192, 'lds_bw' : 128, 
               'fma_rate' : 256, 'comm_bw' : 448 , 'comm_eff' : .7, 'comm_latency' : 12}
}

ProbMem = dict()
ProbMem['KVCacheGpu'] = 0.5
ProbMem['WeightsGpu'] = 0.5
ProbMem['ActGpu'] = 0.5
ProbMem['KVCacheCpu'] = 0.2
ProbMem['WeightsCpu'] = 0.2
ProbMem['ActCpu'] = 0.2
ProbMem['KVCacheDisk'] = 0.3
ProbMem['WeightsDisk'] = 0.3
ProbMem['ActDisk'] = 0.3

launch_setup_eff_loss = 5 
NUM_GPUS = 8
Bpe = 2   #fp16
stream_bw_attainment = .8  #efficiency

class MODELCOMPONENTS(Enum):
    Weights = "weights"
    Activation = "activation"
    KVcache  = "KVcache"
    MOES = "Moes"
    
class MODELPHASE(Enum):
    PREFILL = 0
    TOKEN = 1    
    
class MODELPIPELINE(Enum):
    INFERENCE = 0
    TRAINING = 1


tile_sizes = [(16,16),(16,32),(32,16),(32,32),(32,64),(64,32),(64,64),(64,128),(128,64),(128,128),(256,256),(256,128)]
tilek_sizes = [1024,512,256,128,64,32]

#memory_bw formula
def membw_acheived():
  membw_rate = system_config['mi300']['hbm_bw']
  eff = 0.65  # realistic efficiency that can be achieved
  return membw_rate*eff/8

def membound_gemm(M: int, N:int, K:int, gpu_name: str = "mi300"):
    #check size
    # convert size into per CU buffer 
    # using FMA ops vs MFMA ops
    # we need ~150 CU(s) to sustain HBM BW. 
    _config = system_config[gpu_name]
    number_of_elements = N*K if M<16 else M*K
    gemm_m = N if M<16 else M
    num_tiles_per_cu = gemm_m/16/_config['num_cus']
    # do i need splitK?
    if num_tiles_per_cu <=1:
        split_k = 4
    elif num_tiles_per_cu <=2:
        split_k = 2
    else:
        split_k = 1
    number_of_elements = number_of_elements/1024/1024/1024
    time_duration = number_of_elements*2/(1e6*membw_acheived()*.85)
    return time_duration

def gemm_l2_bw(gemm_k: int, tile_size : tuple = (128,128), bpe:int =2, alu_rate: int = 2048):
    num_flops = tile_size[0]*tile_size[1]*gemm_k*2
    bytes_to_read = bpe*(tile_size[0]*gemm_k + tile_size[1]*gemm_k)
    bytes_to_write = bpe*(tile_size[0]*tile_size[1])
    l2_bw_need =  (bytes_to_read + bytes_to_write)/(num_flops//alu_rate)
    return l2_bw_need

def gemm_lds_bw(bpe: int = 2, tile_k: int =128,  tile_size : tuple = (128,128), alu_rate: int =2048):
    num_flops = tile_size[0]*tile_size[1]*tile_k*2
    #4 waves / WG ; 1 WG/CU work dimension 
    #squared tile requried 2x read-size
    bytes_to_read = 2*bpe*(tile_size[0]*tile_k + tile_size[1]*tile_k)
    bytes_to_write = bpe*(tile_size[0]*tile_size[1])
    lds_bw_need =  (bytes_to_read + bytes_to_write)/(num_flops//alu_rate)
    return lds_bw_need
 
def gemm_eff(bpe: int =2, Mtile: int = 128, Ntile: int = 128,  K: int=1024, gpu_name: str = "mi300"):
    _config = system_config[gpu_name]
    l2_bw_required = gemm_l2_bw(K,tile_size=(Mtile,Ntile),bpe=bpe,alu_rate=_config['fp16_rate'])
    lds_bw_required = gemm_lds_bw(bpe=bpe,tile_size=(Mtile,Ntile),alu_rate=_config['fp16_rate'])
    bw_per_cu = _config['l2_rdbw']/_config['num_cus']
    #print(f"gemm_eff:: lds_bw_required= {lds_bw_required}")
    #print(f"gemm_eff:: l2_bw_required= {l2_bw_required}")
    gemm_eff = min(bw_per_cu/l2_bw_required, 1)
    gemm_eff = min(gemm_eff,_config['lds_bw']/lds_bw_required)
    gemm_eff = gemm_eff - (launch_setup_eff_loss/100)
    return gemm_eff
    
def  flashattn_eff(num_wgs_cu:int = 2, bpe:int = 2, attn_dim: int = 128, qtile: int = 128, ktile: int = 128, vtile: int=128, gpu_name:str = "mi300"):
     _config = system_config[gpu_name]
     softmax_cycles = softmax_rate(qtile,ktile)
     softmax_per_cycles = qtile*ktile//softmax_cycles
     #QKgEMM eff
     #qtile is read once in persistent in register or shared mem
     #qtile load is amortized with Ktensor*qtile*2*attn_dim math
     kbytes_read = ktile*bpe*attn_dim
     num_flops = qtile*ktile*attn_dim*2
     l2bw_req = kbytes_read/(num_flops//_config['fp16_rate'])
     bw_per_cu = math.floor(_config['l2_rdbw']/_config['num_cus'])
     #print(f"bw_per_cu = {bw_per_cu} kbytes_read = {kbytes_read} l2bw_req= {l2bw_req} num_flops = {num_flops}")
     qkgemm_eff = min(bw_per_cu/l2bw_req, 1)
     qk_cycles = num_flops//_config['fp16_rate']
     qk_elements_per_cycle = qtile*ktile/(qk_cycles/qkgemm_eff)
     #print(f"qk_elements_per_cycle {qk_elements_per_cycle}")
     vbytes_read = vtile*attn_dim*bpe
     num_flops= qtile*vtile*attn_dim*2
     l2bw_req = vbytes_read/(num_flops//_config['fp16_rate'])
     pvgemm_eff = min(bw_per_cu/l2bw_req,1)
     pv_cycles = num_flops//(_config['fp16_rate'])
     #print(f"pvgemm_eff = {pvgemm_eff}")
     #print(f"qkgemm_eff = {qkgemm_eff}")
     #print(f"softmax_cycles = {softmax_cycles}")
     #print(f"qk_cycles = {qk_cycles}")
     #print(f"pv_cycles = {pv_cycles}")
     qksoftmax_eff = min(softmax_per_cycles/qk_elements_per_cycle,1)
     #print(f"qksoftmax_eff = {qksoftmax_eff}")
     if (num_wgs_cu >=2):
         fa_eff = qksoftmax_eff * 0.8   #GEMM efficiencies are about .8 for head-dim=128 
     elif num_wgs_cu >= 1:
         fa_eff = (qk_cycles+pv_cycles)/((qk_cycles/qkgemm_eff)+softmax_cycles*(2-num_wgs_cu)+(pv_cycles/pvgemm_eff))
     else:
         fa_eff = num_wgs_cu*(qk_cycles+pv_cycles)/((qk_cycles/qkgemm_eff)+softmax_cycles+(pv_cycles/pvgemm_eff)) 
     print(f"fa_eff {fa_eff}")
     return fa_eff
    
    
def softmax_rate(qtile: int = 128, ktile: int=128):
    #number of ops required for softmax 
    #max3 -> reduction(wave32 reduction) -> max,  sub(x-max) -> exp() -> rowsum -> fma   
    #print(f" qtile = {qtile} ktile= {ktile}")
    number_softmax_per_wave = qtile*ktile//4/64    #64 threads/wave 4= simds/CU
    #print(f"number_softmax_per_wave = {number_softmax_per_wave}")
    number_max3_wave =  number_softmax_per_wave//2    #max3 = max(x,max(y,z))  
    _max3_cycles = number_max3_wave * 4   #16 elements/cycle/simd  
    _max3_reduction = 48  # max3 reduction across 32 threds using ds_bpermute (latency)
    #sub(x-max) using fma 
    fma_cycles = number_softmax_per_wave * 2   # 32 ops per cycle/simd  
    exp_cycles = number_softmax_per_wave * 16  # exp2f32()  4 ops per cycle/simd
    rowsum_cycles = number_softmax_per_wave * 2  # 32 ops per cycle/simd
    rowsum_reduction = 48  # ds_bpermute
    scaling_cycles = number_softmax_per_wave/2 * 4 # 16 ops per cycle/simd #pk_mul
    total_cycles = _max3_cycles + _max3_reduction + fma_cycles + exp_cycles
    total_cycles += rowsum_cycles + rowsum_reduction + scaling_cycles
    print(f"softmax_cycles = {total_cycles}")
    return total_cycles

def ret_ktile(bpe: int=2, tile_sizes:tuple = (128,128), gpu_name: str = "mi300"):
    for _val in tilek_sizes:
        if ((tile_sizes[0]+tile_sizes[1])*bpe*_val <=  (system_config[gpu_name]['lds_cap']*1024)//2):
            return _val
        else:
            continue
    assert(0)

def return_nearestpowerof2(N: int) -> int:
    if (N  and not (N and N-1)):
        return N
    else:
        a=1
        while(pow(2,a) < N):
            a +=1
        return pow(2,a)


def find_bestfaperformance(q_len: int, k_len: int, attn_dim: int, num_heads: int , bpe: int = 2, gpu_name: str="mi300"):
    
    #2D shard per GPU for num_heads x num_query
    num_cus = system_config[gpu_name]['num_cus'] / num_heads
    tile_size= q_len / num_cus
    tile_size= return_nearestpowerof2(tile_size)
    q_tiles = q_len/tile_size
    num_fatiles = q_tiles * num_heads
    fatiles_percu = np.round(num_fatiles / system_config[gpu_name]['num_cus'],1)
    qtile_size = tile_size
    #LDS memory to hold K/V elements; Q tile is persistent in register
    #64KB LDS size 2 WG(S) holding K elements  
    kvtile_size = 64*1024 // (2*attn_dim * bpe)
    #print(tile_size,num_fatiles,fatiles_percu)
    eff = flashattn_eff(fatiles_percu,bpe,attn_dim,qtile_size,kvtile_size,kvtile_size,gpu_name)
    return eff


def find_bestgemmperformance(bpe: int =2, M: int = None, N: int =None, K:int =None, gpu_name: str = "mi300"):
    assert(K!=None)
    assert(M!=None)
    assert(N!=None)
    
    if M < 16 or N< 16:
       cycles_taken=membound_gemm(M,N,K)
       #print(f"cycles_taken = {cycles_taken}")
       time_taken = cycles_taken/(system_config[gpu_name]['fmax'] * 1e6)
       return time_taken
    _efficiency_table = []
    for tile_size in tile_sizes:
        #ktile = ret_ktile(bpe,tile_size,gpu_name)
        #print(f"tile_size[0:1] = {tile_size[0],tile_size[1]}")
        _effmetrics = {}
        _effmetrics['tile_size'] = tile_size
        _effmetrics['eff'] = gemm_eff(bpe,tile_size[0],tile_size[1],K,gpu_name)
        _effmetrics['GPU_util'] = GPU_util(gpu_name,M,N,tile_size[0],tile_size[1])
        #print(f"tile_size[0:1] = {tile_size[0],tile_size[1]}  eff = {_effmetrics['eff']} GPU_util= {_effmetrics['GPU_util']} ")
        _efficiency_table.append(_effmetrics)
    
    assert(len(_efficiency_table)>0)
        
    pick_winner = []    
    for _item in _efficiency_table:
        metrics = _item
        pick_winner.append(metrics['eff']*metrics['GPU_util'])
    
    winner_idx = np.argmax(pick_winner)
    #print(f"GEMM_winner:: {_efficiency_table[winner_idx]['eff']}")
    eff = _efficiency_table[winner_idx]['GPU_util'] * _efficiency_table[winner_idx]['eff']
    fprate =  system_config[gpu_name]['num_cus'] * system_config[gpu_name]['fp16_rate']
    cycles_taken = 2*M*N*K/(eff*fprate)
    #print(f"cycles_taken = {cycles_taken}")
    write_cyles = M*N*bpe/(.8*27*304)
    cycles_taken = write_cyles + cycles_taken
    time_taken = cycles_taken/(system_config[gpu_name]['fmax'] * 1e6)  + 0.005 
    return time_taken
    
         
    #print(f"GEMM_winner:: {_efficiency_table[winner_idx]['eff']}")
    #return _efficiency_table[winner_idx]

def GPU_util(gpu_name: str = 'mi300', gemm_m : int = 1024, gemm_n : int = 1024, mtile:int =128, ntile:int =128):
    
    num_cus = system_config[gpu_name]['num_cus']
    m_tiles = gemm_m/mtile
    n_tiles = gemm_n/ntile
    tiles_per_cu = m_tiles*n_tiles / num_cus
    tiles_per_cu = tiles_per_cu/(math.ceil(tiles_per_cu))
    return tiles_per_cu

def layerNormCycles(bs:int = 1, seq_len:int = 1, model_dim:int = 1024):
    #layernorm ops
    # normalize on last two dimension of tensor
    #mean = torch.mean(tensor[-1,:,:],dim=-1)
    #var  = torch.square(tensor[-1,:,:]-mean).mean(dim=-1)
    # layer_norm = tensor[-1,:,:] - mean / (torch.sqrt(var)

    #welford online layernorm
    #1. pow(x), 2. sum 3. shuffle-op, 4.wglevelreduction, 5. sync() 6. final LN=2 ops.
    # use_dowrdx4 4096 bytes/inst 32 cycles/inst (issue_rate)
    # grid_size / GPU [batch_size * seq_len,1,1]
    #check BW if op-limited (calculate op cycles / payload size)
    # if op-cycles < payload_size_latency then return we should able to
    # stream benchmark otherwise fraction
    # because too many ops involved in LN , memclk might be running lower than optimal frequency (ignore that for now) 
    mean_cycles = (model_dim/64)*2  # 2 ops for mean
    var_cycles =  (model_dim/64)*2  # 2 ops for variance
    
    wave_level_reduction = np.log2(64) * 4
    WG_level_reduction = 48 + 16   # write to LDS and each wave does final mean, variance reduction 
                                   # 48 cycles pipe latency 
    final_LN  = (model_dim/64)*3
    
    total_cycles = final_LN + mean_cycles + var_cycles + wave_level_reduction + WG_level_reduction
    #print(f"total_cycles_per_token_per_cu={total_cycles}")
    bw_gpu = membw_acheived()
    #bw_gpu_per_cu_clk = bw_gpu / system_config['mi300']['fmax']
    #bytes_per_cu = bw_gpu_per_cu_clk/system_config['mi300']['num_cus']
    payload_latency = (model_dim*4*bs*seq_len/(1e6 *bw_gpu)) # layer_norm in float32
    
    return payload_latency
    
    #if total_cycles < payload_latency:
    #    return bw_gpu_per_cu_clk  
    #else:
        #mmight not have enough mem instructions in pipe to achieve BW
    #    return (payload_latency/total_cycles)*bw_gpu_per_cu_clk  
    
def geluCycles(model_dim: int = 1024):
    #gelu involves with transcendental ops and cubic-root ops.
    #gelu =   0.5*tensor * (1.0f + torch.tanh(torch.sqrt(2/numpy.pi)*(x+0.044715*x*x*x)))
    # tanh = exp(2.0*x)- 1.0f/ exp(2.0*x) + 1.0f
    # tanh(x) = 1 exp() + 1 rcp + 4 fma  + 2 conditional check = 48 cycles
    # x = 5 fma 
   # 3 fma + 1 add
   # 2 mov instructions for constants
   # total =  12 fma + 1 add + 1 exp + 1 rcp + few mov and conditional check instructions
   #     100 =  48 cycles + 4 cycles + 16 cycles + 8 cycles + 8 cycles + 16 cycles 
   #     
   # approx ~100 cycles
   # mi300 f32 floats rate = 256/cycle/cu 
   #                       = 256/100 cycles = 2.5 
   # GEMM -rate (floats/cycle/cu) = MN/2*M*N*K/2048 = 2048/2*K = 1024/K  
   #                       GEMM_K must be > 1024/2.5  for GELU is not critical path for MLP1

   num_float_elements = 256  #thread-items per WG 
   num_cycles_per_256 = 100   #approximately 100 cycles for 256 elements
   return (model_dim/num_float_elements)*num_cycles_per_256

### Inference modeling of LLM for AMD GPU
Below is detailed (somewhat)  modeling of LLM for inference tasks for AMD GPUs using 2D parallelism (pipeline and tensor parallelism). for each operator of layers, I am using VALU/Matrix ops requires to implement layer kernels using GPU multi-thread programming. algorithm are little more accurate than simple roofline (fma ops).

Analysis is done for mi300 , latest from AMD. Most of the perf numbers are already out in the publications.

In [441]:
#Peak memory assertions for various system Components
# for running  large models using single GPU + CPU memory + disk CAche

def GPU_PeakMem(Phase : MODELPHASE = MODELPHASE.TOKEN.value, ModelConfig: dict = None, probMem: dict = None, batch_size: int =1, sysCfg: dict() = None) -> bool:
 
    """ Provide GPU peak memory (in elements) required for given model size & model phase"""
    
    if probMem is None:
        raise ValueError("ProbMem is None")

    dim1 = ModelConfig["hidden_dim1"]
    dim2 = ModelConfig["hidden_dim2"]
    seqlen = ModelConfig["max_seq_len"] 
    outToken = ModelConfig["out_tokens"]
    numLayers = ModelConfig["num_layers"]
    numHeads = ModelConfig["num_heads"]
    
    GpuMem_resident_layer = probMem["WeightsGpu"] * 4*pow(dim1,2) + 2*dim1*dim2  + probMem["KVCacheGpu"]*2*(seqlen + outToken)*dim1
                                    
    resident_act = ProbMem["ActGpu"]*batch_size*seqlen*dim1 if Phase == MODELPHASE.PREFILL.value else ProbMem["ActGpu"]*batch_size*1*dim1
    GpuMem_resident = GpuMem_resident_layer * numLayers + resident_act
    
    #Working memory for each op
    #QKV tensor matrices seqlen x dim1 [3] + input seqlen * dim1  
    #for Token generation seqlen = 1
    QKV = batch_size*(4*seqlen*dim1) if Phase == MODELPHASE.TOKEN.value else batch_size*(4*dim1)
    if Phase == MODELPHASE.PREFILL.value:
        #QK inputs + attention outputs
        QK_attention = ProbMem["KVCacheGpu"]*batch_size*(seqlen*dim1 + seqlen*seqlen*numHeads + seqlen*dim1)
        QKV_attention = ProbMem["KVCacheGpu"]*batch_size*(seqlen*dim1 + seqlen*seqlen*numHeads + seqlen*dim1)
    else:
        #QK inputs Q seqlen =1 K seqlen (kv cache) 
        QK_attention = ProbMem["KVCacheGpu"]*batch_size*(dim1 + 1*(seqlen+outToken)*numHeads + (seqlen+outToken)*dim1)
        QKV_attention = ProbMem["KVCacheGpu"]*batch_size*(1*dim1 + 1*(seqlen+outToken)*numHeads + (seqlen+outToken)*dim1)   
        
    Embed = batch_size*(seqlen*dim1*2) if Phase == MODELPHASE.PREFILL.value else  batch_size * (2*dim1)
    MLP1 = batch_size*(seqlen*dim1 + seqlen*dim2) if Phase == MODELPHASE.PREFILL.value else batch_size*(1*dim1 + 1*dim2)
    MLP2 = batch_size*(seqlen*(dim1+dim2)) if Phase == MODELPHASE.PREFILL.value else batch_size*(1*dim1 + 1*dim2)
    
    #Total working memory by adding xfered weights stored higher hierarchy 
    # activation stored in hierachy memory
    # max(all the ops in one layer)
    workingMem = (1 - probMem["WeightsGpu"])*(4*pow(dim1,1) + 2*dim1*dim2) +\
                (1 - probMem["ActGpu"])*batch_size*seqlen*dim1 + \
                 max(QKV,QK_attention,QKV_attention,Embed,MLP1,MLP2)
                 
    PeakMemory = workingMem + GpuMem_resident
    assert(PeakMemory < sysCfg['gpuMemCapacity'])
    return True
    

def  CPU_PeakMem(Phase : MODELPHASE = MODELPHASE.TOKEN.value, 
                 ModelConfig: dict = None, probMem: dict = None, 
                 batch_size: int =1, 
                 sysCfg: dict = None) -> bool:        
    
    """ 
    Provide CPU peak memory (in elements) required for given model size & model phase
    """
    
    if ProbMem is None:
        raise ValueError("ProbMem is None")
    
    dim1 = ModelConfig["hidden_dim1"]
    dim2 = ModelConfig["hidden_dim2"]
    seqlen = ModelConfig["max_seq_len"] 
    outToken = ModelConfig["out_tokens"]
    numLayers = ModelConfig["num_layers"]
    #numHeads = ModelConfig["num_heads"]
    
    CpuMem_resident_layer = probMem["WeightsCpu"] * (4*pow(dim1,2)) + 2*dim1*dim2  + probMem["KVCacheCpu"]*2*(seqlen + outToken)*dim1
    resident_act = ProbMem["ActCpu"]*batch_size*seqlen*dim1 if Phase == MODELPHASE.PREFILL.value else ProbMem["ActGpu"]*batch_size*1*dim1
    CpuMem_resident = CpuMem_resident_layer * numLayers + resident_act
    
    #working memory stored in disk through CPU
    #Total working memory by adding xfered weights stored higher hierarchy 
    # activation stored in hierachy memory
    # max(all the ops in one layer)
    workingMem = (1 - probMem["WeightsCpu"])*(4*pow(dim1,1) + 2*dim1*dim2) +\
                (1 - probMem["ActCpu"])*batch_size*seqlen*dim1 
    
    PeakMemory = workingMem + CpuMem_resident
    assert(PeakMemory < sysCfg['gpuMemCapacity'])
    return True
    
def  disk_PeakMem(ModelConfig: dict = None, 
                  probMem: dict = None, 
                  batch_size: int =1, 
                  sysCfg: dict = None) -> bool:
    
    if ProbMem is None:
        raise ValueError("ProbMem is None")
    dim1 = ModelConfig["hidden_dim1"]
    dim2 = ModelConfig["hidden_dim2"]
    seqlen = ModelConfig["max_seq_len"] 
    outToken = ModelConfig["out_tokens"]
    numLayers = ModelConfig["num_layers"]
    
    disk_resident_layer = probMem["WeightsDisk"] * (4*pow(dim1,2)) + 2*dim1*dim2  + probMem["KVCacheDisk"]*2*(seqlen + outToken)*dim1 
    resident_act = ProbMem["ActDisk"]*batch_size*seqlen*dim1 
    PeakMemory = disk_resident_layer * numLayers + resident_act

    assert(PeakMemory < sysCfg['gpuMemCapacity'])
    return True
    
def activationmem_per_layer(batch_size:int, seq_len: int, model_dim1: int, model_dim2: int, FA: bool= True):
    #input seqlen*batch_size*model_dim 
    #Q,K,V tensor(s) = 3*batch_size*seq_len*model_dim
    #output projection = seqlen*batch_size*model_dim
    #MLP1 = seqlen*model_dim1*
    #MLP2 = seqlen*model_dim
    #layernorm = seqlen*batch_size*model_dim 
    # max (above all)
    # IF FA == True, no QK output, no dropout,..
    input = batch_size*seq_len*model_dim1
    QKV_tensor = batch_size*seq_len*model_dim1*3
    attn_output = batch_size*seq_len*model_dim1
    MLP1_output = batch_size*seq_len*model_dim2
    MLP2_output = batch_size*seq_len*model_dim1
    layer_norm = batch_size*seq_len*model_dim1
    
    activation_mem = max(input,QKV_tensor,attn_output,MLP1_output,MLP2_output,layer_norm)
    return activation_mem

def kvcache_per_layer(Phase: MODELPHASE = MODELPHASE.PREFILL.value,
                      batch_size:int =1, 
                      outTokenSize: int=1, 
                      seq_len:int = 1, 
                      hidden_dim:int =1):
    
    if Phase == MODELPHASE.PREFILL.value:
       return 2*(seq_len*batch_size*hidden_dim)
    else:   
       return (2*(seq_len+outTokenSize)*batch_size*hidden_dim)
   
        
def model_parameters(model_name:str, component:str):
    
    if model_name in MODEL_CLASSES:
        modelDict = MODEL_CLASSES[model_name]
    else:
        raise ValueError(f" Model {model_name} not supported yet")    
    if component == MODELCOMPONENTS.Weights.value:
        ## 4*hidden_dim**2 + 2*model_dim*hidden_dim + 2*hidden_dim * num_layers + 2*(seqlen + vocab_size)*model_dim  
        num_parameters_ = (4*pow(modelDict["model_dim"],2) + 2*modelDict["model_dim"]*modelDict["hidden_dim"] + 2*modelDict["model_dim"]) *modelDict["num_layers"]
        num_parameters_ +=  (modelDict["max_seq_len"]+modelDict["vocab_size"])*modelDict["model_dim"]
        return num_parameters_
    if component == MODELCOMPONENTS.Activation.value:
        activation_ = activationmem_per_layer(batch_size=modelDict["batch_size"], seq_len = modelDict["max_seq_len"],model_dim1=modelDict["model_dim"],model_dim2=modelDict["hidden_dim"])
        return activation_
        
    if component == MODELCOMPONENTS.KVcache.value:
        kvCache_ = kvcache_per_layer(batch_size=modelDict["batch_size"], seq_len = modelDict["max_seq_len"], hidden_dim=modelDict["hidden_dim"], outTokenSize=modelDict["out_tokens"])
        return kvCache_
        
    raise ValueError(f" not supported {component} yet")

def qkv_projection_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value):
    return batch_size*3*2*modelCfg['max_seq_len']*modelCfg['model_dim']*modelCfg['model_dim']

def MLP_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value):
    return 2*2*batch_size*modelCfg['max_seq_len']*modelCfg['model_dim']*modelCfg['hidden_dim']

def out_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value):
    return batch_size*2*modelCfg['max_seq_len']*modelCfg['model_dim']*modelCfg['model_dim']

def flashAttention_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value, CMasking:bool = True):
    if Phase == MODELPHASE.PREFILL.Value:       
        FA = 2*2*batch_size*modelCfg['max_seq_len']*modelCfg['max_seq_len']*modelCfg['model_dim']
        FA = FA//2 if CMasking == True else FA
        return FA
    else:
        FA = 2*2*batch_size*modelCfg['max_seq_len']*modelCfg['model_dim']
        FA = FA//2 if CMasking == True else FA
        return FA

def GPUPeakMem(model_name:str):
    weights= model_parameters(model_name,"weights")
    activation = model_parameters(model_name,"activation")
    kvcache = model_parameters(model_name,"KVcache")
    
    peakMem = weights+activation+kvcache
    return peakMem

In [534]:
#MESH shard (x,y)
#pipeline parallelism
#split layers among N accelerators

#model parallelism (tensor parallelism)
#attention - split #heads N accelerators
#MLP - split tensors(weight) N accelerators
# num_heads = h  hd = d_model/h We
# Wq,Wk,Wv  (num_heads,d_modelxhd) 
# each layer algorithm
#  S= Q@K => P = softmax(S) => P = P @ V => O = P @ W_o   
# send d_model/N to N-1 accelerators 
# receive d_model/N from N-1 accelerators
# add d_model/N vectors from N-1 accelerators 

#MLP weights 4*d_model / N split
#MLP1 = d_model @ 4*d_model/N => 4*d_model/N @ d_model 
# send 4*d_model/N to N-1 accelerators
# receive 4*d_model/N from accelerators
# in total each layer require 4*communication ops

#batch=1 case
# mem_time (math is boud by memory BW) = bpe*model_parameters/ (N*mi300_bw)
# communication = 4*num_layers*8us (latency due to BS=1)  #8 us latency

#batch = large
# compute = batch_size*2*P/(N*mi300_fp16_flops)
# communication =  B * 2 * 4  * (N-1) (model_dim) / (N * link_bw)

#flops:mem ratio should give about batch-size 
# flops : link_bw should give about embedding dimension

#parallelism communication happens at four steps
# want to parallelize communication and compute
#           QKV                      WO                       MLP0               MLP1
#  flops    3*B*(d_model*d_model)    B*(d_model_d_model)      4*B*(d_model)       4*B*(d_model*d_model)
#  commu    B*(d_model)              b*(d_model)              4*B*d_model         4*B*d_model
#  flop/comm  3*d_model              d_model                  4*d_model           4*d_model

# embedding dimenssion should be > flops:link BW  ratio of given configuration
# chooese batch_size that is optimal for compute-bound to optimze per request latency

#pipeline parallelism
#batchsize=1 case

#worker0   |       |        |         |
#          |  B    |        |         |
#          |       |        |         |
#-------------------------------------
#worker1   |       |        |         |
#          |       |   B    |         |
#          |       |        |         |
#--------------------------------------
#worker2   |       |        |         |
#          |       |        |   B     |
#          |       |        |         |

#N workers 1 Batch
#total slots = N*N
#idle slots  = N*(N-1)

#idle time = (N-1)/(N)
# utilization = 1- idle time

#worker0   |     |    |    |    |
#          | B0  | B1 |    |    |
#          |     |    |    |    |
#-------------------------------------
#worker1   |     |    |    |    |         
#          |     | B0 | B1 |    |         
#          |     |    |    |    |         
#--------------------------------------
#worker2   |     |    |    |    |
#          |     |    | B0 | B1 |
#          |     |    |    |    |

#K minibatches N workers 
# N+K-1 steps per worker
# N*(N+K-1) total steps
# N* (N-1) bubbles

#idle bubble time = (N-1) / (N+K-1)  

#L layers per worker

# idle bubble time = (N-1)/(L *(N+K-1))


#flops of the model = 24*n_layers * d_model*d_model
#attention layer flops
#layer norm for every attention layer
#de-embedding layer(s)
# activations , biases and dropouts

#list of operators Encoder
# Q,K,V
# input Bias
# QKT
# scaled-softmax
# out 
#output bias
# dropout
# residual add
# layerNorm
# MLP0
# MLP0 Bias
# GELU 
# MLP1 
# MLP1 bias
# Residual
# layerNorm

#add up all memory-bound ops / elements-ops vs GEMM ops
# 'other ops' : 'mem-ops'
# model size increases by x -times
# other_ops * x : x^2 * gemm_ops

# % other-ops latency of new model = other_ops(latency) * x / (x^2 * gemm_ops)
# if we get same gemm(S) latency as before then = other_ops(latency) / x
#actvcation writes out

#add communication latency
#batch-size = flops/bw 

#pipeline parallelism
#M= micro batches
#P= number of devices
#tf = forwardtime
#tb = backward time
# idle time efficiency = p * (p-1) / (m + p)

#ideal time = m *(tf+tb)
# buble time = (p-1)*(tf+tb) 
#buble time fraction =(p-1) * (tf+tb) / (m) * (tf+tb)
# more stages per GPU (v stages) 
# tf/v , tb/v
# bubble time reduction = p-1 * (tf+tb) / v
#tensor model parallelism require all-reduce reduction so use GPU(S) within 
# GPU server node pipeline parallelism mapped to inter GPU nodes

#communication cost per layer for tensor/model parallelism
# number of parameters = d
# number of devices    = p
# network-latency =  alpha
# network Bandwidth = B
# mini batch-size = b
#communcation cost/layer(all-gather) = alpha * log(P) + b*(P=1)*d / (p*B) 
#             backward =  2* (alpha*log(P) + b*(P-1)*/p*B)
#             backward (all-reduct)

# batch (data parallelism)
# using ring algo for all-reduce (backward) batch -parallel 
# each device does partial sum followed by all-redce 
# T (commu-batch)/ layer = 2*(alpha*log[P] + (P-1)/(P*B) * [Weights of the layer])

#convolution = Yo * Xi * kh * Kw convolutions production output Yo * Yh *Yw
# W = kh*kw*Xc*Yc
# di = Yc*Xh*Xw

# ratio of communcation volume = batch/model = 2*[Wi]/3*B*(di) 



# ----------------
# |0 | 1 | 2 | 3 |
# |4 | 5 | 6 | 7 | 
# ----------------

#pipeline parallelism = X dimension = 4 GPU(S)  layers_per_rank = layers/num_ranks  
#tensor parallelism    = Y dimension = 2 GPU(s)  tensors_per_rank  = hidden_dim/num_ranks model_dim/num_ranks  heads_per_rank = attn_heads/num_ranks
# data_parallelism    = X dimension = 4 GPU(s)  batchsize_per_rank = batch_size/num_ranks


#B = batchsize
#N = number of heads
#M = model-dimension
#S = sequence_len
#H = hidden_dimension
embedding_shard_BSM = {'batch_size' : 'X', 'seq_len' : '_', 'model_dim': "_'"}
activation_shard_BSM = {'batch_size'  : 'X', 'seq_len' : '_', 'model_dim' : 'Y'}
activation_shard_BSND = {'batch_size' : 'X', 'seq_len': '_', 'attn_heads': 'Y', 'attn_dim' : '_'}
activation_shard_BSH = {'batch_size'  : 'X', 'seq_len' : '_', 'hidden_dim' : 'Y'}
activation_shard_BNSS = {'batch_size' : 'X', 'attn_heads' : 'Y', 'max_seq_len' : '_'}

from dataclasses import dataclass
#first key is summation dimension second key is free0/free1 dimension
weights_shard_MM  = {'model_dim'  : '_', 'model_dim1' : '_'}
weights_shard_MND = {'model_dim'  : '_', 'attn_heads'  : 'Y', 'attn_dim' : '_'} 
weights_shard_NDM = {'attn_heads' : 'Y', 'attn_dim'  : '_', 'model_dim'  : 'X'}
weights_shard_MH  = {'model_dim'  : 'X', 'hidden_dim' : 'Y'}
weights_shard_HM  = {'model_dim'  : 'X', 'hidden_dim' : 'Y'}
embedding_weight =  {'vocab_size'  : '_', 'model_dim' : '_'}

@dataclass
class transformer_2dshard:
    embedding_activations = {'input0' : embedding_shard_BSM, 'output' : embedding_shard_BSM}
    qkv_activations = {'input0' : activation_shard_BSM , 'input1' : weights_shard_MM, 'output' : activation_shard_BSND}    # weights M=8192  N=64 D = 128   activation : B=1x2048x8192  (1x2048x128, 64) 
    flash_attention_qk = {'input0' : activation_shard_BSND , 'input1' : activation_shard_BSND, 'output' : activation_shard_BNSS}
    flash_attention_sv = {'input0' : activation_shard_BSND , 'input1' : activation_shard_BNSS, 'output' : activation_shard_BSND}
    out_projection =  {'input0' : activation_shard_BSND , 'input1' :  weights_shard_NDM , 'output' : activation_shard_BSM}
    #FIXME sequence_len sharding 
    layer_norm = {'input' : activation_shard_BSM , 'output' : activation_shard_BSM}
    mlp_0 = {'input0': activation_shard_BSM, 'input1' : weights_shard_MH, 'output' : activation_shard_BSH}
    gelu  = {'input': activation_shard_BSH, 'output' : activation_shard_BSH}
    mlp_1 = {'input1': activation_shard_BSH, 'input1' : weights_shard_HM, 'output' : activation_shard_BSM}
    
    def __init__(self,model_name:str = "llama_65B",
                 bpe : int = 2, 
                 num_gpus : int =8,  
                 sysCfg: system_config = None,
                 tileSizes: list[tuple] = None):
        
        self.bpe = bpe
        if model_name in MODEL_CLASSES.keys():
            self.modelParameters = MODEL_CLASSES[model_name]  
            self.model_name = model_name
        else:
            raise ValueError("Unsupported model")
        self.numGPUs = num_gpus
        
        #3D sharding for n GPU(s)
        self.shard_topology = ("pipeline","tensor")  
        self.devices = np.arange(num_gpus)
        self.mesh = np.array(self.devices).reshape(-1,2)
        
        self.systemCfg = sysCfg
        self.num_cus = self.systemCfg['mi300']['num_cus']
        self.fp16rate = self.num_cus*self.systemCfg['mi300']['fp16_rate']
        self.num_layers = self.modelParameters['num_layers']

        
        self.tileSizeChoices = tileSizes
        self.max_seq_len = self.modelParameters['max_seq_len']
        
        #support 2D mesh for sharding 
        self.pipeline_num_ranks = self.mesh.shape[0]
        self.layers_per_rank  = self.modelParameters['num_layers'] // self.pipeline_num_ranks
        self.num_microbatches = self.modelParameters['batch_size']//self.modelParameters['microbatch_size']
            
        self.tensor_num_ranks = self.mesh.shape[1]
        
        

    def peakmem_check(self):
        memUsage = GPUPeakMem(self.model_name)//1e9
        #pipeline parallelism takes batch_size dimension
        memCapacity = self.systemCfg['mi300']['hbm_cap']
        memUsage = memUsage//len(self.mesh[1])
        assert(memUsage <= memCapacity)
        
    def recv_payloadsize(self):
        return self.modelParameters['batch_size'] * self.modelParameters['model_dim'] * self.modelParameters['max_seq_len']
    
    def layer_ops(self,seq_len: int):
        
        #qkv_projections
        act_tensor = {}
        wt_tensor = {}
        
        for key, _ in self.qkv_activations['input0'].items():
            if key == 'batch_size':
                act_tensor['batch_size'] = self.modelParameters['microbatch_size']
            elif key == 'seq_len':
                act_tensor['M'] = seq_len
            elif key == 'model_dim':
                act_tensor['K'] = self.modelParameters['model_dim'] 
            else:
                raise ValueError(f"unknown key in activation_shard_BSM tensor")
                
        for key, value in self.qkv_activations['input1'].items():
            if key == 'model_dim':
                wt_tensor['K'] = self.modelParameters['model_dim']
            elif key == 'model_dim1':
                wt_tensor['N'] = self.modelParameters['model_dim']
                if value != '_':
                    wt_tensor['N'] = wt_tensor['N']//self.tensor_num_ranks
            else:
                raise ValueError(f"unknown key in activation_shard_BSM tensor")        
        #print(f"fp_rate {self.fp16rate}")
        print(f"QKV_Projection:: GEMM problem sizes M={act_tensor['M']} N= {3*wt_tensor['N']} K= {act_tensor['K']}")
        #returns efficiency
        #efficiency = find_bestgemmperformance(self.bpe,act_tensor['M'],3*wt_tensor['N'],act_tensor['K'],"mi300")
        qkv_duration = np.round(find_bestgemmperformance(self.bpe,act_tensor['M'],3*wt_tensor['N'],act_tensor['K'],"mi300") / 2,4)
        #print(f"qkvgemm_eff : {efficiency['eff']} utilization={efficiency['GPU_util']}")
        #qkv_cycles = ((2*act_tensor['M'])*3*wt_tensor['N']*act_tensor['K'])/(efficiency['GPU_util']*efficiency['eff']*self.fp16rate)
        #qkv_cycles = qkv_cycles/self.tensor_num_ranks/1500/1000 
        print(f"QKV_Projection:: time_duration(ms): [{qkv_duration}]")
        #flash attention 
        #calculate num_heads per GPU
        #case 1:  split heads in 'Y' dimension (tensor parallelism)
        q_tensor = {}
        k_tensor = {}
        v_tensor = {}

        for key, value in self.flash_attention_qk['input0'].items():
            if key == 'batch_size':
                q_tensor['batch_size'] = self.modelParameters['microbatch_size']
            elif key == 'seq_len':
                q_tensor['M'] = seq_len
            elif key == 'attn_dim':
                q_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['attn_heads']
            elif key == 'attn_heads':
                heads_num_ranks = self.tensor_num_ranks
                q_tensor['h']  = self.modelParameters['attn_heads']
            else:
                raise ValueError(f"unknown key in activation_shard_BSND tensor")
        
        for key, _ in self.flash_attention_qk['input1'].items():
            if key == 'batch_size':
                k_tensor['batch_size'] = self.modelParameters['microbatch_size']
                v_tensor['batch_size'] = self.modelParameters['microbatch_size']
            elif key == 'seq_len':
                k_tensor['N'] = seq_len
                v_tensor['N'] = seq_len
            elif key == 'attn_dim':
                k_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['attn_heads']
                v_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['attn_heads']
            elif key == 'attn_heads':
                k_tensor['h']  = self.modelParameters['attn_heads']
                v_tensor['h']  = self.modelParameters['attn_heads']
            else:
                raise ValueError(f"unknown key in activation_shard_BSND tensor")
                

        #shard heads in 'Y' dimension     
        q_tensor['h'] = q_tensor['h']//heads_num_ranks
        print(f"FA problem sizes : q_len = {q_tensor['M']} k_len = {k_tensor['N']}, attn_dim={q_tensor['K']}, num_heads={q_tensor['h']}")
        fa_eff = find_bestfaperformance(q_tensor['M'],k_tensor['N'],q_tensor['K'],q_tensor['h'],self.bpe,"mi300")
        fa_flops = 2*(q_tensor['batch_size']*q_tensor['h']*(q_tensor['M']*q_tensor['K']*k_tensor['N'] + v_tensor['K']*q_tensor['M']*v_tensor['N']))
        fa_cycles = np.round(fa_flops/(self.fp16rate*fa_eff),4)
        fa_duration = np.round(fa_cycles/(self.systemCfg['mi300']['fmax']*1e6),4)
        print(f"FlashAttention::  seqlen={seq_len} attn_heads= {self.modelParameters['attn_heads']} attn_dim= {q_tensor['K']}")
        print(f"flash attention:: duration(ms)=[{fa_duration}]") 
        #attention result projection output to [B,S,M]
        #input [B,S,N,D] x [N,D,M] -> [B,S,M]
        #shard on model-dimension
        #gather for reduction
        o_tensor = {}
        wo_tensor = {}
        
        for key, value in self.out_projection['input0'].items():
            if key == 'batch_size':
                o_tensor['batch_size'] = self.modelParameters['microbatch_size']
            elif key == 'seq_len':
                o_tensor['M'] = seq_len
            elif key == 'attn_dim':
                o_tensor['K'] = self.modelParameters['model_dim']
            elif key == 'attn_heads':
                heads_num_ranks = self.tensor_num_ranks
            else:
                print(key)
                raise ValueError(f"unknown key in activation_shard_BSND tensor")
            
        for key, value in self.out_projection['input1'].items():
            if key == 'model_dim':
                wo_tensor['N'] = self.modelParameters['model_dim'] // self.tensor_num_ranks
            elif key == 'attn_dim':
                wo_tensor['K'] = self.modelParameters['model_dim']
            elif key == 'attn_heads':
                heads_num_ranks = self.tensor_num_ranks
            else:
                raise ValueError(f"unknown key in weights_shard_NDM tensor")    
        
        #reshape BSND into BSM (gather primitive)    
        print(f"Out_Projection:: GEMM problem sizes M={o_tensor['M']} N= {wo_tensor['N']} K= {o_tensor['K']}")
        #efficiency = find_bestgemmperformance(self.bpe,o_tensor['M'],wo_tensor['N'],o_tensor['K'],"mi300")
        #print(efficiency['eff'])
        #print(efficiency['GPU_util'])
        #projout_cycles = (2*o_tensor['batch_size']*o_tensor['M']*wo_tensor['N']*o_tensor['K'])/(efficiency['GPU_util']*efficiency['eff']*self.fp16rate) 
        #projout_cycles = projout_cycles/1500/1000
        #print(f"out_projection cycles = {projout_cycles}")
        
        out_duration = np.round(find_bestgemmperformance(self.bpe,o_tensor['M'],wo_tensor['N'],o_tensor['K'],"mi300"),4)
        print(f"out_projection duration(ms) = {out_duration}")
        
        payload_size = (self.modelParameters['microbatch_size'] * seq_len)
        ln_duration = np.round(2*layerNormCycles(self.modelParameters['microbatch_size'],seq_len,self.modelParameters['model_dim']),4)
        print(f"ln- duration(ms) = {ln_duration}")
        bias_duration = payload_size/(self.num_cus*128)/(self.systemCfg['mi300']['fmax']*1e6)   #rough estimation
        #print(f"bias_duration = {bias_duration}")
        act_tensor = {}
        wt_tensor = {}
        for key, value in self.mlp_0['input0'].items():
            if key == 'batch_size':
                act_tensor['batch_size'] = self.modelParameters['microbatch_size']
            elif key == 'seq_len':
                act_tensor['M'] = seq_len
            elif key == 'model_dim':
                act_tensor['K'] = self.modelParameters['model_dim']
            else:
                raise ValueError(f"unknown key in activation_shard_BSM tensor")
            
        for key, value in self.mlp_0['input1'].items():
            if key == 'model_dim':
                wt_tensor['N'] = self.modelParameters['hidden_dim']
                self.tensor_num_ranks  = 1  if value == '_' else self.tensor_num_ranks
            elif key == 'hidden_dim':
                wt_tensor['K'] = self.modelParameters['model_dim'] 
            else:
                raise ValueError(f"unknown key in weights_shard_MH tensor")    
         
         
        print(f"MLP:: GEMM problem sizes M={act_tensor['M']} N= {wt_tensor['N']} K= {wt_tensor['K']}")    
        wt_tensor['N'] = wt_tensor['N'] // self.mesh.shape[1]
        #efficiency = find_bestgemmperformance(self.bpe,act_tensor['M'],wt_tensor['N'],wt_tensor['K'],"mi300")
        #mlp_cycles = 2*(2*act_tensor['M']*wt_tensor['N']*act_tensor['K'])//(efficiency['GPU_util']*efficiency['eff']*self.fp16rate)
        mlp_duration = np.round(find_bestgemmperformance(self.bpe,act_tensor['M'],wt_tensor['N'],wt_tensor['K'],"mi300"),4)
        #mlp_cycles = mlp_cycles/1500/1000
        print(f"MLP[0-1] duration(ms) :: {mlp_duration}")
        #add reduction 
        comm_bw_acheived = self.systemCfg['mi300']['comm_bw']*self.systemCfg['mi300']['comm_eff']
        comm_bw_per_cycle = comm_bw_acheived / self.systemCfg['mi300']['fmax']                                                                     
        
        #print(f"comm_bw_per_cycle = {comm_bw_per_cycle}")
        
        #gather cycles
        #communication _time
        latency_time = np.round(self.systemCfg['mi300']['comm_latency'] * math.log2(self.tensor_num_ranks) / 1000,4)
        #gather cycles
        gather_duration = np.round(((self.tensor_num_ranks-1)/(self.tensor_num_ranks))*self.layers_per_rank*self.modelParameters['model_dim']*self.bpe / (1e6*comm_bw_acheived),4) + latency_time
        #reduction cycles
        #scatter cycles (multiplied by layers per device)
        reduce_scatter_duration = np.round(act_tensor['M']*self.modelParameters['hidden_dim']*self.bpe / (1e6*comm_bw_acheived),4)
        print(f"reduce_scatter ={reduce_scatter_duration}")
        #bw_per_cycle = (membw_acheived()/self.systemCfg['mi300']['fmax'])
        #reduction by reading 
        reduce_scatter_reduction = 2*reduce_scatter_duration + (3*self.tensor_num_ranks*act_tensor['M']*self.modelParameters['hidden_dim']*self.bpe / (1e9*membw_acheived()) *1e3)  +  latency_time
        reduce_scatter_reduction = np.round(reduce_scatter_reduction,4)
        gelu_cycles = geluCycles(act_tensor['M']*wt_tensor['N']/self.num_cus)
        gelu_duration = np.round(gelu_cycles/(self.systemCfg['mi300']['fmax']*1e6),4)
        print(f"gelu duration(ms) = {gelu_duration}")
        print(f"reduce_scatter_reduction (ms)= {reduce_scatter_reduction}")
        total_time = qkv_duration + fa_duration  + 2*ln_duration + out_duration
        total_time += mlp_duration + gelu_duration + reduce_scatter_reduction + gather_duration
            
        print(f" total_time per layer = {total_time}")
        
        return total_time
        
    def forward_phase(self):
        #prefill phase
        print("\n")
        print("######################################")
        print("****** Prefill Phase *****************")
        print("######################################")
        prefill_time = self.layer_ops(self.max_seq_len) * (self.num_layers/self.pipeline_num_ranks)
        print(f"prefill (ms) = {prefill_time}")
        print("\n")
        print("######################################")
        print("****** decode Phase *****************")
        print("######################################")
        token_gen_time = self.layer_ops(16) * (self.num_layers/self.pipeline_num_ranks)
        print(f"token_gen= {token_gen_time}")
        total_time = np.round((prefill_time + token_gen_time),4)
        print("######################################")
        print("****** Total Time   *****************")
        print("######################################")
        print(f"total_time (ms) = {total_time}")
        #tokens_per_msecond = total_cycles/(self.systemCfg['mi300']['fmax'] * 1000)/1000
        #print(tokens_per_msecond)
        
    def pipeline_eff(self):
        #pipeline efficiency
        pipeline_idle_fraction = (self.pipeline_num_ranks-1)/(self.pipeline_num_ranks+self.num_microbatches-1)
        #layers per device
        pipeline_idle = pipeline_idle_fraction/self.layers_per_rank
        return pipeline_idle
        


In [536]:
llm2 = transformer_2dshard(sysCfg=system_config,tileSizes=tile_sizes)
time = llm2.forward_phase()




######################################
****** Prefill Phase *****************
######################################
QKV_Projection:: GEMM problem sizes M=2048 N= 24576 K= 8192
QKV_Projection:: time_duration(ms): [0.4905]
FA problem sizes : q_len = 2048 k_len = 2048, attn_dim=128, num_heads=32
softmax_cycles = 3168.0
fa_eff 0.5769014084507043
FlashAttention::  seqlen=2048 attn_heads= 64 attn_dim= 128
flash attention:: duration(ms)=[0.1276]
Out_Projection:: GEMM problem sizes M=2048 N= 4096 K= 8192
out_projection duration(ms) = 0.1907
ln- duration(ms) = 0.0315
MLP:: GEMM problem sizes M=2048 N= 22016 K= 8192
MLP[0-1] duration(ms) :: 0.4695
reduce_scatter =0.2876
gelu duration(ms) = 0.0193
reduce_scatter_reduction (ms)= 0.7142
 total_time per layer = 2.0873
prefill (ms) = 41.745999999999995


######################################
****** decode Phase *****************
######################################
QKV_Projection:: GEMM problem sizes M=16 N= 24576 K= 8192
QKV_Projection:: time_

In [None]:
import numpy as np


def welford_update(count, mean, M2, currValue):
    count += 1
    delta = currValue - mean
    mean += delta / count
    delta2 = currValue - mean
    M2 += delta * delta2
    return (count, mean, M2)


def naive_update(sum, sum_square, currValue):
    sum = sum + currValue
    sum_square = sum_square + currValue * currValue
    return (sum, sum_square)


x_arr = np.random.randn(100000).astype(np.float32)

welford_mean = 0
welford_m2 = 0
welford_count = 0
for i in range(len(x_arr)):
    new_val = x_arr[i]
    welford_count, welford_mean, welford_m2 = welford_update(welford_count, welford_mean, welford_m2, new_val)
print("Welford mean: ", welford_mean)
print("Welford var: ", welford_m2 / welford_count)

naive_sum = 0
naive_sum_square = 0
for i in range(len(x_arr)):
    new_val = x_arr[i]
    naive_sum, naive_sum_square = naive_update(naive_sum, naive_sum_square, new_val)
naive_mean = naive_sum / len(x_arr)
naive_var = naive_sum_square/ len(x_arr) - naive_mean*naive_mean
print("Naive mean: ", naive_mean)
print("Naive var: ", naive_var)