In [1]:
import os
import sys
from pathlib import Path
import numpy as np      
import math
import argparse
from enum import Enum

In [24]:
MODEL_CLASSES = {
    "Gpt3_175B"  : {"microbatch_size": 1, "batch_size": 1, "max_seq_len": 2048, "vocab_size": 64000, "model_dim": 12288, "hidden_dim": 49152, "num_layers" : 96, "attn_heads" : 96, "out_tokens": 32},
    "llama_65B"  : {"microbatch_size": 1, "batch_size": 1, "max_seq_len": 2048,"vocab_size": 32000, "model_dim": 8192, "hidden_dim": 22016, "num_layers" : 80, "attn_heads" : 64, "out_tokens": 32},
}
system_config = {
    "mi300" : { 'hbm_cap': 192, 'lds_cap': 64,
               'num_cus': 304,
               'f16_rate' : 2048, 'fmax' : 1.5 , 'l2_rdbw' : 16384,
               'l2_wrbw': 8192, 'hbm_bw' : 4403, 'lds_bw' : 128, 'fma_rate' : 256}
}
launch_setup_eff_loss = 5 
NUM_GPUS = 8
Bpe = 2   #fp16
stream_bw_attainment = .8  #efficiency

class MODELCOMPONENTS(Enum):
    Weights = "weights"
    Activation = "activation"
    KVcache  = "KVcache"
    MOES = "Moes"
    
class MODELPHASE(Enum):
    PREFILL = 0
    TOKEN = 1    
    
class MODELPIPELINE(Enum):
    INFERENCE = 0
    TRAINING = 1

ProbMem = dict()
ProbMem['KVCacheGpu'] = 0.5
ProbMem['WeightsGpu'] = 0.5
ProbMem['ActGpu'] = 0.5
ProbMem['KVCacheCpu'] = 0.2
ProbMem['WeightsCpu'] = 0.2
ProbMem['ActCpu'] = 0.2
ProbMem['KVCacheDisk'] = 0.3
ProbMem['WeightsDisk'] = 0.3
ProbMem['ActDisk'] = 0.3

#h100
#NVLink = 990 GB/s    #infiniband = 448GB /s  (%75 achievable target)
#capacity = 80GB
#bf16 = 1979Tflops / 2 (dense)
#mem-BW = 3.3TB/s

tile_sizes = [(64,64),(128,128),(256,256),(256,128)]
tilek_sizes = [1024,512,256,128,64,32]



def gemm_l2_bw(gemm_k: int, tile_size : tuple = (128,128), bpe:int =2, alu_rate: int = 2048):
    num_flops = tile_size[0]*tile_size[1]*gemm_k*2
    bytes_to_read = bpe*(tile_size[0]*gemm_k + tile_size[1]*gemm_k)
    bytes_to_write = bpe*(tile_size[0]*tile_size[1])
    l2_bw_need =  (bytes_to_read + bytes_to_write)//(num_flops//alu_rate)
    return l2_bw_need

def gemm_lds_bw(bpe: int = 2, tile_k: int =128,  tile_size : tuple = (128,128), alu_rate: int =2048):
    num_flops = tile_size[0]*tile_size[1]*tile_k*2
    #4 waves / WG ; 1 WG/CU work dimension 
    #squared tile requried 2x read-size
    bytes_to_read = 2*bpe*(tile_size[0]*tile_k + tile_size[1]*tile_k)
    bytes_to_write = bpe*(tile_size[0]*tile_size[1])
    lds_bw_need =  (bytes_to_read + bytes_to_write)//(num_flops//alu_rate)
    return lds_bw_need
 
