In [6]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.interpolate import PchipInterpolator

In [4]:
# Load CSV exported from wandb
data = pd.read_csv("data.csv")

# remove columns with name ending in __MIN or __MAX
selected_cols = [
    col
    for col in data.columns
    if not col.endswith("__MIN")
    and not col.endswith("__MAX")
    and not col.endswith("_step")
]
data = data.loc[:, selected_cols]

print(f"Data columns: {data.columns}")

Data columns: Index(['step', '(muon_moonlight+adam)_bs=2048_lr=0.01_sp=1.0 - val/loss',
       '(adam+adam)_bs=2048_lr=0.0016_sp=1.0 - val/loss',
       '(dion+lion)_bs=2048_lr=0.01_sp=1.0 - val/loss',
       '(muon_moonlight+adam)_bs=1024_lr=0.01_sp=1.0 - val/loss',
       '(dion+lion)_bs=1024_lr=0.01_sp=1.0 - val/loss',
       '(adam+adam)_bs=1024_lr=0.002_sp=1.0 - val/loss'],
      dtype='object')


In [8]:
# Get data exported to CSV


def get_column(data, name):
    col_name = [x for x in data.columns if name in x]
    assert len(col_name) == 1, f"Column {name} not found"
    return data[col_name[0]].dropna().values


step_120 = get_column(data, "step")
dion_120 = get_column(data, "(dion+lion)_bs=1024")
muon_120 = get_column(data, "(muon_moonlight+adam)_bs=1024")
adam_120 = get_column(data, "(adam+adam)_bs=1024")

step_350 = step_120
dion_350 = get_column(data, "(dion+lion)_bs=2048")
muon_350 = get_column(data, "(muon_moonlight+adam)_bs=2048")
adam_350 = get_column(data, "(adam+adam)_bs=2048")

In [48]:
# Get data from log files


def load_jsonl_log(file):
    with open(file, "r") as f:
        data = []
        for line in f:
            x = json.loads(line)
            data.append(x)
    return data


def get_fields(data, fields):
    output = [list() for _ in range(len(fields))]
    for row in data:
        if not all(field in row for field in fields):
            continue
        for i, field in enumerate(fields):
            output[i].append(row[field])
    return [np.array(x) for x in output]


step_1b, adam_1b = get_fields(load_jsonl_log("adam1b.log"), ["step", "val/loss"])
_, muon_1b = get_fields(load_jsonl_log("muon1b.log"), ["step", "val/loss"])
_, dion_1b = get_fields(load_jsonl_log("dion1b.log"), ["step", "val/loss"])

step_3b, adam_3b = get_fields(load_jsonl_log("adam3b.log"), ["step", "val/loss"])
_, muon_3b = get_fields(load_jsonl_log("muon3b.log"), ["step", "val/loss"])
_, dion_3b = get_fields(load_jsonl_log("dion3b.log"), ["step", "val/loss"])

In [60]:
def create_interp(step, losses):
    step = step[: len(losses)]
    idx = np.argsort(losses)
    interp = PchipInterpolator(losses[idx], step[idx])
    return interp


def interpolate(dion_interp, muon_interp, adam_interp, target_loss):
    dion_steps = dion_interp(target_loss)
    muon_steps = muon_interp(target_loss)
    adam_steps = adam_interp(target_loss)

    print(f"Interpolated steps for target loss {target_loss}:")
    print(f"  dion: {dion_steps:.0f}")
    print(f"  muon: {muon_steps:.0f}")
    print(f"  adam: {adam_steps:.0f}")
    print(f"  dion speedup: {adam_steps / dion_steps}")
    print(f"  muon speedup: {adam_steps / muon_steps}")

In [91]:
# 120M model
# Total training steps 3000
dion_120_interp = create_interp(step_120, dion_120)
muon_120_interp = create_interp(step_120, muon_120)
adam_120_interp = create_interp(step_120, adam_120)
interpolate(dion_120_interp, muon_120_interp, adam_120_interp, target_loss=3.52)

Interpolated steps for target loss 3.52:
  dion: 1450
  muon: 2087
  adam: 2578
  dion speedup: 1.7785704666731437
  muon speedup: 1.2353678048364678


In [62]:
# 350M model
# Total training steps 4000
dion_350_interp = create_interp(step_350, dion_350)
muon_350_interp = create_interp(step_350, muon_350)
adam_350_interp = create_interp(step_350, adam_350)
interpolate(dion_350_interp, muon_350_interp, adam_350_interp, target_loss=3.23)

Interpolated steps for target loss 3.23:
  dion: 1997
  muon: 2555
  adam: 3443
  dion speedup: 1.7242330385248117
  muon speedup: 1.3477499525554368


In [88]:
# 120M model
# Total training steps 3000
dion_120_interp = create_interp(step_120, dion_120)
muon_120_interp = create_interp(step_120, muon_120)
adam_120_interp = create_interp(step_120, adam_120)
interpolate(dion_120_interp, muon_120_interp, adam_120_interp, target_loss=3.56)

Interpolated steps for target loss 3.56:
  dion: 1241
  muon: 1722
  adam: 2397
  dion speedup: 1.932127365400116
  muon speedup: 1.3919431211352713


In [85]:
# 350M model
# Total training steps 4000
dion_350_interp = create_interp(step_350, dion_350)
muon_350_interp = create_interp(step_350, muon_350)
adam_350_interp = create_interp(step_350, adam_350)
interpolate(dion_350_interp, muon_350_interp, adam_350_interp, target_loss=3.27)

Interpolated steps for target loss 3.27:
  dion: 1674
  muon: 2122
  adam: 3198
  dion speedup: 1.9100783620862307
  muon speedup: 1.5069891181768138


In [82]:
# 1B model
# Total training steps 7500
dion_1b_interp = create_interp(step_1b, dion_1b)
muon_1b_interp = create_interp(step_1b, muon_1b)
adam_1b_interp = create_interp(step_1b, adam_1b)
interpolate(dion_1b_interp, muon_1b_interp, adam_1b_interp, target_loss=2.81)

Interpolated steps for target loss 2.81:
  dion: 3578
  muon: 4353
  adam: 6024
  dion speedup: 1.6837048128307963
  muon speedup: 1.38405210941899


In [75]:
# 3B model
# Total training steps 7500
dion_3b_interp = create_interp(step_3b, dion_3b)
muon_3b_interp = create_interp(step_3b, muon_3b)
adam_3b_interp = create_interp(step_3b, adam_3b)
interpolate(dion_3b_interp, muon_3b_interp, adam_3b_interp, target_loss=2.63)

Interpolated steps for target loss 2.63:
  dion: 4069
  muon: 4215
  adam: 6096
  dion speedup: 1.4980010825189116
  muon speedup: 1.4459815582424187
