In [None]:
import wandb
import pandas as pd
import json

In [None]:
api = wandb.Api()

columns = ["model", "test loss"]
data = []

projects = ["DL2/erwin-cosmology-2"]

for project in projects:
    runs = api.runs(project)
    for run in runs:
        if run.state != "finished":
            continue

        cfg = json.loads(run.json_config)

        if cfg["size"]["value"] != "medium":
            continue

        model_name = run.name

        if "full_k" in model_name:
            model_name = "Full Attention"
        elif "base" in model_name:
            model_name = "Erwin"
        else:
            _, _, topk, _ = model_name.split("_")
            model_name = f"NSA k={topk}"

        df = run.history(samples=1, keys=["test/avg/loss"])
        test_loss = df.iloc[0]["test/avg/loss"]

        data.append([model_name, test_loss])

In [None]:
data = pd.DataFrame(data, columns=columns)
data = data.sort_values(by=columns).reset_index(drop=True)
data = data.groupby("model").agg(["mean", "std"])

for model_name, score in data.iterrows():
    avg = score.loc["test loss", "mean"]
    std = score.loc["test loss", "std"]
    print(f"{model_name} & \\({avg:.3f} \pm {std:.3f}\\) \\\\")