In [1]:
from utils import calculate_weight_memory, calculate_kv_cache
from utils import convert_bytes_to_gigabytes, convert_bytes_to_megabytes
from name import TrainingConfig, Datatype, Transformer
from transformer_mem_functional import calculate_memory_requirements
from constants import LLAMA3_70B_CONFIG, LLAMA3_400B_CONFIG, VANILA_TRAINING_CONFIG, H100_MEMORY, A100_MEMORY, V100_MEMORY, MI250X_MEMORY

In [2]:
# global_batch_sizes = [x*10**6 for x in [2, 4, 10, 40]]
global_batch_sizes = [x*10**6 for x in [1, 2, 4, 8, 16, 32]]

### Total training memory

In [3]:
from copy import deepcopy
import pandas as pd
from utils import convert_to_million_format, convert_bytes_to_terabytes
from dataclasses import asdict

In [4]:
[f"{k}: {v}" for k, v in asdict(VANILA_TRAINING_CONFIG).items()]

['tp_size: 1',
 'pp_size: 1',
 'num_gpus: 1',
 'ctx_length: 8192',
 'partition_activations: False',
 'zero1: 1',
 'checkpoint_activations: False',
 'batch_size_per_replicas: 1',
 'weight_dtype: Datatype.BFLOAT16',
 'act_dtype: Datatype.BFLOAT16',
 'gradient_dtype: Datatype.BFLOAT16',
 'optim_first_state_dtype: Datatype.FLOAT32',
 'optim_second_state_dtype: Datatype.FLOAT32',
 'master_weight_dtype: Datatype.FLOAT32']

### The total memory for training/inference

In [5]:
memory_data = []

