In [12]:
import pandas as pd
import numpy as np

In [13]:
A = 406.4
B = 410.7

alpha = 0.34
beta = 0.28

In [14]:
# loss as a function of params and data
# cf. eqn 10, https://arxiv.org/pdf/2203.15556.pdf

def calculate_loss(n, d, A_=A, B_=B, alpha_=alpha, beta_=beta):
    # n, d in billions
    n = n
    d = d
    model_size_contribution = A_/(n ** alpha_)
    dataset_size_contribution = B_/(d ** beta_)
    loss = 1.69 + model_size_contribution + dataset_size_contribution
    return loss, (model_size_contribution, dataset_size_contribution)

In [15]:
def calculate_flops(n, d):
    # n, d in billions
    n = n / 1e9
    d = d / 1e9
    return 6 * n * d

In [16]:
# Base model parameters
base_model_size = 70*10**9  # in billions
base_tokens = 1400*10**9  # in billions
base_loss, (base_msc, base_dsc) = calculate_loss(base_model_size, base_tokens)
base_flops = calculate_flops(base_model_size, base_tokens)

In [17]:
base_loss, base_msc, base_dsc

(1.9366454705587173, 0.08348729030772284, 0.1631581802509945)

In [18]:
from utils import convert_to_billion_format, convert_to_xt_format

In [23]:
# Create array of token values from 100B to 20T
tokens = np.array([x*10**9 for x in [1500, 2000, 4000, 10000, 20000]])
# model_sizes = np.array([x*10**9 for x in [70, 140, 280, 560, 1120]])
model_sizes = np.array([x*10**9 for x in [70, 100]])

# Calculate values for each combination
results = []
for m in model_sizes:
    for t in tokens:
        # t_billions = t * 1000  # convert trillion to billion
        loss, (msc, dsc) = calculate_loss(m, t)

        # Calculate percentage contributions
        overall_improvement = base_loss - loss
        msc_improvement = base_msc - msc
        dsc_improvement = base_dsc - dsc
        
        msc_percent = (msc_improvement / overall_improvement) * 100
        dsc_percent = (dsc_improvement / overall_improvement) * 100

        # Calculate FLOPS and percentage increase in FLOPS
        flops = calculate_flops(m, t)
        if flops < 0:
            assert 1 == 1
        flops_increase_percent = ((flops - base_flops) / base_flops)

        # Calculate rate of increase in FLOPS / rate of increase in overall loss
        flop_increase_ratio = flops/base_flops
        ratio_model_size_increase_vs_model_loss_increase = (m/base_model_size) / msc_improvement
        ratio_data_size_increase_vs_data_loss_increase = (t/base_tokens) / dsc_improvement
        # loss_decrease_ratio = overall_improvement
        # rate_increase_flops_loss = overall_improvement / flops_increase_percent

        results.append([
            convert_to_billion_format(m), convert_to_xt_format(t),
            loss, msc, dsc,
            overall_improvement, f"{msc_improvement:.3f} - {msc_percent:.2f}%", f"{dsc_improvement:.3f} - {dsc_percent:.2f}%",
            # msc_percent, dsc_percent,
            # flops,
            flops_increase_percent,
            ratio_model_size_increase_vs_model_loss_increase,
            ratio_data_size_increase_vs_data_loss_increase,
            # rate_increase_flops_loss
        ])

  ratio_model_size_increase_vs_model_loss_increase = (m/base_model_size) / msc_improvement


In [24]:
# Create DataFrame
df = pd.DataFrame(results, columns=[
    'Model Size (B)', 'Tokens (B)',
    'Loss', 'Model Size\'s loss', 'Dataset Size\'s loss',
    "Overall Improvement", "Model Size\'s improvement in loss", "Dataset Size\'s improvement in loss", 
    # 'Model Size Contribution %', 'Dataset Size Contribution %',
    # 'FLOPS',
    'FLOPS Increase %',
    "ratio_model_size_increase_vs_model_loss_increase",
    "ratio_data_size_increase_vs_data_loss_increase"
    # 'Rate of Increase (FLOPS / Loss)'
])

# Format numbers to 4 decimal places
df = df.round(7)

In [25]:
base_loss, base_msc, base_dsc

(1.9366454705587173, 0.08348729030772284, 0.1631581802509945)

In [26]:
df

Unnamed: 0,Model Size (B),Tokens (B),Loss,Model Size's loss,Dataset Size's loss,Overall Improvement,Model Size's improvement in loss,Dataset Size's improvement in loss,FLOPS Increase %,ratio_model_size_increase_vs_model_loss_increase,ratio_data_size_increase_vs_data_loss_increase
0,70.00B,1.5T,1.933524,0.083487,0.160036,0.003122,0.000 - 0.00%,0.003 - 100.00%,0.071429,inf,343.22603
1,70.00B,2.0T,1.921138,0.083487,0.147651,0.015507,0.000 - 0.00%,0.016 - 100.00%,0.428571,inf,92.123044
2,70.00B,4.0T,1.895091,0.083487,0.121604,0.041554,0.000 - 0.00%,0.042 - 100.00%,1.857143,inf,68.757071
3,70.00B,10.0T,1.867573,0.083487,0.094086,0.069072,0.000 - 0.00%,0.069 - 100.00%,6.142857,inf,103.411393
4,70.00B,20.0T,1.850976,0.083487,0.077488,0.08567,0.000 - 0.00%,0.086 - 100.00%,13.285714,inf,166.753115
5,100.00B,1.5T,1.923989,0.073953,0.160036,0.012656,0.010 - 75.34%,0.003 - 24.66%,0.530612,149.829497,343.22603
6,100.00B,2.0T,1.911604,0.073953,0.147651,0.025042,0.010 - 38.07%,0.016 - 61.93%,1.040816,149.829497,92.123044
7,100.00B,4.0T,1.885557,0.073953,0.121604,0.051089,0.010 - 18.66%,0.042 - 81.34%,3.081633,149.829497,68.757071
8,100.00B,10.0T,1.858039,0.073953,0.094086,0.078607,0.010 - 12.13%,0.069 - 87.87%,9.204082,149.829497,103.411393
9,100.00B,20.0T,1.841441,0.073953,0.077488,0.095204,0.010 - 10.01%,0.086 - 89.99%,19.408163,149.829497,166.753115
