In [1]:
from utils import (
    get_dataset_size_from_model_size, calculate_total_steps, calculate_total_flops, calculate_flops_per_step,
    calculate_num_h100s_per_step, calculate_total_time_to_train_a_model,
    compute_minimum_latency_between_clusters, calculate_total_minimum_comm_latency_to_train_a_model,
    get_time_per_step, get_training_cost
)
from utils import (
    convert_to_petaflops, convert_to_exaflops, convert_seconds_to_days,
    convert_to_xt_format, convert_to_million_format, convert_to_billion_format, convert_seconds_to_years
)
from constants import UTILIZED_BFLOAT16_FLOPS, H100_COST_PER_HOUR, H100_COST_PER_GPU
import pandas as pd
import numpy as np

In [2]:
# target_model_sizes = [70*10**9, 100*10**9, 500*10**9, 1000*10**9, 10000*10**9, 100000*10**9]
target_model_sizes = [70*10**9]

#### Critical batch size

In [3]:
from utils import get_critical_batch_size, convert_to_million_format

In [4]:
critical_batch_sizes = [get_critical_batch_size(x) * 4096 for x in target_model_sizes]

In [5]:
df_gbs = pd.DataFrame({
    'Model Size (B params)': [convert_to_xt_format(x) for x in target_model_sizes],
    'Critical Batch Size (tokens)': [f'{x/1e6:1f}M' for x in critical_batch_sizes]
})

In [6]:
df_gbs

Unnamed: 0,Model Size (B params),Critical Batch Size (tokens)
0,70.00B,4.218658M


#### Compute

In [7]:
from constants import CTX_LENGTH

In [8]:
# 100b to 100T
# global_batch_sizes = [x*10**6 for x in [2, 10, 40]]
global_batch_sizes = [x*10**6 for x in [1, 2, 4, 8, 16, 32]]
global_batch_sizes.insert(0, CTX_LENGTH)

In [9]:
global_batch_sizes

[8192, 1000000, 2000000, 4000000, 8000000, 16000000, 32000000]

In [10]:
time_per_step = 1 # the total time of a fwd, and bwd pass

In [11]:
from constants import VANILA_TRAINING_CONFIG, LLAMA3_70B_CONFIG, H100_MEMORY
from copy import deepcopy
from transformer_mem_functional import calculate_memory_requirements

In [12]:
dataframes_compute = []

