In [1]:
from notebooks.utils import fetch_runs, get_runs_data, differing_config

  "kl_div_prefix_1_teacher_val": "$A_0\,x_{t-4}$",
  "kl_div_prefix_2_teacher_val": "$A_0\,x_{t-4} + A_1\,x_{t-3}$",
  "kl_div_prefix_3_teacher_val": "$A_0\,x_{t-4} + A_1\,x_{t-3} + A_2\,x_{t-2}$",


In [2]:
runs = fetch_runs(
    entity="r-alvarezlucendo16", project="incremental-learning", tags_any=["ICLR-minimal-dataset"]
)
df = get_runs_data(runs, metrics=[
    "val_loss", 
    "teacher_val_loss",
    "val_best",
    "kl_div_unigram_learned_val", 
    "kl_div_bigram_learned_val", 
    "kl_div_trigram_learned_val"
])
for run in runs:
    print(run.name)

youthful-fog-1711
glowing-armadillo-1712
faithful-oath-1717
bumbling-brook-1718
devout-thunder-1721
upbeat-glade-1723
solar-sun-1725
celestial-sky-1728
noble-dragon-1729
fresh-forest-1730


In [3]:
import matplotlib.pyplot as plt

groups = df.groupby(["_run_name", "cfg.dataset.number.train"])[["val_loss", "teacher_val_loss", "val_best"]]
groups = sorted(groups, key=lambda x: (x[0][1] < 0, x[0][1]))
max_step = 1000
for (run_id, num_train), data in groups:
    loss = data["val_loss"].tolist()
    if num_train == 600 or num_train == 2000 or num_train == 9000 or num_train == -3000:
        val_loss = data["val_best"].tolist()
        teacher_val_loss = data["teacher_val_loss"].unique()[0]
        val_loss = val_loss - teacher_val_loss
        val_loss = val_loss[:max_step]
        if num_train < 0:
            label = "online"
            linestyle = "--"
            color = "black"
        else:
            label = str(num_train)
            linestyle = "-"
            color = None
        plt.plot(val_loss, label=label, linewidth=2, linestyle=linestyle, color=color)

# Get handles/labels from the current axes
handles, labels = plt.gca().get_legend_handles_labels()

# Figure-level legend above the plot
plt.gcf().legend(
    handles, labels,
    loc="upper center",
    ncol=len(labels),
    fontsize=14,
    framealpha=1,
    bbox_to_anchor=(0.52, 1.04)
)

