In [1]:
from utils import calculate_weight_memory, calculate_kv_cache
from utils import convert_bytes_to_gigabytes, convert_bytes_to_megabytes
from constants import LLAMA3_400B_CONFIG
from name import TrainingConfig, Datatype, Transformer
from transformer_mem_functional import calculate_memory_requirements

In [5]:
LLAMA3_400B_CONFIG

Transformer(name='llama3 405B', n_layers=126, hidden_size=16384, n_heads=128, n_key_value_heads=8, ctx_length=8192)

In [8]:
training_config = TrainingConfig(
    tp_size=1, pp_size=1, num_gpus=1,
    partition_activations=False, zero1=0,
    checkpoint_activations=False,
    batch_size_per_replicas=1,
    weight_dtype=Datatype.BFLOAT16,
    gradient_dtype=Datatype.BFLOAT16
)

In [9]:
calculate_memory_requirements(transformer=LLAMA3_400B_CONFIG, config=training_config)

{'Model Memory': 756.015380859375,
 'KV Cache Memory': 63.0,
 'Gradient Memory': 756.015380859375,
 'Activation Memory': 5575.5,
 'Optimizer Memory': 4536.09228515625,
 'Communication Memory': 0.0,
 'Miscellaneous Memory': 0.0,
 'Total Training Memory (GB)': 11623.623046875,
 'Total Inference Memory (GB)': 819.015380859375}

In [7]:
calculate_memory_requirements(
    transformer=Transformer(
        name="x",
        n_layers=44,
        hidden_size=6144,
        n_heads=64,
        n_key_value_heads=64,
        ctx_length=2048
    ),
    config=TrainingConfig(
    tp_size=1, pp_size=1, num_gpus=1,
    partition_activations=True, zero1=1,
    checkpoint_activations=True,
    batch_size_per_replicas=1,
    weight_dtype=Datatype.BFLOAT16,
    gradient_dtype=Datatype.BFLOAT16,
    optim_first_state_dtype=Datatype.FLOAT32,
    optim_second_state_dtype=Datatype.FLOAT32
))

{'Model Memory': 39864827904.0,
 'KV Cache Memory': 2214592512.0,
 'Gradient Memory': 39864827904.0,
 'Activation Memory': 18824036352.0,
 'Optimizer Memory': 159459311616.0,
 'Communication Memory': 3000000000.0,
 'Miscellaneous Memory': 0,
 'Total Training Memory (GB)': 261013003776.0,
 'Total Inference Memory (GB)': 42079420416.0}

In [9]:
{k: convert_bytes_to_gigabytes(v) for k, v in calculate_memory_requirements(
    transformer=Transformer(
        name="x",
        n_layers=44,
        hidden_size=6144,
        n_heads=64,
        n_key_value_heads=64,
        ctx_length=2048
    ),
    config=TrainingConfig(
    tp_size=1, pp_size=1, num_gpus=1,
    partition_activations=True, zero1=1,
    checkpoint_activations=True,
    batch_size_per_replicas=1,
    weight_dtype=Datatype.BFLOAT16,
    gradient_dtype=Datatype.BFLOAT16,
    optim_first_state_dtype=Datatype.FLOAT32,
    optim_second_state_dtype=Datatype.FLOAT32
)).items()}

{'Model Memory': '39.865 GB',
 'KV Cache Memory': '2.215 GB',
 'Gradient Memory': '39.865 GB',
 'Activation Memory': '18.824 GB',
 'Optimizer Memory': '159.459 GB',
 'Communication Memory': '3.000 GB',
 'Miscellaneous Memory': '0.000 GB',
 'Total Training Memory (GB)': '261.013 GB',
 'Total Inference Memory (GB)': '42.079 GB'}

In [6]:
{k: convert_bytes_to_gigabytes(v) for k, v in calculate_memory_requirements(
    transformer=LLAMA3_400B_CONFIG,
    config=TrainingConfig(
    tp_size=1, pp_size=1, num_gpus=1,
    partition_activations=True, zero1=1,
    checkpoint_activations=True,
    batch_size_per_replicas=1,
    weight_dtype=Datatype.BFLOAT16,
    gradient_dtype=Datatype.BFLOAT16,
    optim_first_state_dtype=Datatype.FLOAT32,
    optim_second_state_dtype=Datatype.FLOAT32
)).items()}

{'Model Memory': '811.765 GB',
 'KV Cache Memory': '67.646 GB',
 'Gradient Memory': '811.765 GB',
 'Activation Memory': '574.989 GB',
 'Optimizer Memory': '3247.061 GB',
 'Communication Memory': '3.000 GB',
 'Miscellaneous Memory': '0.000 GB',
 'Total Training Memory (GB)': '5448.581 GB',
 'Total Inference Memory (GB)': '879.411 GB'}

In [5]:
calculate_memory_requirements(
    transformer=LLAMA3_400B_CONFIG,
    config=TrainingConfig(
    tp_size=1, pp_size=1, num_gpus=1,
    partition_activations=True, zero1=1,
    checkpoint_activations=True,
    batch_size_per_replicas=1,
    weight_dtype=Datatype.BFLOAT16,
    gradient_dtype=Datatype.BFLOAT16,
    optim_first_state_dtype=Datatype.FLOAT32,
    optim_second_state_dtype=Datatype.FLOAT32
)).items()

dict_items([('Model Memory', 811765334016.0), ('KV Cache Memory', 67645734912.0), ('Gradient Memory', 811765334016.0), ('Activation Memory', 574988746752.0), ('Optimizer Memory', 3247061336064.0), ('Communication Memory', 3000000000.0), ('Miscellaneous Memory', 0), ('Total Training Memory (GB)', 5448580750848.0), ('Total Inference Memory (GB)', 879411068928.0)])