for is_train_for_20T in [False, True]:
    for global_batch_size in global_batch_sizes:
        data_compute = {
            "Model Size (Params)": [],
            "Dataset Size (Tokens)": [],
            "gbs": [],
            "Total Steps": [],
            "Total FLOPs": [],
            "FLOPs per Step": [],
            "H100 GPUs Needed": [],
            "Total H100s cost": [],
            "training_time_and_mem_with_0_grad_accum": [],
            "training_time_and_mem_with_5_grad_accum": [],
            "training_time_and_mem_with_10_grad_accum": [],
            "training_time_and_mem_with_100_grad_accum": [],
            # "Total Training Time with 1000 grad accum": []
        }
        
        for model_size in target_model_sizes:
            if is_train_for_20T is True:
                if model_size > 100*10**9:
                    # NOTE: only check 20T tokens for model that is less than 10B
                    continue
                
                dataset_size = 20000*10**9
            else:
                dataset_size = get_dataset_size_from_model_size(model_size)

            def compute_training_memory_with_grad_accum(gbs, num_grad_accum):
                training_config = deepcopy(VANILA_TRAINING_CONFIG)
                training_config.batch_size_per_replicas = gbs // num_grad_accum
                training_config.checkpoint_activations = True

                training_memory = calculate_memory_requirements(
                    transformer=LLAMA3_70B_CONFIG,
                    config=training_config
                )["total_training_mem"][0]
                return training_memory
            
            training_mem_with_0_grad_accum = compute_training_memory_with_grad_accum(global_batch_size, 1)
            training_mem_with_5_grad_accum = compute_training_memory_with_grad_accum(global_batch_size, 5)
            training_mem_with_10_grad_accum = compute_training_memory_with_grad_accum(global_batch_size, 10)
            training_mem_with_100_grad_accum = compute_training_memory_with_grad_accum(global_batch_size, 100)

            num_gpus_with_5_grad_accum = training_mem_with_5_grad_accum / H100_MEMORY
            num_gpus_with_10_grad_accum = training_mem_with_10_grad_accum / H100_MEMORY
            num_gpus_with_100_grad_accum = training_mem_with_100_grad_accum / H100_MEMORY
                
            total_steps = calculate_total_steps(model_size, dataset_size, global_batch_size)
            total_flops = calculate_total_flops(model_size, dataset_size)
            flops_per_step = calculate_flops_per_step(model_size, global_batch_size)
            h100s_per_step = calculate_num_h100s_per_step(model_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS)
    
            time_per_step = get_time_per_step(model_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS)
            total_time = calculate_total_time_to_train_a_model(model_size, dataset_size, global_batch_size, time_per_step)
            total_training_cost = get_training_cost(model_size, dataset_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS, H100_COST_PER_HOUR)
            
            data_compute["Model Size (Params)"].append(convert_to_xt_format(model_size))
            data_compute["Dataset Size (Tokens)"].append(convert_to_xt_format(dataset_size))
            data_compute["gbs"].append(f'{global_batch_size/1e6}M')
            data_compute["Total Steps"].append("{:,}".format(total_steps))
            data_compute["Total FLOPs"].append(convert_to_exaflops(total_flops))
            data_compute["FLOPs per Step"].append(convert_to_petaflops(flops_per_step))
            data_compute["H100 GPUs Needed"].append(h100s_per_step)
            # data_compute["Total H100s cost"].append(convert_to_billion_format(h100s_per_step * H100_COST_PER_GPU))
            data_compute["Total H100s cost"].append(convert_to_million_format(total_training_cost))
            data_compute["training_time_and_mem_with_0_grad_accum"].append(convert_seconds_to_days(total_time))
            # data_compute["training_time_and_mem_with_5_grad_accum"].append(f"{convert_seconds_to_years(total_time*5)} | {h100s_per_step/5} gpus | {num_gpus_with_5_grad_accum} gpus")
            data_compute["training_time_and_mem_with_5_grad_accum"].append(f"{convert_seconds_to_years(total_time*5)} | {h100s_per_step/5} gpus")
            data_compute["training_time_and_mem_with_10_grad_accum"].append(f"{convert_seconds_to_years(total_time*10)} | {h100s_per_step/10} gpus")
            data_compute["training_time_and_mem_with_100_grad_accum"].append(f"{convert_seconds_to_years(total_time*100)} | {h100s_per_step/100} gpus")
            # data_compute["Total Training Time with 1000 grad accum"].append(f"{convert_seconds_to_years(total_time*1000)} - {h100s_per_step/1000} gpus")
        
        df = pd.DataFrame(data_compute)
        # Add batch size information
        # df['Global Batch Size'] = f'{global_batch_size/1e6}M'
        dataframes_compute.append(df)
    
    final_df_compute = pd.DataFrame()
    for i, df in enumerate(dataframes_compute):
        final_df_compute = pd.concat([final_df_compute, df])

In [13]:
final_df_compute

