In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from collections import defaultdict

api = wandb.Api()

In [2]:
def find_run(all_runs, prefix):
    matching_runs = [run for run in all_runs if run.name.startswith(prefix) and run._state == "finished"]
    return sorted(matching_runs, key=lambda run: run.createdAt)[-1] if matching_runs else None

In [None]:
WANDB_PROJECT = "your-wandb-project"
WANDB_USER = "your-username"

all_runs = api.runs(f"{WANDB_USER}/{WANDB_PROJECT}")

DATAFILTER_PATTERN = "ts_datafilter_pr{precision}_rec{recall}_"
SGTM_PATTERN = "ts_parameter_masking_mlp64_h1_ret25_pr{precision}_rec{recall}_"
GR_PATTERN = "ts_gradient_routing_mlp64_h1_ret25_pr{precision}_rec{recall}_"

In [None]:
sns.set_palette("muted")
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

X_VALS = np.array([0, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0])

linewidth=6
markersize=12

PATTERNS = {
    "datafilter": DATAFILTER_PATTERN,
    "sgtm": SGTM_PATTERN,
    "gr": GR_PATTERN,
}

for i, dataset_type in enumerate(('retain', 'forget')):
    
    y = defaultdict(list)
    for unlabeled_rate in X_VALS:
        recall=round(1-unlabeled_rate,3)
        for rtype in PATTERNS:
            run_info = find_run(all_runs=all_runs, prefix=PATTERNS[rtype].format(precision=1.0, recall=recall))
            if run_info is None:
                y[rtype].append(None)
                continue

            y[rtype].append(run_info.summary['eval_calibrated'][dataset_type]['main']['loss'])

    ax = axes[(i+1)%2]
    ax.plot(X_VALS*100, y["datafilter"], label="Data filtering", marker='o', markersize=markersize, linewidth=linewidth, color="C0")
    ax.plot(X_VALS*100, y["gr"], label="Gradient Routing (Cloud et al.)", marker='o', markersize=markersize, linewidth=linewidth, color="C5")
    ax.plot(X_VALS*100, y["sgtm"], label="SGTM (Ours)", marker='o', markersize=markersize, linewidth=linewidth, color="C1")

    ax.grid(True, alpha=0.3)
    ax.set_xlabel('% Undiscovered (forget set)', fontsize=16)
    ax.set_xscale('symlog', linthresh=1)
    ax.set_yscale('log')
    ax.set_xticks([0, 1, 10, 100])
    ax.set_xticklabels(['0', '1', '10', '100'])
    ax.set_ylabel("Forget Loss (higher better)\n(Spanish)", fontsize=16)     


axes[0].set_yticks([1, 2, 4, 8], ['1', '2', '4', '8'])

yticks = np.arange(1.3, 1.47, step=0.04)
axes[1].set_yticks(yticks, [f"{x:.2f}" for x in yticks])
axes[1].set_yticks([], minor=True)
axes[1].set_ylabel("Retain Loss (lower better)\n(English)", fontsize=16)
axes[1].legend(loc="upper left", fontsize=16)
axes[1].set_ylim(1.30, 1.46)

plt.tight_layout()
plt.show()