In [6]:
# batch_size = gbs // training_config.ctx_length
# for gbs in global_batch_sizes:
batch_sizes = [gbs // VANILA_TRAINING_CONFIG.ctx_length for gbs in global_batch_sizes]
batch_sizes.insert(0, 1)
global_batch_sizes.insert(0, VANILA_TRAINING_CONFIG.ctx_length)

In [7]:
NUM_GPUS_PER_REPLICAS = None
for ckp_act in [False, True]:
    for batch_size, gbs in zip(batch_sizes, global_batch_sizes):
        training_config = deepcopy(VANILA_TRAINING_CONFIG)
        training_config.batch_size_per_replicas = batch_size
        training_config.checkpoint_activations = ckp_act
    
        memory_requirements = calculate_memory_requirements(
            transformer=LLAMA3_70B_CONFIG,
            config=training_config
        )
        keys_that_convert_to_tb = ["activation_mem", "total_training_mem"]
        
        memory_requirements_gb = {}
        memory_requirements_gb["name"] = LLAMA3_70B_CONFIG.name
        memory_requirements_gb['ckp_act'] = ckp_act
        memory_requirements_gb['global_batch_size'] = convert_to_million_format(gbs)
        memory_requirements_gb['batch_size'] = batch_size
        # memory_requirements_gb = {k: convert_bytes_to_gigabytes(v) if not k in keys_that_convert_to_tb else convert_bytes_to_terabytes(v) for k, v in memory_requirements.items()}
        memory_requirements_gb.update(
            {k: f"{(convert_bytes_to_gigabytes(v) if k not in keys_that_convert_to_tb else convert_bytes_to_terabytes(v))} - {percent}%"
             for k, (v, percent) in memory_requirements.items()}
        )
    
        memory_requirements_gb["nums_gpus_for_training_without_grad_accum_(h100,a100,v100,MI250X)"] = (
            f"{int(memory_requirements['total_training_mem'][0] // H100_MEMORY):,}",
            f"{int(memory_requirements['total_training_mem'][0] // A100_MEMORY):,}",
            f"{int(memory_requirements['total_training_mem'][0] // V100_MEMORY):,}",
            f"{int(memory_requirements['total_training_mem'][0] // MI250X_MEMORY):,}",
        )
        memory_requirements_gb["nums_h100_for_inference"] = int(int(memory_requirements["total_inference_mem"][0]) // H100_MEMORY),
        memory_data.append(memory_requirements_gb)

In [8]:
memory_df = pd.DataFrame(memory_data)

In [9]:
memory_df

Unnamed: 0,name,ckp_act,global_batch_size,batch_size,model_mem,activation_mem,kv_cache_mem,grad_mem,optim_mem,total_training_mem,total_inference_mem,"nums_gpus_for_training_without_grad_accum_(h100,a100,v100,MI250X)",nums_h100_for_inference
0,llama3 70b,False,8192,1,128.854 GB - 4.4%,1.901 TB - 64.83%,21.475 GB - 0%,128.854 GB - 4.4%,773.126 GB - 26.37%,2.931 TB - 100%,150.329 GB - 0%,"(34, 42, 170, 21)","(1,)"
1,llama3 70b,False,1.0m,122,128.854 GB - 0.06%,231.864 TB - 99.56%,21.475 GB - 0%,128.854 GB - 0.06%,773.126 GB - 0.33%,232.895 TB - 100%,150.329 GB - 0%,"(2,711, 3,389, 13,556, 1,694)","(1,)"
2,llama3 70b,False,2.0m,244,128.854 GB - 0.03%,463.728 TB - 99.78%,21.475 GB - 0%,128.854 GB - 0.03%,773.126 GB - 0.17%,464.758 TB - 100%,150.329 GB - 0%,"(5,410, 6,763, 27,052, 3,381)","(1,)"
3,llama3 70b,False,4.0m,488,128.854 GB - 0.01%,927.455 TB - 99.89%,21.475 GB - 0%,128.854 GB - 0.01%,773.126 GB - 0.08%,928.486 TB - 100%,150.329 GB - 0%,"(10,809, 13,511, 54,045, 6,755)","(1,)"
4,llama3 70b,False,8.0m,976,128.854 GB - 0.01%,1854.910 TB - 99.94%,21.475 GB - 0%,128.854 GB - 0.01%,773.126 GB - 0.04%,1855.941 TB - 100%,150.329 GB - 0%,"(21,606, 27,007, 108,030, 13,503)","(1,)"
5,llama3 70b,False,16.0m,1953,128.854 GB - 0.0%,3711.721 TB - 99.97%,21.475 GB - 0%,128.854 GB - 0.0%,773.126 GB - 0.02%,3712.752 TB - 100%,150.329 GB - 0%,"(43,222, 54,027, 216,110, 27,013)","(1,)"
6,llama3 70b,False,32.0m,3906,128.854 GB - 0.0%,7423.443 TB - 99.99%,21.475 GB - 0%,128.854 GB - 0.0%,773.126 GB - 0.01%,7424.474 TB - 100%,150.329 GB - 0%,"(86,432, 108,040, 432,161, 54,020)","(1,)"
7,llama3 70b,True,8192,1,128.854 GB - 10.62%,0.183 TB - 15.04%,21.475 GB - 0%,128.854 GB - 10.62%,773.126 GB - 63.72%,1.213 TB - 100%,150.329 GB - 0%,"(14, 17, 70, 8)","(1,)"
8,llama3 70b,True,1.0m,122,128.854 GB - 0.55%,22.269 TB - 95.58%,21.475 GB - 0%,128.854 GB - 0.55%,773.126 GB - 3.32%,23.300 TB - 100%,150.329 GB - 0%,"(271, 339, 1,356, 169)","(1,)"
9,llama3 70b,True,2.0m,244,128.854 GB - 0.28%,44.539 TB - 97.74%,21.475 GB - 0%,128.854 GB - 0.28%,773.126 GB - 1.7%,45.570 TB - 100%,150.329 GB - 0%,"(530, 663, 2,652, 331)","(1,)"


### Activation memory's breakdown (the overall memory consumption, without activation recomputation)

In [58]:
from transformer_mem_functional import calculate_activation_memory

In [23]:
act_mem_data = []

In [24]:
for batch_size, gbs in zip(batch_sizes, global_batch_sizes):
    training_config.batch_size_per_replicas = batch_size

    memory_requirements = calculate_activation_memory(
        transformer=LLAMA3_70B_CONFIG,
        config=training_config
    )[1]
    
    memory_requirements_gb = {}
    memory_requirements_gb["name"] = LLAMA3_70B_CONFIG.name
    memory_requirements_gb['global_batch_size'] = convert_to_million_format(gbs)
    memory_requirements_gb['batch_size'] = batch_size
    memory_requirements_gb.update(
        {k: f"{(convert_bytes_to_gigabytes(v) if k not in keys_that_convert_to_tb else convert_bytes_to_terabytes(v))} - {percent}%"
         for k, (v, percent) in memory_requirements.items()}
    )
    act_mem_data.append(memory_requirements_gb)

In [25]:
act_mem_df = pd.DataFrame(act_mem_data)

In [26]:
act_mem_df

Unnamed: 0,name,global_batch_size,batch_size,linear_proj_input,attn_qkv_matmul,attn_qk_scores,attn_softmax,attn_dropout,attn_v,attn_drop_mask,mlp,ln
0,llama3 70b,8192,1,0.134 GB - 0.56%,0.134 GB - 0.56%,0.268 GB - 1.13%,8.590 GB - 36.16%,4.295 GB - 18.08%,8.724 GB - 36.72%,0.067 GB - 0.28%,1.275 GB - 5.37%,0.268 GB - 1.13%
1,llama3 70b,2.0m,244,32.749 GB - 0.56%,32.749 GB - 0.56%,65.498 GB - 1.13%,2095.944 GB - 36.16%,1047.972 GB - 18.08%,2128.693 GB - 36.72%,16.375 GB - 0.28%,311.117 GB - 5.37%,65.498 GB - 1.13%
2,llama3 70b,4.0m,488,65.498 GB - 0.56%,65.498 GB - 0.56%,130.997 GB - 1.13%,4191.888 GB - 36.16%,2095.944 GB - 18.08%,4257.386 GB - 36.72%,32.749 GB - 0.28%,622.233 GB - 5.37%,130.997 GB - 1.13%
3,llama3 70b,10.0m,1220,163.746 GB - 0.56%,163.746 GB - 0.56%,327.491 GB - 1.13%,10479.720 GB - 36.16%,5239.860 GB - 18.08%,10643.466 GB - 36.72%,81.873 GB - 0.28%,1555.583 GB - 5.37%,327.491 GB - 1.13%
4,llama3 70b,40.0m,4882,655.251 GB - 0.56%,655.251 GB - 0.56%,1310.502 GB - 1.13%,41936.061 GB - 36.16%,20968.030 GB - 18.08%,42591.312 GB - 36.72%,327.625 GB - 0.28%,6224.884 GB - 5.37%,1310.502 GB - 1.13%


### Datacenters

In [27]:
# Extract minimal GPUs per cluster for batch_size = 1
min_gpus_str = memory_df[memory_df['batch_size'] == 1]["nums_gpus_for_training_without_grad_accum_(h100,a100,v100,MI250X)"].item()
min_h100, min_a100, min_v100, min_mi250x = [int(x.replace(",", "")) for x in min_gpus_str]

ValueError: can only convert an array of size 1 to a Python scalar

The minimum number of GPUs required for each cluster using the following hardware: H100, A100, V100, and MI250X.

In [None]:
min_gpus_str

In [28]:
# Extract total GPUs required for global_batch_size = 2.0m
total_gpus_str = memory_df[memory_df['global_batch_size'] == "2.0m"]["nums_gpus_for_training_without_grad_accum_(h100,a100,v100,MI250X)"].item()
total_h100, total_a100, total_v100, total_mi250x = [int(x.replace(",", "")) for x in total_gpus_str]

ValueError: can only convert an array of size 1 to a Python scalar

In [29]:
gpu_names = ['H100', 'A100', 'V100', 'MI250X']
base_gpus_per_cluster = [min_h100, min_a100, min_v100, min_mi250x]
totals = [total_h100, total_a100, total_v100, total_mi250x]

NameError: name 'min_h100' is not defined

The maximum number of clusters, assuming each cluster trains with a batch size of 1, given a global batch size of 2m

In [30]:
all_dfs = []

factors = [1, 5, 10, 15]
for factor in factors:
    # Scale the gpus_per_cluster by the factor
    scaled_gpus_per_cluster = [g * factor for g in base_gpus_per_cluster]
    
    # Compute maximum number of clusters for each GPU type
    max_clusters = [t // g for t, g in zip(totals, scaled_gpus_per_cluster)]
    
    df = pd.DataFrame({
        'gpu_name': gpu_names,
        # 'factor': factor,
        'gpus_per_cluster': scaled_gpus_per_cluster,
        'maximum_number_of_datacenters': max_clusters
    })
    all_dfs.append(df)

# Concatenate all the DataFrames for a final result
result_df = pd.concat(all_dfs, ignore_index=True)

NameError: name 'base_gpus_per_cluster' is not defined

In [31]:
result_df

NameError: name 'result_df' is not defined