In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

api = wandb.Api()

In [None]:
WANDB_PROJECT = "your-wandb-project"
WANDB_USER = "your-username"

def load_run(project, name):
    runs = api.runs(
        path=f"{WANDB_USER}/{project}",
        filters={"displayName": name}
    )
    if len(runs) == 0:
        raise ValueError(f"run {project}/{name} not found")

    if len(runs) > 1:
        raise ValueError(f"Too many runs {project}/{name}: {len(runs)}")
    
    return runs[0]

In [None]:
sns.set_palette("muted")
STEPS = np.arange(500, step=25)
COLUMN = "eval.forget.STEM.Biology.loss"


runs = {
    "RMU": ("", "C6"),
    "Weak filter":  ("", "C9"),
    "Strict filter":  ("", "C0"),
    "SGTM (Ours)": ("", "C1"),
}

no_filter_run = load_run(project=WANDB_PROJECT, name="")
no_filter_loss = no_filter_run.summary['final/eval']["forget"]["STEM.Biology"]["loss"]


plt.figure(figsize=(6, 6))
for label in runs:
    if label == "No filter":
        continue

    name, color = runs[label]
    run = load_run(project=WANDB_PROJECT, name=name)
    df = run.history()

    filtered_df = df[df['_step'].isin(STEPS) & df[COLUMN].notna()]
    steps = filtered_df['_step'].dropna().values
    values = filtered_df[COLUMN].dropna().values

    plt.plot(steps, values, linewidth=4, label=label, color=color)

plt.axhline(y=no_filter_loss, color='slategrey', linestyle='--', linewidth=4, label='No filter', zorder=0)

plt.xlabel('Finetuning Step', fontsize=16)
plt.ylabel("Forget Loss (higher better)", fontsize=16)
plt.grid(True, alpha=0.3)
plt.ylim(2, 4.0)
plt.legend(fontsize=16)
plt.tight_layout()
plt.show()