def gemm_eff(bpe: int =2, Mtile: int = 128, Ntile: int = 128, Ktile:int = 128, K: int=1024, gpu_name: str = "mi300"):
    
    _config = system_config[gpu_name]
    l2_bw_required = gemm_l2_bw(K,tile_size=tuple(Mtile,Ntile),bpe=bpe,alu_rate=_config['fp16_rate'])
    lds_bw_required = gemm_lds_bw(bpe=bpe,tile_size=tuple(Mtile,Ntile),alu_rate=_config['fp16_rate'])
    bw_per_cu = _config['l2_rdbw']//_config['num_cus']
    gemm_eff = min(bw_per_cu//l2_bw_required, 1)
    gemm_eff = min(gemm_eff,_config['lds_bw']//lds_bw_required)
    gemm_eff = gemm_eff - (launch_setup_eff_loss/100)
    return gemm_eff
    
def  flashattn_eff(num_wgs_cu:int = 2, bpe:int = 2, attn_dim: int = 128, qtile: int = 128, ktile: int = 128, vtile: int=128, gpu_name:str = "mi300"):
     _config = system_config[gpu_name]
     softmax_cycles = softmax_rate(qtile,ktile)
     softmax_per_cycles = qtile*ktile//softmax_cycles
     #QKgEMM eff
     #qtile is read once in persistent in register or shared mem
     #qtile load is amortized with Ktensor*qtile*2*attn_dim math
     kbytes_read = ktile*bpe*attn_dim
     num_flops = qtile*ktile*attn_dim*2
     l2bw_req = kbytes_read/(num_flops//_config['fp16_rate'])
     bw_per_cu = _config['l2_rdbw']//_config['num_cus']
     qkgemm_eff = min(bw_per_cu//l2bw_req, 1)
     qk_cycles = num_flops//_config['fp16_rate']
     qk_elements_per_cycle = qtile*ktile//(qk_cycles//qkgemm_eff)
     vbytes_read = vtile*attn_dim*bpe
     num_flops= qtile*vtile*attn_dim*2
     l2bw_req = vbytes_read//(num_flops//_config['fp16_rate'])
     pvgemm_eff = min(bw_per_cu/l2bw_req,1)
     pv_cycles = num_flops//(_config['fp16_rate'])
     qksoftmax_eff = softmax_per_cycles//qk_elements_per_cycle
     fa_eff = qksoftmax_eff*pvgemm_eff if num_wgs_cu >=2 else (qk_cycles+pv_cycles)//((qk_cycles//qkgemm_eff)+softmax_cycles+(pv_cycles//pvgemm_eff))
     return fa_eff
    
    
def softmax_rate(qtile: int = 128, ktile: int=128):
    #number of ops required for softmax 
    #max3 -> reduction(wave32 reduction) -> max,  sub(x-max) -> exp() -> rowsum -> fma   
    
    number_softmax_per_wave = qtile*ktile//4
    number_max3_wave =  number_softmax_per_wave//2    #max3 = max(x,max(y,z))  
    _max3_cycles = number_max3_wave * 4   #16 elements/cycle/simd  
    _max3_reduction = 48  # max3 reduction across 32 threds using ds_bpermute (latency)
    #sub(x-max) using fma 
    fma_cycles = number_softmax_per_wave * 2   # 32 ops per cycle/simd  
    exp_cycles = number_softmax_per_wave * 16  # exp2f32()  4 ops per cycle/simd
    rowsum_cycles = number_softmax_per_wave * 2  # 32 ops per cycle/simd
    rowsum_reduction = 48  # ds_bpermute
    scaling_cycles = number_softmax_per_wave * 4 # 16 ops per cycle/simd
    total_cycles = _max3_cycles + _max3_reduction + fma_cycles + exp_cycles + rowsum_cycles + rowsum_reduction + scaling_cycles
    return total_cycles

def ret_ktile(bpe: int=2, tile_sizes:tuple = (128,128), gpu_name: str = "mi300"):
    for _val in tilek_sizes:
        if ((tile_sizes[0]+tile_sizes[1])*bpe*_val <=  (system_config[gpu_name]['lds_cap']*1024)//2):
            return _val
        else:
            continue
    assert(0)

def find_bestfaperformance(q_len: int, k_len: int, attn_dim: int, num_heads: int , bpe: int = 2, gpu_name: "mi300"):
    
    #2D shard per GPU for num_heads x num_query
    num_cus = system_config[gpu_name]['num_cus'] // num_heads
    tile_size= q_len // num_cus
    if tile_size >= 128:
        tile_size = 128
    elif tile_size >= 64:
        tile_size = 64
    elif tile_size >= 32:
        tile_size = 32
    elif tile_size >= 16:
        tile_size = 16
    else:
        tile_size = q_len
        
    q_tiles = q_len//tile_size
    num_fatiles = q_tiles * num_heads
    fatiles_percu = num_fatiles // system_config[gpu_name]['num_cus']
    
    qtile_size = tile_size
    #LDS memory to hold K/V elements; Q tile is persistent in register
    #64KB LDS size 2 WG(S) holding K elements  
    kvtile_size = 64*1024 // (2*attn_dim * bpe)
    
    eff = flashattn_eff(fatiles_percu,bpe,attn_dim,qtile_size,kvtile_size,kvtile_size,gpu_name)
    return eff

    


def find_bestgemmperformance(bpe: int =2, M: int = None, N: int =None, K:int =None, gpu_name: str = "mi300"):
    assert(K!=None)
    assert(M!=None)
    assert(N!=None)
    
    _efficiency_table = []
    for tile_size in tile_sizes:
        ktile = ret_ktile(bpe,tile_size,gpu_name)
        _effmetrics = {}
        _effmetrics['tile_size'] = tile_size
        _effmetrics['eff'] = gemm_eff(bpe,tile_size[0],tile_size[1],ktile,K,gpu_name)
        _effmetrics['granu_loss'] = granu_loss(gpu_name,M,N,tile_size[0],tile_size[1])
        _efficiency_table.append(_effmetrics)
        
    pick_winner = []    
    for idx,_item in enumerate(_efficiency_table):
        metrics = _item[idx]
        pick_winner.append(metrics['eff']*metrics['granu_loss'])
    
    winner_idx = np.argmax(pick_winner)
    return _efficiency_table[winner_idx]

def granu_loss(gpu_name: str = 'mi300', gemm_m : int = 1024, gemm_n : int = 1024, mtile:int =128, ntile:int =128):
    
    num_cus = system_config[gpu_name]
    mtile = tile_size[0]
    ntile = tile_size[1]
    m_tiles = gemm_m//mtile
    n_tiles = gemm_n//ntile
    tiles_per_cu = m_tiles*n_tiles // num_cus
    return tiles_per_cu


In [26]:
#Peak memory assertions for various system Components
# for running  large models using single GPU + CPU memory + disk CAche

def GPU_PeakMem(Phase : MODELPHASE = MODELPHASE.TOKEN.value, ModelConfig: dict = None, probMem: dict = None, batch_size: int =1, sysCfg: dict() = None) -> bool:
 
    """ Provide GPU peak memory (in elements) required for given model size & model phase"""
    
    if ProbMem is None:
        raise ValueError("ProbMem is None")

    dim1 = ModelConfig["hidden_dim1"]
    dim2 = ModelConfig["hidden_dim2"]
    seqlen = ModelConfig["max_seq_len"] 
    outToken = ModelConfig["out_tokens"]
    numLayers = ModelConfig["num_layers"]
    numHeads = ModelConfig["num_heads"]
    
    GpuMem_resident_layer = probMem["WeightsGpu"] * 4*pow(dim1,2) + 2*dim1*dim2  + probMem["KVCacheGpu"]*2*(seqlen + outToken)*dim1
                                    
    resident_act = ProbMem["ActGpu"]*batch_size*seqlen*dim1 if Phase == MODELPHASE.PREFILL.value else ProbMem["ActGpu"]*batch_size*1*dim1
    GpuMem_resident = GpuMem_resident_layer * numLayers + resident_act
    
    #Working memory for each op
    #QKV tensor matrices seqlen x dim1 [3] + input seqlen * dim1  
    #for Token generation seqlen = 1
    QKV = batch_size*(4*seqlen*dim1) if Phase == MODELPHASE.TOKEN.value else batch_size*(4*dim1)
    if Phase == MODELPHASE.PREFILL.value:
        #QK inputs + attention outputs
        QK_attention = ProbMem["KVCacheGpu"]*batch_size*(seqlen*dim1 + seqlen*seqlen*numHeads + seqlen*dim1)
        QKV_attention = ProbMem["KVCacheGpu"]*batch_size*(seqlen*dim1 + seqlen*seqlen*numHeads + seqlen*dim1)
    else:
        #QK inputs Q seqlen =1 K seqlen (kv cache) 
        QK_attention = ProbMem["KVCacheGpu"]*batch_size*(dim1 + 1*(seqlen+outToken)*numHeads + (seqlen+outToken)*dim1)
        QKV_attention = ProbMem["KVCacheGpu"]*batch_size*(1*dim1 + 1*(seqlen+outToken)*numHeads + (seqlen+outToken)*dim1)   
        
    Embed = batch_size*(seqlen*dim1*2) if Phase == MODELPHASE.PREFILL.value else  batch_size * (2*dim1)
    MLP1 = batch_size*(seqlen*dim1 + seqlen*dim2) if Phase == MODELPHASE.PREFILL.value else batch_size*(1*dim1 + 1*dim2)
    MLP2 = batch_size*(seqlen*(dim1+dim2)) if Phase == MODELPHASE.PREFILL.value else batch_size*(1*dim1 + 1*dim2)
    
    #Total working memory by adding xfered weights stored higher hierarchy 
    # activation stored in hierachy memory
    # max(all the ops in one layer)
    workingMem = (1 - probMem["WeightsGpu"])*(4*pow(dim1,1) + 2*dim1*dim2) +\
                (1 - probMem["ActGpu"])*batch_size*seqlen*dim1 + \
                 max(QKV,QK_attention,QKV_attention,Embed,MLP1,MLP2)
                 
    PeakMemory = workingMem + GpuMem_resident
    assert(PeakMemory < sysCfg['gpuMemCapacity'])
    return True
    


def  CPU_PeakMem(Phase : MODELPHASE = MODELPHASE.TOKEN.value, 
                 ModelConfig: dict = None, probMem: dict = None, 
                 batch_size: int =1, 
                 sysCfg: dict = None) -> bool:        
    
    """ 
    Provide CPU peak memory (in elements) required for given model size & model phase
    """
    
    if ProbMem is None:
        raise ValueError("ProbMem is None")
    
    dim1 = ModelConfig["hidden_dim1"]
    dim2 = ModelConfig["hidden_dim2"]
    seqlen = ModelConfig["max_seq_len"] 
    outToken = ModelConfig["out_tokens"]
    numLayers = ModelConfig["num_layers"]
    #numHeads = ModelConfig["num_heads"]
    
    CpuMem_resident_layer = probMem["WeightsCpu"] * (4*pow(dim1,2)) + 2*dim1*dim2  + probMem["KVCacheCpu"]*2*(seqlen + outToken)*dim1
    resident_act = ProbMem["ActCpu"]*batch_size*seqlen*dim1 if Phase == MODELPHASE.PREFILL.value else ProbMem["ActGpu"]*batch_size*1*dim1
    CpuMem_resident = CpuMem_resident_layer * numLayers + resident_act
    
    #working memory stored in disk through CPU
    #Total working memory by adding xfered weights stored higher hierarchy 
    # activation stored in hierachy memory
    # max(all the ops in one layer)
    workingMem = (1 - probMem["WeightsCpu"])*(4*pow(dim1,1) + 2*dim1*dim2) +\
                (1 - probMem["ActCpu"])*batch_size*seqlen*dim1 
    
    PeakMemory = workingMem + CpuMem_resident
    assert(PeakMemory < sysCfg['gpuMemCapacity'])
    return True
    
def  disk_PeakMem(Phase : MODELPHASE = MODELPHASE.TOKEN.value, 
                  ModelConfig: dict = None, 
                  probMem: dict = None, 
                  batch_size: int =1, 
                  sysCfg: dict = None) -> bool:
    
    if ProbMem is None:
        raise ValueError("ProbMem is None")
    dim1 = ModelConfig["hidden_dim1"]
    dim2 = ModelConfig["hidden_dim2"]
    seqlen = ModelConfig["max_seq_len"] 
    outToken = ModelConfig["out_tokens"]
    numLayers = ModelConfig["num_layers"]
    
    disk_resident_layer = probMem["WeightsDisk"] * (4*pow(dim1,2)) + 2*dim1*dim2  + probMem["KVCacheDisk"]*2*(seqlen + outToken)*dim1 
    resident_act = ProbMem["ActDisk"]*batch_size*seqlen*dim1 
    PeakMemory = disk_resident_layer * numLayers + resident_act

    assert(PeakMemory < sysCfg['gpuMemCapacity'])
    return True
    
def activationmem_per_layer(batch_size:int, seq_len: int, model_dim1: int, model_dim2: int, FA: bool= True):
    #input seqlen*batch_size*model_dim 
    #Q,K,V tensor(s) = 3*batch_size*seq_len*model_dim
    #output projection = seqlen*batch_size*model_dim
    #MLP1 = seqlen*model_dim1*
    #MLP2 = seqlen*model_dim
    #layernorm = seqlen*batch_size*model_dim 
    # max (above all)
    # IF FA == True, no QK output, no dropout,..
    input = batch_size*seq_len*model_dim1
    QKV_tensor = batch_size*seq_len*model_dim1*3
    attn_output = batch_size*seq_len*model_dim1
    MLP1_output = batch_size*seq_len*model_dim2
    MLP2_output = batch_size*seq_len*model_dim1
    layer_norm = batch_size*seq_len*model_dim1
    
    activation_mem = max(input,QKV_tensor,attn_output,MLP1_output,MLP2_output,layer_norm)
    return activation_mem


def kvcache_per_layer(Phase: MODELPHASE = MODELPHASE.PREFILL.value,
                      batch_size:int =1, 
                      outTokenSize: int=1, 
                      seq_len:int = 1, 
                      hidden_dim:int =1):
    
    if Phase == MODELPHASE.PREFILL.Value:
       return 2*(seq_len*batch_size*hidden_dim)
    else:   
       return (2*(seq_len+outTokenSize)*batch_size*hidden_dim)
   
        
def model_parameters(model_name:str, component:str):
    
    if model_name in MODEL_CLASSES:
        modelDict = MODEL_CLASSES[model_name]
    else:
        raise ValueError(f" Model {model_name} not supported yet")    
    if component == MODELCOMPONENTS.Weights.value:
        ## 4*hidden_dim**2 + 2*model_dim*hidden_dim + 2*hidden_dim * num_layers + 2*(seqlen + vocab_size)*model_dim  
        num_parameters_ = (4*pow(modelDict["model_dim"],2) + 2*modelDict["model_dim"]*modelDict["hidden_dim"] + 2*modelDict["model_dim"]) *modelDict["num_layers"]
        num_parameters_ +=  (modelDict["seq_len"]+modelDict["vocab_size"])*modelDict["model_dim"]
        return num_parameters_
    if component == MODELCOMPONENTS.Activation.value:
        activation_ = activationmem_per_layer(batch_size=modelDict["batch_size"], seq_len = modelDict["seq_len"],model_dim1=modelDict["model_dim"],model_dim2=modelDict["hidden_dim"])
        return activation_
        
    if component == MODELCOMPONENTS.KVcache.value:
        kvCache_ = kvcache_per_layer(batch_size=modelDict["batch_size"], seq_len = modelDict["seq_len"], hidden_dim=modelDict["hidden_dim"], outTokenSize=modelDict["out_tokens"])
        return kvCache_
        
    raise ValueError(f" not supported {component} yet")

def qkv_projection_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value):
    return batch_size*3*2*modelCfg['max_seq_len']*modelCfg['model_dim']*modelCfg['model_dim']

def MLP_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value):
    return 2*2*batch_size*modelCfg['max_seq_len']*modelCfg['model_dim']*modelCfg['hidden_dim']

def out_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value):
    return batch_size*2*modelCfg['max_seq_len']*modelCfg['model_dim']*modelCfg['model_dim']

def flashAttention_compute(modelCfg: dict, batch_size: int =1, Phase: MODELPHASE = MODELPHASE.TOKEN.value, CMasking:bool = True):
    if Phase == MODELPHASE.PREFILL.Value:       
        FA = 2*2*batch_size*modelCfg['max_seq_len']*modelCfg['max_seq_len']*modelCfg['model_dim']
        FA = FA//2 if CMasking == True else FA
        return FA
    else:
        FA = 2*2*batch_size*modelCfg['max_seq_len']*modelCfg['model_dim']
        FA = FA//2 if CMasking == True else FA
        return FA

def GPUPeakMem(model_name:str):
    weights= model_parameters(model_name,"weights")
    activation = model_parameters(model_name,"activation")
    kvcache = model_parameters(model_name,"KVcache")
    
    peakMem = weights+activation+kvcache
    return peakMem

In [None]:
#MESH shard (x,y)

# ----------------
# |0 | 1 | 2 | 3 |
# |4 | 5 | 6 | 7 | 
# ----------------

#pipeline parallelism = X dimension = 4 GPU(S)  layers_per_rank = layers/num_ranks  
#model parallelism    = Y dimension = 2 GPU(s)  tensors_per_rank  = hidden_dim/num_ranks model_dim/num_ranks  heads_per_rank = attn_heads/num_ranks
# data_parallelism    = X dimension = 4 GPU(s)  batchsize_per_rank = batch_size/num_ranks



#2D sharding for 8 GPU(s)

mesh_shard = (4,2)  # (x,y)
mesh_dict = {'X': mesh_shard[0], 'Y': mesh_shard[1]}
# weights: hidden_dim, model-dim = Y
# activation: batch and model_dim  = X

#B = batchsize
#N = number of heads
#M = model-dimension
#S = sequence_len
#H = hidden_dimension
embedding_shard_BSM = {'batch_size' : 'X', 'max_seq_len' : '_', 'model_dim': "_'"}
activation_shard_BSM = {'batch_size'  : 'X', 'max_seq_len' : '_', 'model_dim' : '_'}
activation_shard_BSND = {'batch_size' : 'X', 'max_seq_len': '_', 'attn_heads': 'Y', 'attn_dim' : '_'}
activation_shard_BSH = {'batch_size'  : 'X', 'max_seq_len' : '_', 'hidden_dim' : 'Y'}
activation_shard_BNSS = {'batch_size' : 'X', 'attn_heads' : 'Y', 'max_seq_len' : '_'}

from dataclasses import dataclass
#first key is summation dimension second key is free0/free1 dimension
weights_shard_MM  = {'model_dim'  : '_', 'model_dim1' : '_'}
weights_shard_MND = {'model_dim'  : '_', 'attn_heads'  : 'Y', 'attn_dim' : '_'} 
weights_shard_NDM = {'attn_heads' : 'Y', 'attn_dim'  : '_', 'model_dim'  : 'X'}
weights_shard_MH  = {'model_dim'  : 'X', 'hidden_dim' : 'Y'}
weights_shard_HM  = {'model_dim'  : 'X', 'hidden_dim' : 'Y'}
embedding_weight =  {'vocab_size'  : '_', 'model_dim' : '_'}

@dataclass
class transformer_2dshard:
    embedding_activations = {'input0' : embedding_shard_BSM, 'output' : embedding_shard_BSM}
    qkv_activations = {'input0' : activation_shard_BSM , 'input1' : weights_shard_MM, 'output' : activation_shard_BSND}    # weights M=8192  N=64 D = 128   activation : B=1x2048x8192  (1x2048x128, 64) 
    flash_attention_qk = {'input0' : activation_shard_BSND , 'input1' : activation_shard_BSND, 'output' : activation_shard_BNSS}
    flash_attention_sv = {'input0' : activation_shard_BSND , 'input1' : activation_shard_BNSS, 'output' : activation_shard_BSND}
    out_projection =  {'input0' : activation_shard_BSND , 'input1' :  weights_shard_NDM , 'output' : activation_shard_BSM}
    #FIXME sequence_len sharding 
    layer_norm = {'input' : activation_shard_BSM , 'output' : activation_shard_BSM}
    mlp_0 = {'input0': activation_shard_BSM, 'input1' : weights_shard_MH, 'output' : activation_shard_BSH}
    gelu  = {'input': activation_shard_BSH, 'output' : activation_shard_BSH}
    mlp_1 = {'input1': activation_shard_BSH, 'input1' : weights_shard_HM, 'output' : activation_shard_BSM}
    
    def __init__(self,model_name:str = "llama_65B",
                 bpe : int = 2, 
                 num_gpus : int =1, shard_topology : tuple = None, 
                 parallelism_cfg: dict = None, 
                 sysCfg: system_config = None,
                 tileSizes: list[tuple] = None):
        
        self.bpe = bpe
        if model_name in MODEL_CLASSES.keys():
            self.modelParameters = MODEL_CLASSES[model_name]  
            self.model_name = model_name
        else:
            raise ValueError("Unsupported model")
        self.numGPUs = num_gpus
        self.shard_topology = shard_topology
        self.systemCfg = sysCfg
        self.fp16rate = self.systemCfg['mi300']['num_cus']*self.systemCfg['mi300']['fp16_rate']
        if parallelism_cfg is None:
            self.parallelismDict = dict()
            self.parallelismDict['tensorParallelism'] = True
            self.parallelismDict['dataParallelism'] = False
            self.parallelismDict['pipelineParallelism'] = True
        else:
            self.parallelismDict = parallelism_cfg   
        
        self.tileSizeChoices = tileSizes
        self.max_seq_len = self.modelParameters['max_seq_len']
        
        #support 2D mesh for sharding 
        assert (not(self.parallelismDict['tensorParallelism'] and self.parallelismDict['dataParallelism'] and self.parallelismDict['pipelineParallelism']))                                                                                                                                                                                                                                               
        if self.parallelismDict['pipelineParallelism'] and not self.parallelismDict['dataParallelism']: 
            self.pipeline_num_ranks = self.shard_topology[0]
            self.layers_per_rank  = self.systemCfg['num_layers'] // self.pipeline_num_ranks
            self.num_microbatches = self.modelParameters['batch_size']//self.modelParameters['microbatch_size']
            
        #data parallelism enabled when pipeline parallelism 
        #batch_Size is global batch size
        if self.pipeline_num_ranks == 1:
            self.data_num_ranks = self.shard_topology[0] 
            self.micro_batchsize = self.modelParameters['batch_size']//self.shard_topology[0]
            
        self.tensor_num_ranks = self.shard_topology[1]
        
        #self.input2dim = {'batch_size' : 'X', 
        #                  'model_dim' : 'Y', 
        #                  'hidden_dim' : 'Y', 
        #                  'attn_heads' : 'Y'}
        
        #self.dim2Input = {'X' : ['batch_size','model_dim'],
        #                  'Y' : ['attn_heads','hidden_dim']}
    def peakmem_check(self):
        memUsage = GPUPeakMem(self.model_name)
        #pipeline parallelism takes batch_size dimension
        memCapacity = self.systemCfg['hbm_cap']
        if self.parallelismDict['pipelineParallelism'] and not self.parallelismDict['dataParallelism']:
            memUsage = memUsage//self.shard_topology[0]
        assert(memUsage <= memCapacity)
        
    def recv_payloadsize(self):
        return self.modelParameters['batch_size'] * self.modelParameters['model_dim'] * self.modelParameters['max_seq_len']
    
    def layer_ops(self,seq_len: int):
        
        #qkv_projections
        act_tensor = {}
        wt_tensor = {}
        for key, _ in self.qkv_activations['input0'].items():
            if key == 'batch_size':
                act_tensor['batch_size'] = self.micro_batchsize
            elif key == 'max_seq_len':
                act_tensor['M'] = seq_len
            elif key == 'model_dim':
                act_tensor['K'] = self.modelParameters['model_dim'] 
            else:
                raise ValueError(f"unknown key in activation_shard_BSM tensor")
                
        for key, value in self.qkv_activations['input1'].items():
            if key == 'model_dim':
                wt_tensor['K'] = self.modelParameters['model_dim']
            elif key == 'model_dim1':
                wt_tensor['N'] = self.modelParameters['model_dim']
                if value != '_':
                    wt_tensor['N'] = wt_tensor['N']//self.tensor_num_ranks
            else:
                raise ValueError(f"unknown key in activation_shard_BSM tensor")        
 
        print(f"QKV_Projection:: GEMM problem sizes M={act_tensor['M']} N= {3*wt_tensor['N']} K= {act_tensor['K']}")
        #returns efficiency
        efficiency = find_bestgemmperformance(self.bpe,act_tensor['M'],3*wt_tensor['N'],act_tensor['K'],"mi300")
        qkv_cycles = ((2*act_tensor['M'])*3*wt_tensor['N']*act_tensor['K'])//(efficiency['eff']*self.systemCfg['mi300']['num_cus']*self.systemCfg['mi300']['fp16_rate'])
      
        #flash attention 
        #calculate num_heads per GPU
        #case 1: if heads >32 split heads in 'Y' dimension
        q_tensor = {}
        k_tensor = {}
        v_tensor = {}

        for key, value in self.flash_attention_qk['input0'].items():
            if key == 'batch_size':
                q_tensor['batch_size'] = self.micro_batchsize
            elif key == 'max_seq_len':
                q_tensor['M'] = seq_len
            elif key == 'attn_dim':
                q_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['num_heads']
            elif key == 'attn_heads':
                heads_num_ranks = 1 if value == '_' else mesh_dict[value]
                q_tensor['h']  = self.modelParameters['num_heads']
            else:
                raise ValueError(f"unknown key in activation_shard_BSND tensor")
        
        for key, _ in self.flash_attention_qk['input1'].items():
            if key == 'batch_size':
                k_tensor['batch_size'] = self.micro_batchsize
            elif key == 'max_seq_len':
                k_tensor['M'] = seq_len
            elif key == 'attn_dim':
                k_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['num_heads']
            elif key == 'attn_heads':
                k_tensor['h']  = self.modelParameters['num_heads']
            else:
                raise ValueError(f"unknown key in activation_shard_BSND tensor")
                

        #shard heads in 'Y' dimension     
        q_tensor['h'] = q_tensor['h']//heads_num_ranks
        
        fa_eff = find_bestfaperformance(q_tensor['M'],k_tensor['N'],q_tensor['K'],q_tensor['h'],self.bpe,"mi300")
        fa_flops = 2*(q_tensor['h']*(q_tensor['M']*q_tensor['K']*k_tensor['N']  + v_tensor['K']*q_tensor['M']*q_tensor['K']))
        fa_cycles = fa_flops//(self.fp16rate*fa_eff)
        
        #attention result projection output to [B,S,M]
        #input [B,S,N,D] x [N,D,M] -> [B,S,M]
        #shard on model-dimension
        #gather for reduction
        o_tensor = {}
        wo_tensor = {}
        
        for key, value in self.out_projection['input0'].items():
            if key == 'batch_size':
                o_tensor['batch_size'] = self.micro_batchsize
            elif key == 'max_seq_len':
                o_tensor['M'] = seq_len
            elif key == 'attn_dim':
                o_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['num_heads']
            elif key == 'attn_heads':
                heads_num_ranks = 1 if value == '_' else mesh_dict[value]
                o_tensor['h']  = self.modelParameters['num_heads']
            else:
                raise ValueError(f"unknown key in activation_shard_BSND tensor")
            
        for key, value in self.out_projection['input1'].items():
            if key == 'model_dim':
                wo_tensor['N'] = self.modelParameters['model_dim']
                tensor_num_ranks  = 1  if value == '_' else self.tensor_num_ranks
            elif key == 'attn_dim':
                wo_tensor['K'] = self.modelParameters['model_dim'] // self.modelParameters['num_heads']
            elif key == 'attn_heads':
                heads_num_ranks = 1 if value == '_' else mesh_dict[value]
                wo_tensor['h']  = self.modelParameters['num_heads']
            else:
                raise ValueError(f"unknown key in activation_shard_BSND tensor")    
        
        #gather heads before Out dimension
        #payload = num_heads * batch_size * seq_len * attn_dim (B,S,N,D)
        #FIXME communication & reduction 
        #global reduction through gather and reduction
        if heads_num_ranks > 1: 
           payload_size = o_tensor['h']*o_tensor['M']*wo_tensor['N']
           all_reduce_cycles = 10
        else:
           all_reduce_cycles = 0
           
        wo_tensor['N'] = wo_tensor['N'] // tensor_num_ranks 
        efficiency = find_bestgemmperformance(self.bpe,o_tensor['M'],wo_tensor['N'],o_tensor['K']*o_tensor['h'],"mi300")
        out_cycles = (2*o_tensor['M']*wo_tensor['N']*o_tensor['K']*o_tensor['h'])//(efficiency['eff']*self.systemCfg['mi300']['num_cus']*self.systemCfg['mi300']['fp16_rate']) 
        
        #global reduction through gather and reduction
        #MLP0 w = MxHy   act = BSMx
        #ALL-GATHER Mx ->M 
        #GELU
        #MLP1 w = HyMx act = BSHy
        #ALL-GATHER
        #MX -> M
        #output = BSHy HyM
        #       = BSM(partials)
        #       = BSMy (reduce & scatter)       
        
        
        fa_latency = 1
        norm_latency  = 1
        gather_latency = 1
        ff_latency = 1
        reduction_latency = 1
        
        return 1
        
    def prefill_forward(self):
        
    #rank[0] has additional layer of embedding beside other layers
        #rank[n-1] has additional layers of norm()
        #ignore embedding and norm() for now
        
        num_ranks = self.shard_topology[0]  if self.parallelismDict['pipelineParallelism'] and not self.parallelismDict['dataParallelism'] else 1
        layers_per_rank  = self.systemCfg['num_layers'] // num_ranks
        
        #every rank has send and recv communication along with other primitives for tensor  parallelism
        payload_size = self.modelParameters['']
        # payload_size // (lank_BW * efficiency)
        _recv_latency = self.recv_payloadsize()
        layer_latency = self.layer_ops()
        
        
        

        #rank{x}
        #rank{n-1}

#all_gather 
#scatter_gather
#all_reduce

#pipeline/model paralllelism  num_layers/num_ranks   
#tensor parallelism = MLP/MOE 
#TODO
# fix the compute graph with sharding
# calculate each layer time for prefill & generation 
# caclculate inference time




In [None]:
import numpy as np
from typing import List, Dict
# language model FLOPS 
# 3 * QKV 1 * projection + 2MLP  * number_of_layers
d_model = 1024
per_layer_flops = 2*(3*pow(d_model,2) + pow(d_model,2)*4*2 + pow(d_model,2))
per_layer_flops = 2*12*pow(d_model,2)

kv_cache_flops = 2*2*pow(d_model,2)

mi300_fp16_flops = 304*2048/1e3  #Tflops
mi300_BW = 4.3e12 
mi300_capacity = 192e9
bpe=2

#52B model
num_layers = 64
d_model = 8192
num_head = 64
model_parameters = 52e9


#work request parameters
seqlen : List(int) = [1024,2048,4096,8192,16384]
batch_size = np.arange(1,512,8,dtype=int)

#parallelism parameters
N = 4    # number of accelerators


model_size = model_parameters*bpe
kv_cache_size = 2*bpe*(seqlen*d_model*num_layers)
token_size = mi300_capacity - (model_size+kv_cache_size)

#pipeline parallelism
#split layers among N accelerators

#model parallelism (tensor parallelism)
#attention - split #heads N accelerators
#MLP - split tensors(weight) N accelerators
# num_heads = h  hd = d_model/h We
# Wq,Wk,Wv  (num_heads,d_modelxhd) 
# each layer algorithm
#  S= Q@K => P = softmax(S) => P = P @ V => O = P @ W_o   
# send d_model/N to N-1 accelerators 
# receive d_model/N from N-1 accelerators
# add d_model/N vectors from N-1 accelerators 

#MLP weights 4*d_model / N split
#MLP1 = d_model @ 4*d_model/N => 4*d_model/N @ d_model 
# send 4*d_model/N to N-1 accelerators
# receive 4*d_model/N from accelerators
# in total each layer require 4*communication ops

#batch=1 case
# mem_time (math is boud by memory BW) = bpe*model_parameters/ (N*mi300_bw)
# communication = 4*num_layers*8us (latency due to BS=1)

#batch = large
# compute = 2*P/(N*)






In [14]:
from typing import List, Optional, Dict

seqlen : List[int] = [1024,2048,4096,8192,16384]
