In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

api = wandb.Api()

In [None]:
WANDB_USERNAME = "your-username"
WANDB_PROJECT = "your-wandb-project"

def load_run(name):
    runs = api.runs(
        path=f"{WANDB_USERNAME}/{WANDB_PROJECT}",
        filters={"displayName": name}
    )
    if len(runs) == 0:
        raise ValueError(f"run {WANDB_PROJECT}/{name} not found")

    if len(runs) > 1:
        raise ValueError(f"Too many runs {WANDB_PROJECT}/{name}: {len(runs)}")
    
    return runs[0]

In [None]:
RUNS = {
    "No filter": ("", "darkgray"),
    "Weak filter": ("", "C9"),
    "Strict filter": ("", "C0"),
    "SGTM (Ours)": ("", "C1"),
}


FORGET_CATEGORY = "STEM.Biology"

CATEGORIES = {
    "retain": ["History_and_Society", "Culture", "Geography"],
    "adjacent": ["STEM.Earth_and_environment", "STEM.Chemistry", "STEM.Medicine_&_Health"],
}

In [None]:
def draw2d(run, label, retain_type, steps, color=None):
    df = run.history()

    df = df.dropna(subset=["eval.retain.Geography.loss"])
    df = df[df["_step"].isin(steps)]

    retain_columns = [f"eval_calibrated.{retain_type}.{cat}.loss" for cat in CATEGORIES[retain_type]]
    forget_column = f"eval_calibrated.forget.{FORGET_CATEGORY}.loss"
    df["avg_retain_loss"] = df[retain_columns].mean(axis=1)

    retain_losses = []
    for _, row in df.iterrows():
        retain_loss_values = [row[col] for col in retain_columns]
        retain_losses.append(sum(retain_loss_values) / len(retain_loss_values))

    forget_losses = []
    for _, row in df.iterrows():
        forget_losses.append(row[forget_column])

    x_data = retain_losses
    y_data = forget_losses

    plt.plot(x_data, y_data, "o-", markersize=12, alpha=1.0, label=label, linewidth=6, color=color)
    plt.scatter(x_data[-1], y_data[-1], marker="*", s=600, color=plt.gca().lines[-1].get_color(), edgecolors='black', linewidths=1, zorder=10)

In [None]:
sns.set_palette("muted")

STEPS = np.concatenate([np.arange(0, 9000, step=1000), [9689]])


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

plt.sca(ax1)
for label, (name, color) in RUNS.items():
    draw2d(load_run(name=name), label=label, retain_type="retain", steps=STEPS, color=color)

plt.xlabel("Culture / Geography / History Loss", fontsize=16)
plt.ylabel("Biology Loss (→ better)", fontsize=16)
plt.legend(loc="lower right", fontsize=16)
plt.grid(alpha=0.3)
plt.xlim(2.54, 3.1)
plt.ylim(2.2, 4.0)

plt.sca(ax2)
for label, (name, color) in RUNS.items():
    draw2d(load_run(name=name), label=label, retain_type="adjacent", steps=STEPS, color=color)

plt.xlabel("Medicine / Chemistry / Environment Loss", fontsize=16)
plt.ylabel("Biology Loss (→ better)", fontsize=16)
plt.legend(loc="lower right", fontsize=16)
plt.grid(alpha=0.3)
plt.xlim(2.86, 3.5)
plt.ylim(2.2, 4.0)

plt.tight_layout()
plt.show()