In [None]:
import wandb

wandb.login()

api = wandb.Api()

project_name = "dion"  # Replace with your project name
entity_name = "natalieabreu"  # Replace with your wandb entity/username
runs = api.runs(f"{entity_name}/{project_name}")
runs = [run for run in runs if run.state == "finished"]

for run in runs:
    print(f"Run ID: {run.id}, Name: {run.name}, State: {run.state}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot(runs):
    large_font = 14
    small_font = 12
    plt.rc("xtick", labelsize=small_font)
    plt.rc("ytick", labelsize=small_font)
    plt.rc("axes", labelsize=large_font)
    plt.rc("legend", fontsize=large_font)

    labels = {
        "dion_True": "Dion+Adam",
        "dion_False": "Dion+Lion",
        "muon_True": "Muon+Adam",
        "muon_False": "Muon+Lion",
    }

    # Iterate through each run in lowest_loss_runs
    for _, row in runs.iterrows():
        run_id = row["Run ID"]
        run = api.run(f"{entity_name}/{project_name}/{run_id}")
        history = run.history(keys=["step", "val/loss"])

        # Plot step vs val/loss
        lbl_key = f"{row['Optimizer']}_{row['Adam for scalar']}"
        plt.plot(
            history["step"],
            history["val/loss"],
            label=f"{labels.get(lbl_key, 'Unknown')}",
        )

    plt.xscale("log")
    plt.yscale("log")

    plt.gca().yaxis.set_minor_locator(plt.NullLocator())
    plt.gca().xaxis.set_minor_locator(plt.NullLocator())

    yticks = np.arange(3, 4.1, 0.2)
    plt.yticks(yticks)
    plt.gca().yaxis.set_major_formatter(
        plt.FuncFormatter(lambda x, _: "{:.1f}".format(x))
    )
    plt.ylim(3, 4.1)

    xticks = np.arange(0, 3.1e3, 5e2)
    plt.xticks(xticks)
    plt.gca().xaxis.set_major_formatter(
        plt.FuncFormatter(lambda x, _: "{:.0f}".format(x))
    )

    plt.grid(axis="both")
    plt.xlabel("Step")
    plt.ylabel("Validation Loss")
    # plt.title("Step vs Validation Loss for Lowest Loss Runs")
    plt.legend()
    plt.savefig("lion_vs_adam.pdf", bbox_inches="tight", format="pdf")
    plt.show()

In [None]:
import pandas as pd

# Extract relevant information from runs
data = []
for run in runs:
    config = run.config  # Access the run's configuration
    data.append(
        {
            "Run ID": run.id,
            "Optimizer": config.get("optimizer", None),
            "Learning Rate": config.get("lr", None),
            "Num Iterations": config.get("num_iterations", None),
            "LR Schedule": config.get("lr_schedule", None),
            "Mu": config.get("mu", 0.95),
            "Weight Decay": config.get("weight_decay", 0),
            "Sparsity": config.get("sparsity", 1),
            "Loss": run.summary.get("val/loss", None),
            "Batch size": config.get("batch_size", None),
            "Training set": config.get("input_bin", None),
            "Adam for scalar": config.get("use_adam_for_scalar", True),
        }
    )

# Create a pandas DataFrame
runs_df = pd.DataFrame(data)

# Display the DataFrame
print(runs_df)

In [None]:
bsz = 1024
sparsity = 1
num_iterations = 3000
optimizers = ["muon", "dion"]

# Filter the DataFrame based on the specified criteria
filtered_df = runs_df[
    (runs_df["Batch size"] == bsz)
    & (runs_df["Sparsity"] == sparsity)
    & (runs_df["Num Iterations"] == num_iterations)
    & (runs_df["Optimizer"].isin(optimizers))
]

print(filtered_df["Adam for scalar"].value_counts())

In [None]:
grouped = filtered_df.groupby(["Optimizer", "Adam for scalar"])

# Get the row with the lowest loss for each group
lowest_loss_per_group = grouped.apply(lambda x: x.loc[x["Loss"].idxmin()])

# Display the results
print(lowest_loss_per_group)

In [None]:
plot(lowest_loss_per_group)