Unnamed: 0,Model Size (Params),Dataset Size (Tokens),gbs,Total Steps,Total FLOPs,FLOPs per Step,H100 GPUs Needed,Total H100s cost,training_time_and_mem_with_0_grad_accum,training_time_and_mem_with_5_grad_accum,training_time_and_mem_with_10_grad_accum,training_time_and_mem_with_100_grad_accum
0,70.00B,1.4T,0.008192M,170898437,"588,000.0 EFLOPs",3.44064 PFLOPs,8.0,0.8m,2150.4 days,29.44 years | 1.6 gpus,58.87 years | 0.8 gpus,588.74 years | 0.08 gpus
0,70.00B,1.4T,1.0M,1400000,"588,000.0 EFLOPs",420.0 PFLOPs,1061.0,0.8m,16.2 days,0.22 years | 212.2 gpus,0.44 years | 106.1 gpus,4.44 years | 10.61 gpus
0,70.00B,1.4T,2.0M,700000,"588,000.0 EFLOPs",840.0 PFLOPs,2123.0,0.8m,8.1 days,0.11 years | 424.6 gpus,0.22 years | 212.3 gpus,2.22 years | 21.23 gpus
0,70.00B,1.4T,4.0M,350000,"588,000.0 EFLOPs","1,680.0 PFLOPs",4246.0,0.8m,4.1 days,0.06 years | 849.2 gpus,0.11 years | 424.6 gpus,1.11 years | 42.46 gpus
0,70.00B,1.4T,8.0M,175000,"588,000.0 EFLOPs","3,360.0 PFLOPs",8493.0,0.8m,2.0 days,0.03 years | 1698.6 gpus,0.06 years | 849.3 gpus,0.55 years | 84.93 gpus
0,70.00B,1.4T,16.0M,87500,"588,000.0 EFLOPs","6,720.0 PFLOPs",16986.0,0.8m,1.0 days,0.01 years | 3397.2 gpus,0.03 years | 1698.6 gpus,0.28 years | 169.86 gpus
0,70.00B,1.4T,32.0M,43750,"588,000.0 EFLOPs","13,440.0 PFLOPs",33973.0,0.8m,0.5 days,0.01 years | 6794.6 gpus,0.01 years | 3397.3 gpus,0.14 years | 339.73 gpus
0,70.00B,20.0T,0.008192M,2441406250,"8,400,000.0 EFLOPs",3.44064 PFLOPs,8.0,11.8m,30719.9 days,420.53 years | 1.6 gpus,841.06 years | 0.8 gpus,8410.64 years | 0.08 gpus
0,70.00B,20.0T,1.0M,20000000,"8,400,000.0 EFLOPs",420.0 PFLOPs,1061.0,11.8m,231.6 days,3.17 years | 212.2 gpus,6.34 years | 106.1 gpus,63.42 years | 10.61 gpus
0,70.00B,20.0T,2.0M,10000000,"8,400,000.0 EFLOPs",840.0 PFLOPs,2123.0,11.8m,115.8 days,1.58 years | 424.6 gpus,3.17 years | 212.3 gpus,31.69 years | 21.23 gpus


##### Communication time of data parallelism

In [14]:
from constants import FP8_BYTES, BFLOAT16_BYTES
from utils import convert_bytes_to_terabytes
from utils import calculate_comm_time_given_comm_volume, convert_bytes_to_gigabytes
from constants import NVLINK_MAX_TOTAL_BANDWIDTH

Assume that fwd+bwd pass of a single replicas takes 1 seconds

In [15]:
# comm_bandwidths = [0.5*1024**3, 1*1024**3, 4*1024**3] # bytes/sec
# comm_bandwidths = [40*1024**3, NVLINK_MAX_TOTAL_BANDWIDTH] # bytes/sec
comm_bandwidths = [10*1024**3] # bytes/sec
# divide by 10 data centers
# comm_bandwidths = [(4/10)*1024**3] # bytes/sec
# cluster_sizes = [1024, 10240, 102400]
NUM_DATA_CENTERS = 10

