In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

api = wandb.Api()

In [None]:
WANDB_USERNAME = "your-username"
WANDB_PROJECT = "your-wandb-project"

In [None]:
def load_run(name):
    runs = api.runs(
        path=f"{WANDB_USERNAME}/{WANDB_PROJECT}",
        filters={"displayName": name}
    )
    if len(runs) == 0:
        raise ValueError(f"run {WANDB_PROJECT}/{name} not found")

    if len(runs) > 1:
        raise ValueError(f"Too many runs {WANDB_PROJECT}/{name}: {len(runs)}")
    
    return runs[0]

In [None]:
# Fill in run names for respective runs
# Format: label -> (wandb_run_name, color)
RUNS = {
    "SGTM (Ours)": ("", "C1"),
    "Gradient Routing (Cloud et al.)": ("", "C5"),
    "Perfect filter": ("", "C0"),
    "99% filter": ("", "C9"),
    "No filter": ("","darkgray"),
}


In [None]:
def draw2d(run, label, steps=None, color=None):
    df = run.history(samples=1000)

    # Filter rows where we didn't do eval
    df = df.dropna(subset=['eval.forget.main.loss'])

    if steps:
        df = df[df['_step'].isin(steps)]
    
    retain_column = 'eval_calibrated.retain.main.loss'
    forget_column = 'eval_calibrated.forget.main.loss'

    # Hardcode random model loss
    forget_losses = [8]
    retain_losses = [8]

    for _, row in df.iterrows():
        forget_losses.append(row[forget_column])
        retain_losses.append(row[retain_column])

    x_data = retain_losses
    y_data = forget_losses

    plt.plot(x_data, y_data, label=label, markersize=12, marker='o', linewidth=6, alpha=1.0, color=color)
    # Add a star at the end of the line
    plt.scatter(x_data[-1], y_data[-1], marker='*', s=600, 
                edgecolors='black', linewidth=1,
                color=plt.gca().lines[-1].get_color(), zorder=10)

    # Return training flops for isomorphic lines
    return (x_data, y_data, df["train/flops"])

sns.set_palette("muted")
plt.figure(figsize=(6.2, 6.2))


steps =  [1000, 2000, 6000, 14000, 20000, 26000, 33000]
for i, (label, (name,color)) in enumerate(RUNS.items()):
    run = load_run(name)
    x_data, y_data, flops = draw2d(run, label=label, steps=steps, color=color)

plt.xlabel('Retain Loss (lower better)\n(English)', fontsize=16)
plt.ylabel('Forget Loss (higher better)\n(Spanish)', fontsize=16)

plt.legend(loc="lower right", framealpha=1.0, fontsize=16)
plt.grid(True, alpha=0.3)
plt.xlim(1.25, 2.0)
plt.ylim(0, 8)
plt.tight_layout()
plt.show()