In [8]:
from memory import llama_activation

GB = 2**30
Tflops = 1e12
Billion = 1e9
""" hardware parameters """
# 3ddram
dram3d_capacity = 40 * GB
card_compute = 576 * Tflops
c2c_bandwidth = 400 * GB

# 按照H100的设置
# dram3d_capacity = 80 * GB
# card_compute = 1000 * Tflops
# c2c_bandwidth = 450 * GB


"""model config"""
llama_config = {
    70 * Billion: {
        "hidden_size": 8192,
        "layer_num": 80,
        "head_num": 64,
        "kv_head": 8,
        "immediate_size": 12288,
        "vocab_size": 128256,


    },
    8 * Billion: {
        "hidden_size": 4096,
        "layer_num": 32,
        "head_num": 64,
        "kv_head": 8,
        "immediate_size": 6144,
        "vocab_size": 128256,
    },
}
"""common config"""
# params = 8 * Billion
# seq_len = 4096
# batch_size = 8
# card_num = 8

params = 70 * Billion
seq_len = 8192
batch_size = 8
card_num = 64

"""get model config"""

hidden_size = llama_config[params]["hidden_size"]
layer_num = llama_config[params]["layer_num"]
head_num = llama_config[params]["head_num"]
kv_head = llama_config[params]["kv_head"]
immediate_size = llama_config[params]["immediate_size"]
vocab_size = llama_config[params]["vocab_size"]

activation_mem = llama_activation(head_num=head_num, kv_head=kv_head, batch_size=batch_size, seq_len=seq_len, \
    hidden_size=hidden_size, immediate_size=immediate_size, layer_num=layer_num, vocab_size=vocab_size, is_print=False, use_flash_attention=True)
activation_mem = activation_mem * GB

optimizer_mem = 12 * params  # FP32 weights, FP32 momentum, FP32 2nd momentum
grad_mem = 2 * params
weights_mem = 2 * params

total_mem = activation_mem + optimizer_mem + grad_mem + weights_mem
print(f"Activation memory: {activation_mem/GB:.2f}GB")
print(f"Optimizer memory: {optimizer_mem/GB:.2f}GB")
print(f"Grad memory: {grad_mem/GB:.2f}GB")
print(f"Weights memory: {weights_mem/GB:.2f}GB")
print(f"Total memory need: {total_mem/GB:.2f}GB")

print(f"Total 3DDRAM Memory: {dram3d_capacity*card_num/GB:.2f}GB")

computation = batch_size * seq_len * 2 * params
computation_time = computation / card_compute / card_num
print(f"computation_time: {computation_time:.4f}s")

communication = batch_size * seq_len * hidden_size * 2 * 2 * layer_num  # BF16=2Byte, 2 All-reduce/layer
print(f"communication: {communication/GB:.2f}GB")
one_time_comm = batch_size * seq_len * hidden_size * 2
print(f"one_time_comm: {one_time_comm/GB:.4f}GB")

communication_time = communication / c2c_bandwidth * 2

print(f"communication_time: {communication_time:.4f}s")

print(f"compute/communication ratio: {computation_time/communication_time:.2f}")

one_comutation = batch_size * (seq_len/card_num) * hidden_size  * (hidden_size/card_num) * 2
print(f"one_comutation: {one_comutation/Tflops:.9f}TFlops")
one_comutation_time = one_comutation / card_compute
print(f"one_comutation_time: {one_comutation_time*1000:.9f} ms")

one_chunk = batch_size * (seq_len/card_num) * hidden_size * 2
print(f"one_chunk: {one_chunk/GB:.9f}GB")
one_chunk_time = one_chunk / c2c_bandwidth
print(f"one_chunk_time: {one_chunk_time*1000:.9f} ms")


Activation memory: 1134.31GB
Optimizer memory: 782.31GB
Grad memory: 130.39GB
Weights memory: 130.39GB
Total memory need: 2177.39GB
Total 3DDRAM Memory: 2560.00GB
computation_time: 0.2489s
communication: 160.00GB
one_time_comm: 1.0000GB
communication_time: 0.8000s
compute/communication ratio: 0.31
one_comutation: 0.002147484TFlops
one_comutation_time: 0.003728270 ms
one_chunk: 0.015625000GB
one_chunk_time: 0.039062500 ms