data_mem = {
    "Model Size (Params)": [],
    "Global batch size": [],
    # "Number of datacenters": [],
    "Total bfloat16 gradient storage": [],
    "Total fp8 gradient storage": [],
    "Bandwidth": [],
    "Total communication time in bfloat16 - comm/compute ratio": [],
    "Total communication time in fp8 - comm/compute ratio": [],
    "Total GPU idle cost for bfloat16 comm": [],
    "Total GPU idle cost for fp8 comm": [],
    "DiLoCo's total communication time in bfloat16 (500 inner steps) - comm/compute ratio": []
}
# for cluster_size in cluster_sizes:
for bandwidth in comm_bandwidths:
    for global_batch_size in global_batch_sizes:
        for model_size in target_model_sizes:
            _ds_size = 20_000*10**9
            # _ds_size = 100*10**9
            total_steps = calculate_total_steps(model_size, dataset_size=_ds_size, global_batch_size=global_batch_size)
            time_per_step = get_time_per_step(model_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS)
            total_time = calculate_total_time_to_train_a_model(model_size, _ds_size, global_batch_size, time_per_step)
            
            # h100s_per_step = calculate_num_h100s_per_step(model_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS)
            # num_clusters = h100s_per_step // cluster_size

            bfloat16_grad_comm_volume = model_size * BFLOAT16_BYTES
            fp8_grad_comm_volume = model_size * FP8_BYTES
            
            bfloat16_total_comm_time = calculate_comm_time_given_comm_volume(bfloat16_grad_comm_volume, bandwidth, NUM_DATA_CENTERS) * total_steps
            fp8_total_comm_time = calculate_comm_time_given_comm_volume(fp8_grad_comm_volume, bandwidth, NUM_DATA_CENTERS) * total_steps
            bfloat16_diloco_total_comm_time = calculate_comm_time_given_comm_volume(fp8_grad_comm_volume, bandwidth, NUM_DATA_CENTERS) * (total_steps / 500)

            bfloat16_comm_compute_ratio = (bfloat16_total_comm_time / (total_time + bfloat16_total_comm_time)) * 100
            fp8_comm_compute_ratio = (fp8_total_comm_time / (total_time + fp8_total_comm_time)) * 100
            bfloat16_diloco_comm_compute_ratio = (bfloat16_diloco_total_comm_time / (total_time + bfloat16_diloco_total_comm_time)) * 100
            
            bfloat16_total_gpu_idle_cost_comm = ((bfloat16_total_comm_time * h100s_per_step) / (60*60)) / H100_COST_PER_HOUR
            fp8_total_gpu_idle_cost_comm = ((fp8_total_comm_time * h100s_per_step) / (60*60)) / H100_COST_PER_HOUR
            
            data_mem["Model Size (Params)"].append(convert_to_xt_format(model_size))
            data_mem["Global batch size"].append(f'{global_batch_size/1e6}M')
            # data_mem["Number of datacenters"].append(num_clusters)
            data_mem["Total bfloat16 gradient storage"].append(convert_bytes_to_terabytes(model_size * BFLOAT16_BYTES))
            data_mem["Total fp8 gradient storage"].append(convert_bytes_to_terabytes(model_size * FP8_BYTES))
            data_mem["Bandwidth"].append(f"{convert_bytes_to_gigabytes(bandwidth)}/s")
            data_mem["Total communication time in bfloat16 - comm/compute ratio"].append(f"{convert_seconds_to_days(bfloat16_total_comm_time)} / {convert_seconds_to_years(bfloat16_total_comm_time)} - {bfloat16_comm_compute_ratio:.2f}%")
            data_mem["Total communication time in fp8 - comm/compute ratio"].append(f"{convert_seconds_to_days(bfloat16_total_comm_time)} / {convert_seconds_to_years(fp8_total_comm_time)} - {fp8_comm_compute_ratio:.2f}%")
            data_mem["Total GPU idle cost for bfloat16 comm"].append(convert_to_billion_format(bfloat16_total_gpu_idle_cost_comm))
            data_mem["Total GPU idle cost for fp8 comm"].append(convert_to_billion_format(fp8_total_gpu_idle_cost_comm))
            if global_batch_size == 2*10**6:
                assert 1 == 1

            data_mem["DiLoCo's total communication time in bfloat16 (500 inner steps) - comm/compute ratio"].append(f"{convert_seconds_to_days(bfloat16_diloco_total_comm_time)}  - {bfloat16_diloco_comm_compute_ratio:.2f}%")
    
    df_mem = pd.DataFrame(data_mem)

In [16]:
df_mem