# Make room for the top legend
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.xlabel("Step", fontsize=20)
plt.ylabel("Excess Best Loss", fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()
# plt.savefig("val-loss-dataset-exps.pdf", bbox_inches="tight")

KeyError: 'cfg.dataset.number.train'

In [None]:
groups = df.groupby(["_run_name", "cfg.dataset.number.train"])[["kl_div_unigram_learned_val", "kl_div_bigram_learned_val", "kl_div_trigram_learned_val"]]
groups = sorted(groups, key=lambda x: x[0][1])
max_step = 1000

for (run_id, num_train), data in groups:
    kl_div_unigram_learned_val = data["kl_div_unigram_learned_val"].tolist()[:max_step]
    kl_div_bigram_learned_val = data["kl_div_bigram_learned_val"].tolist()[:max_step]
    kl_div_teacher_val = data["kl_div_trigram_learned_val"].tolist()[:max_step]
    if num_train == 400 or num_train == 2000 or num_train == 9000:
        plt.plot(kl_div_unigram_learned_val, linewidth=2)
        plt.plot(kl_div_bigram_learned_val, linewidth=2)
        plt.plot(kl_div_teacher_val, linewidth=2)
        plt.xlabel("Step", fontsize=16)
        plt.ylabel("KL Divergence", fontsize=16)
        plt.legend(loc="upper right", fontsize=16,  framealpha=1)
        plt.show()

In [None]:
import matplotlib.pyplot as plt

selected_nums = [400, 2000, 9000]
max_step = 1000

picked = {}
for (run_id, num_train), data in groups:
    if num_train in selected_nums and num_train not in picked:
        picked[num_train] = (run_id, data)
    if len(picked) == len(selected_nums):
        break

ncols = len(selected_nums)
fig, axes = plt.subplots(1, ncols, figsize=(4*ncols, 4), sharey=True)
if ncols == 1:
    axes = [axes]

handles, labels = None, None
for ax, num_train in zip(axes, selected_nums):
    run = picked.get(num_train)
    if run is None:
        ax.set_visible(False)
        continue

    run_id, data = run
    kl_uni = data["kl_div_unigram_learned_val"].tolist()[:max_step]
    kl_bi  = data["kl_div_bigram_learned_val"].tolist()[:max_step]
    kl_t   = data["kl_div_trigram_learned_val"].tolist()[:max_step]

    ln1, = ax.plot(kl_uni, linewidth=2)
    ln2, = ax.plot(kl_bi,  linewidth=2)
    ln3, = ax.plot(kl_t,   linewidth=2)

    if handles is None:
        handles, labels = [ln1, ln2, ln3], ["4-gram", "8-gram", "12-gram"]

    ax.set_title(f"$\\mathbf{{{num_train}}}\\text{{ samples}}$", fontsize=20)
    ax.set_xlabel("Step", fontsize=20)
    ax.xaxis.set_tick_params(labelsize=16)

axes[0].set_ylabel("KL Divergence", fontsize=20)
axes[0].yaxis.set_tick_params(labelsize=16)

fig.legend(
    handles, 
    labels, 
    loc="upper center", 
    ncol=3, 
    fontsize=18, 
    framealpha=1,
    bbox_to_anchor=(0.5, 1.11)
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
# plt.savefig("figures/dataset-size/results/kl-divergence-dataset-exps.pdf", bbox_inches="tight")

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cycler
from matplotlib.ticker import MultipleLocator
import numpy as np

# ---- Prepare grouped data once ----
groups = df.groupby(["_run_name", "cfg.dataset.number.train"])[
    ["val_loss", "teacher_val_loss", "val_best",
     "kl_div_unigram_learned_val", "kl_div_bigram_learned_val", "kl_div_trigram_learned_val"]
]
groups = sorted(groups, key=lambda x: (x[0][1] < 0, x[0][1]))

# ---- What to show where ----
kl_selected_nums = [600, 2000, 9000]      # columns 2â€“4 (KL)
val_selected = [600, 2000, 9000, -3000]   # column 1 (val-loss incl. online=-3000)
max_step = 1000

# ---- Pick one run per selected num_train for the KL panels ----
picked = {}
for (run_id, num_train), data in groups:
    if num_train in kl_selected_nums and num_train not in picked:
        picked[num_train] = (run_id, data)
    if len(picked) == len(kl_selected_nums):
        break

# ---- Figure & axes ----
# Use a bit wider figure so we can squeeze KL axes together without cramping labels
fig, axes = plt.subplots(1, 4, figsize=(24, 6), sharey=False)
ax_val = axes[0]
ax_kl_list = axes[1:]

val_colors = ["tab:red", "tab:purple", "tab:brown"]  # for 600, 2000, 9000
ax_val.set_prop_cycle(cycler(color=val_colors))

# Colors for KL curves (fixed, distinct from val colors)
kl_colors = ["tab:blue", "tab:orange", "tab:green"]   # uni, bi, tri (4/8/12-gram)

# ---- Val-loss panel (first axis) ----
val_handles = []
val_labels = []
seen_label = set()

for (run_id, num_train), data in groups:
    if num_train in val_selected:
        val_loss = data["val_best"].tolist()
        teacher_val_loss = data["teacher_val_loss"].unique()[0]
        y = (val_loss - teacher_val_loss)[:max_step]

        if num_train < 0:
            label = "online"
            linestyle = "--"
            color = "black"   # stays black, outside the cycler
        else:
            label = str(num_train)
            linestyle = "-"
            color = None      # uses ax_val's distinct color cycle

        ln, = ax_val.plot(y, label=label, linewidth=2, linestyle=linestyle, color=color)
        if label not in seen_label:
            val_handles.append(ln)
            val_labels.append(label)
            seen_label.add(label)

ax_val.set_xlabel("Step", fontsize=24)
ax_val.set_ylabel("Best Excess Loss", fontsize=24)
ax_val.tick_params(labelsize=20)

# ---- KL panels (next three axes) ----
kl_handles, kl_labels = None, None
for ax, num_train in zip(ax_kl_list, kl_selected_nums):
    run = picked.get(num_train)
    if run is None:
        ax.set_visible(False)
        continue

    _, data = run
    kl_uni = data["kl_div_unigram_learned_val"].tolist()[:max_step]
    kl_bi  = data["kl_div_bigram_learned_val"].tolist()[:max_step]
    kl_tri = data["kl_div_trigram_learned_val"].tolist()[:max_step]

    ln1, = ax.plot(kl_uni, linewidth=2, color=kl_colors[0])
    ln2, = ax.plot(kl_bi,  linewidth=2, color=kl_colors[1])
    ln3, = ax.plot(kl_tri, linewidth=2, color=kl_colors[2])

    # mark the step of best validation loss
    raw_val = data["val_best"].tolist()[:max_step]
    if len(raw_val) > 0:
        best_step = int(np.nanargmin(raw_val))
        ax.axvline(best_step, linestyle=":", linewidth=2, color="grey", alpha=0.9)

    if kl_handles is None:
        kl_handles, kl_labels = [ln1, ln2, ln3], ["4-gram", "8-gram", "12-gram"]

    ax.set_title(f"$\\mathbf{{{num_train}}}\\,\\text{{samples}}$", fontsize=28)
    ax.set_xlabel("Step", fontsize=24)
    ax.yaxis.set_major_locator(MultipleLocator(1.0))
    ax.tick_params(labelsize=20)

ax_kl_list[0].set_ylabel("KL Divergence", fontsize=24)

# ---- Layout first so axes positions are final-ish ----
plt.tight_layout(rect=[0, 0, 1, 0.9])
gap = 0.02  # small gap between KL panels
# Use current width of first KL axis for all
first_pos = ax_kl_list[0].get_position()
kl_w = first_pos.width
# Keep the left x0 of the first KL axis; pack the next two right next to it
x0 = first_pos.x0
y0 = first_pos.y0
h  = first_pos.height

for i, ax in enumerate(ax_kl_list):
    ax.set_position([x0 + i * (kl_w + gap), y0, kl_w, h])

# ---- Recompute positions for legend anchors after manual positioning ----
val_pos = ax_val.get_position()
val_center_x = val_pos.x0 + val_pos.width / 2
val_top_y = val_pos.y1

kl_left = min(ax.get_position().x0 for ax in ax_kl_list)
kl_right = max(ax.get_position().x1 for ax in ax_kl_list)
kl_top_y = max(ax.get_position().y1 for ax in ax_kl_list)
kl_center_x = (kl_left + kl_right) / 2

# ---- Legends above their sections ----
fig.legend(
    val_handles, val_labels,
    loc="lower center",
    ncol=4,
    fontsize=22,
    framealpha=1,
    bbox_to_anchor=(val_center_x, val_top_y + 0.07),
    bbox_transform=fig.transFigure,
    columnspacing=0.65, 
)

fig.legend(
    kl_handles, kl_labels,
    loc="lower center",
    ncol=len(kl_labels),
    fontsize=22,
    framealpha=1,
    bbox_to_anchor=(kl_center_x, kl_top_y + 0.09),
    bbox_transform=fig.transFigure,
)

# plt.show()
plt.savefig("figures/dataset-size/results/dataset-experiments.pdf", bbox_inches="tight")
