In [128]:
def get_activation_size(batch_size, hidden_size, seq_len, FP, TP):
    mega = 1024 ** 2
    return batch_size * seq_len * hidden_size * FP / TP / mega

In [129]:
def get_embedding_size(vocab_size, hidden_size, FP, TP):
    mega = 1024 ** 2
    return vocab_size * hidden_size * FP / TP / mega

In [130]:
def get_tp_peer_size(batch_size, seq_len, hidden_size, FP, layer_per_stage):
    mega = 1024 ** 2
    return 2 * batch_size * seq_len * hidden_size * layer_per_stage * FP / mega

In [131]:
def get_flops_per_layer(hidden_size, seq_len, batch_size):
    return 4 * seq_len ** 2 * hidden_size * batch_size + 10 * seq_len * hidden_size ** 2 * batch_size

In [132]:
# GPT3 large
B = 4 # batch size
H = 1024 # hidden size
S = 2048 # sequence length
FP = 2 # 2 bytes per float
V = 50256 # vocabulary size
L = 24 # number of layers

# cluster config
BW = 10000 # bandwidth MBps
TOP_FLOPS = 312 * 10 ** 12 # 312 TFLOP for A100 FP16
util_reduce = 0.6 # 60% utilization
FLOPS = TOP_FLOPS * util_reduce


In [133]:
def show_info(TP,PP):
    layers_per_stage = L / PP

    print("Activation time (ms)")
    activation_time = get_activation_size(B, H, S, FP, TP) / BW * 1000
    print(activation_time)
 
    print("Embedding time (ms)")
    embedding_time = get_embedding_size(V, H, FP, TP)/BW * 1000
    print(embedding_time)
    
    print("TP time (ms)")
    tp_peer_size = get_tp_peer_size(B, S, H, FP, layers_per_stage)/BW * 1000
    print(tp_peer_size)

    print("Per layer Total time (ms)")
    total_time = get_flops_per_layer(H, S, B) / FLOPS * 1000
    print(total_time)
    
    print("Per stage comp time (ms)")
    comp_time = total_time * L / PP
    print(comp_time)

In [141]:
# parallel config
PP = 8 # number of wparallel pipelines
TP = 2 # tensor parallelism
show_info(TP,PP)

Activation time (ms)
0.8
Embedding time (ms)
4.9078124999999995
TP time (ms)
9.6
Per layer Total time (ms)
0.8259552492307692
Per stage comp time (ms)
2.4778657476923076