Unnamed: 0,Model Size (Params),Global batch size,Total bfloat16 gradient storage,Total fp8 gradient storage,Bandwidth,Total communication time in bfloat16 - comm/compute ratio,Total communication time in fp8 - comm/compute ratio,Total GPU idle cost for bfloat16 comm,Total GPU idle cost for fp8 comm,DiLoCo's total communication time in bfloat16 (500 inner steps) - comm/compute ratio
0,70.00B,0.008192M,0.140 TB,0.070 TB,10.737 GB/s,3684295.7 days / 10087.05 years - 99.17%,3684295.7 days / 5043.53 years - 98.36%,"1,502.00B",751.00B,3684.3 days - 10.71%
1,70.00B,1.0M,0.140 TB,0.070 TB,10.737 GB/s,30181.8 days / 82.63 years - 99.24%,30181.8 days / 41.32 years - 98.49%,12.30B,6.15B,30.2 days - 11.53%
2,70.00B,2.0M,0.140 TB,0.070 TB,10.737 GB/s,15090.9 days / 41.32 years - 99.24%,15090.9 days / 20.66 years - 98.49%,6.15B,3.08B,15.1 days - 11.53%
3,70.00B,4.0M,0.140 TB,0.070 TB,10.737 GB/s,7545.4 days / 20.66 years - 99.24%,7545.4 days / 10.33 years - 98.49%,3.08B,1.54B,7.5 days - 11.53%
4,70.00B,8.0M,0.140 TB,0.070 TB,10.737 GB/s,3772.7 days / 10.33 years - 99.24%,3772.7 days / 5.16 years - 98.49%,1.54B,0.77B,3.8 days - 11.53%
5,70.00B,16.0M,0.140 TB,0.070 TB,10.737 GB/s,1886.4 days / 5.16 years - 99.24%,1886.4 days / 2.58 years - 98.49%,0.77B,0.38B,1.9 days - 11.53%
6,70.00B,32.0M,0.140 TB,0.070 TB,10.737 GB/s,943.2 days / 2.58 years - 99.24%,943.2 days / 1.29 years - 98.49%,0.38B,0.19B,0.9 days - 11.53%


#### Communication latency (theoretical minimum)

