In [None]:
from analysis.utils import fetch_runs, get_runs_data, differing_config
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
runs = fetch_runs(tags_any=["ICLR-Fig1"])
df = get_runs_data(runs, metrics=["val_loss", "teacher_val_loss"])

In [None]:
teacher_val_loss = df["teacher_val_loss"].unique()[0]
print(f"Teacher val loss: {teacher_val_loss}")

In [None]:
plt.figure(figsize=(10, 6))
groups = df.groupby(["_run_name", "cfg.student.query_init_scale"], dropna=False)["val_loss"]

def sort_key(item):
    scale = item[0][1]
    return (0, 0) if pd.isna(scale) else (1, -float(scale))

groups = sorted(groups, key=sort_key)
labels_map = {1e-3: "1e-3", 1e-6: "1e-6", 1e-9: "1e-9", np.nan: "Baseline", 0: "0"}

for (name, query_init_scale), group in groups:
    data = group.tolist() - teacher_val_loss
    data = data[:1000]
    if np.isnan(query_init_scale):
        plt.plot(data, label=labels_map[np.nan], linestyle="--", color="black", linewidth=2)
    else:
        plt.plot(data, label=labels_map[query_init_scale], linewidth=3)

plt.ylabel("Excess Loss", fontsize=22)
plt.xlabel("Training Step", fontsize=22)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.legend(fontsize=22)
plt.savefig("figures/minimal-transformer-init_scales.pdf", bbox_inches="tight")

In [None]:
df = df.drop(columns=["cfg.teacher.span_lengths"])
differing_config(df)