In [2]:
import numpy as np
import scipy.optimize as opt

# Constants from the Chinchilla scaling law
A = 406.4
B = 410.7 
alpha = 0.34
beta = 0.28

# Loss function
def loss(params, tokens, A=A, B=B, alpha=alpha, beta=beta):
    return 1.69 + A/(params**alpha) + B/(tokens**beta)

# Compute budget function (in FLOPs)
def compute_budget(params, tokens):
    return 6 * params * tokens

# Objective function to minimize (loss)
def objective(x):
    params, tokens = x
    return loss(params, tokens)

# Budget constraint function 
def budget_constraint(x):
    params, tokens = x
    return budget_70b_20t - compute_budget(params, tokens)

# Compute budget for a 70B model trained on 20T tokens
budget_70b_20t = compute_budget(70e9, 200000e9)

# Bounds for model size and dataset size
bounds = ((7e10, 1e12), (1e12, 2e13))  # 70B to 1T params, 1T to 20T tokens

# Initial guess
x0 = np.array([7e10, 2e13])  # 70B params, 20T tokens

# Solve constrained optimization problem
result = opt.minimize(objective, x0, method='SLSQP', 
                      bounds=bounds,
                      constraints=({'type': 'ineq', 'fun': budget_constraint}))

# Extract optimal parameters
params_opt, tokens_opt = result.x

print(f"Optimal model size: {params_opt/1e9:.2f} billion parameters")
print(f"Optimal dataset size: {tokens_opt/1e12:.2f} trillion tokens")
print(f"Minimum loss: {loss(params_opt, tokens_opt):.3f}")

Optimal model size: 70.00 billion parameters
Optimal dataset size: 20.00 trillion tokens
Minimum loss: 1.851