Assumptions on communication
- No limit on banwidth
- Achieve speed of light
- Clostest surface distance between two points on the earth surface (assume you don't dig a crazy hole to go a straight line)

In [17]:
minimum_latency_between_jz_and_jc = compute_minimum_latency_between_clusters("JEAN_ZAY", "JOLIOT_CURIE")
minimum_latency_between_jz_and_ec = compute_minimum_latency_between_clusters("JEAN_ZAY", "EL_CAPITAN")

In [18]:
dataframes_comm = []

for global_batch_size in global_batch_sizes:
    data_comm = {
        "Model Size (Params)": [],
        "Dataset Size (Tokens)": [],
        "Global Batch Size": [],
        
        "Total minimum communication latency between JZ and JC": [],
        "Total GPU idle time for minimum comm between JZ and JC": [],
        "GPU idle cost during JZ-JC minimum communication latency": [],
        
        "Total minimum communication latency between JZ and EC": [],
        "Total GPU idle time for minimum comm between JZ and EC": [],
        "GPU idle cost during JZ-EC minimum communication latency": [],
    }
    
    for model_size in target_model_sizes:
        dataset_size = get_dataset_size_from_model_size(model_size)
        # Append to dictionary
        data_comm["Model Size (Params)"].append(convert_to_xt_format(model_size))
        data_comm["Dataset Size (Tokens)"].append(convert_to_xt_format(dataset_size))
        data_comm["Global Batch Size"].append(f'{global_batch_size/1e6}M')
        
        data_comm["Total minimum communication latency between JZ and JC"].append(calculate_total_minimum_comm_latency_to_train_a_model(model_size, global_batch_size, minimum_latency_between_jz_and_jc))
        data_comm["Total GPU idle time for minimum comm between JZ and JC"].append(convert_seconds_to_days(calculate_total_minimum_comm_latency_to_train_a_model(model_size, global_batch_size, minimum_latency_between_jz_and_jc) * h100s_per_step))
        data_comm["GPU idle cost during JZ-JC minimum communication latency"].append(convert_to_million_format((calculate_total_minimum_comm_latency_to_train_a_model(model_size, global_batch_size, minimum_latency_between_jz_and_jc) / (60 * 60)) * h100s_per_step * H100_COST_PER_HOUR))
        
        data_comm["Total minimum communication latency between JZ and EC"].append(calculate_total_minimum_comm_latency_to_train_a_model(model_size, global_batch_size, minimum_latency_between_jz_and_ec))
        data_comm["Total GPU idle time for minimum comm between JZ and EC"].append(convert_seconds_to_years(calculate_total_minimum_comm_latency_to_train_a_model(model_size, global_batch_size, minimum_latency_between_jz_and_ec) * h100s_per_step))
        data_comm["GPU idle cost during JZ-EC minimum communication latency"].append(convert_to_million_format((calculate_total_minimum_comm_latency_to_train_a_model(model_size, global_batch_size, minimum_latency_between_jz_and_ec) / (60 * 60)) * h100s_per_step * H100_COST_PER_HOUR))
    
    # Convert to DataFrame
    df_comm = pd.DataFrame(data_comm)
    # df_comm['Global Batch Size'] = f'{global_batch_size/1e6}M'
    dataframes_comm.append(df_comm)

final_df_comm = pd.DataFrame()
for i, df in enumerate(dataframes_comm):
    final_df_comm = pd.concat([final_df_comm, df])

TypeError: calculate_total_steps() missing 1 required positional argument: 'global_batch_size'

In [None]:
final_df_comm

#### Electricity

In [19]:
from constants import TOTAL_H100_WATT
from utils import convert_watts_to_megawatts, convert_watts_to_terawatts, calculate_electricity_consumption_of_an_h100

In [18]:
dataframes_elec = []

for global_batch_size in global_batch_sizes:
    data_elec = {
        "Model Size (Params)": [],
        "Dataset Size (Tokens)": [],
        "Global batch size": [],
        "Number of GPUs": [],
        "Total electricity per step (without grad accum)": [],
        "Total electricity for the entire training (without grad accum)": []
    }
    
    for model_size in target_model_sizes:
        dataset_size = get_dataset_size_from_model_size(model_size)
        h100s_per_step = calculate_num_h100s_per_step(model_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS)
        time_per_step = get_time_per_step(model_size, global_batch_size, UTILIZED_BFLOAT16_FLOPS)
        total_time = calculate_total_time_to_train_a_model(model_size, global_batch_size, time_per_step)
        total_electricity_consumption = calculate_electricity_consumption_of_an_h100(TOTAL_H100_WATT, total_time) * h100s_per_step
        
        data_elec["Model Size (Params)"].append(convert_to_xt_format(model_size))
        data_elec["Dataset Size (Tokens)"].append(convert_to_xt_format(dataset_size))
        data_elec["Global batch size"].append(f'{global_batch_size/1e6}M')
        data_elec["Number of GPUs"].append(h100s_per_step)
        data_elec["Total electricity per step (without grad accum)"].append(convert_watts_to_megawatts(h100s_per_step * TOTAL_H100_WATT))
        data_elec["Total electricity for the entire training (without grad accum)"].append(f"{convert_watts_to_terawatts(total_electricity_consumption)}")
    
    df = pd.DataFrame(data_elec)
    # df['Global Batch Size'] = f'{global_batch_size/1e6}M'
    dataframes_elec.append(df)

final_df_elec = pd.DataFrame()
for i, df in enumerate(dataframes_elec):
    final_df_elec = pd.concat([final_df_elec, df])

In [19]:
# my calculation closes to 100k gpu cluster's electricity: https://semianalysis.com/2024/06/17/100000-h100-clusters-power-network/#power-challenges

In [None]:
final_df_elec