In [1]:
from analysis.utils import fetch_runs, get_runs_data, differing_config
import matplotlib.pyplot as plt
from matplotlib import cycler
from matplotlib.ticker import MultipleLocator
import numpy as np



In [2]:
runs = fetch_runs(tags_any=["ICLR-minimal-dataset"])
df = get_runs_data(
    runs,
    metrics=[
        "val_loss",
        "teacher_val_loss",
        "val_best",
        "kl_div_unigram_learned_val",
        "kl_div_bigram_learned_val",
        "kl_div_trigram_learned_val",
    ],
)
for run in runs:
    print(run.name)

ValueError: dictionary update sequence element #0 has length 1; 2 is required

In [None]:
df = df.drop(columns=["cfg.teacher.span_lengths"])
differing_config(df)

## Validation Loss by Dataset Size

In [None]:
groups = df.groupby(["_run_name", "cfg.dataset.number.train"])[
    ["val_loss", "teacher_val_loss", "val_best"]
]
groups = sorted(groups, key=lambda x: (x[0][1] < 0, x[0][1]))
max_step = 1000

for (run_id, num_train), data in groups:
    if num_train in [600, 2000, 9000, -3000]:
        val_loss = data["val_best"].tolist()
        teacher_val_loss = data["teacher_val_loss"].unique()[0]
        val_loss = val_loss - teacher_val_loss
        val_loss = val_loss[:max_step]
        label = "online" if num_train < 0 else str(num_train)
        linestyle = "--" if num_train < 0 else "-"
        color = "black" if num_train < 0 else None
        plt.plot(val_loss, label=label, linewidth=2, linestyle=linestyle, color=color)

handles, labels = plt.gca().get_legend_handles_labels()
plt.gcf().legend(
    handles, labels, loc="upper center", ncol=len(labels), fontsize=14, framealpha=1, bbox_to_anchor=(0.52, 1.04)
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.xlabel("Step", fontsize=20)
plt.ylabel("Excess Best Loss", fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.savefig("figures/val-loss-dataset-exps.pdf", bbox_inches="tight")

## KL Divergence by Dataset Size

In [None]:
groups = df.groupby(["_run_name", "cfg.dataset.number.train"])[
    ["kl_div_unigram_learned_val", "kl_div_bigram_learned_val", "kl_div_trigram_learned_val"]
]
groups = sorted(groups, key=lambda x: x[0][1])

selected_nums = [600, 2000, 9000]
max_step = 1000

picked = {}
for (run_id, num_train), data in groups:
    if num_train in selected_nums and num_train not in picked:
        picked[num_train] = (run_id, data)
    if len(picked) == len(selected_nums):
        break

fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
handles, labels = None, None

for ax, num_train in zip(axes, selected_nums):
    run_id, data = picked[num_train]
    kl_uni = data["kl_div_unigram_learned_val"].tolist()[:max_step]
    kl_bi = data["kl_div_bigram_learned_val"].tolist()[:max_step]
    kl_t = data["kl_div_trigram_learned_val"].tolist()[:max_step]

    ln1, = ax.plot(kl_uni, linewidth=2)
    ln2, = ax.plot(kl_bi, linewidth=2)
    ln3, = ax.plot(kl_t, linewidth=2)

    if handles is None:
        handles, labels = [ln1, ln2, ln3], ["4-gram", "8-gram", "12-gram"]

    ax.set_title(f"$\\mathbf{{{num_train}}}\\text{{ samples}}$", fontsize=20)
    ax.set_xlabel("Step", fontsize=20)
    ax.xaxis.set_tick_params(labelsize=16)

axes[0].set_ylabel("KL Divergence", fontsize=20)
axes[0].yaxis.set_tick_params(labelsize=16)

fig.legend(handles, labels, loc="upper center", ncol=3, fontsize=18, framealpha=1, bbox_to_anchor=(0.5, 1.11))
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig("figures/kl-divergence-dataset-exps.pdf", bbox_inches="tight")

## Combined Figure

In [None]:
groups = df.groupby(["_run_name", "cfg.dataset.number.train"])[
    [
        "val_loss",
        "teacher_val_loss",
        "val_best",
        "kl_div_unigram_learned_val",
        "kl_div_bigram_learned_val",
        "kl_div_trigram_learned_val",
    ]
]
groups = sorted(groups, key=lambda x: (x[0][1] < 0, x[0][1]))

kl_selected_nums = [600, 2000, 9000]
val_selected = [600, 2000, 9000, -3000]
max_step = 1000

picked = {}
for (run_id, num_train), data in groups:
    if num_train in kl_selected_nums and num_train not in picked:
        picked[num_train] = (run_id, data)
    if len(picked) == len(kl_selected_nums):
        break

fig, axes = plt.subplots(1, 4, figsize=(24, 6), sharey=False)
ax_val = axes[0]
ax_kl_list = axes[1:]

val_colors = ["tab:red", "tab:purple", "tab:brown"]
ax_val.set_prop_cycle(cycler(color=val_colors))
kl_colors = ["tab:blue", "tab:orange", "tab:green"]

# Val loss panel
val_handles, val_labels = [], []
seen_label = set()

for (run_id, num_train), data in groups:
    if num_train in val_selected:
        val_loss = data["val_best"].tolist()
        teacher_val_loss = data["teacher_val_loss"].unique()[0]
        y = (val_loss - teacher_val_loss)[:max_step]

        label = "online" if num_train < 0 else str(num_train)
        linestyle = "--" if num_train < 0 else "-"
        color = "black" if num_train < 0 else None

        (ln,) = ax_val.plot(y, label=label, linewidth=2, linestyle=linestyle, color=color)
        if label not in seen_label:
            val_handles.append(ln)
            val_labels.append(label)
            seen_label.add(label)

ax_val.set_xlabel("Step", fontsize=24)
ax_val.set_ylabel("Best Excess Loss", fontsize=24)
ax_val.tick_params(labelsize=20)

# KL panels
kl_handles, kl_labels = None, None
for ax, num_train in zip(ax_kl_list, kl_selected_nums):
    _, data = picked[num_train]
    kl_uni = data["kl_div_unigram_learned_val"].tolist()[:max_step]
    kl_bi = data["kl_div_bigram_learned_val"].tolist()[:max_step]
    kl_tri = data["kl_div_trigram_learned_val"].tolist()[:max_step]

    ln1, = ax.plot(kl_uni, linewidth=2, color=kl_colors[0])
    ln2, = ax.plot(kl_bi, linewidth=2, color=kl_colors[1])
    ln3, = ax.plot(kl_tri, linewidth=2, color=kl_colors[2])

    raw_val = data["val_best"].tolist()[:max_step]
    if len(raw_val) > 0:
        best_step = int(np.nanargmin(raw_val))
        ax.axvline(best_step, linestyle=":", linewidth=2, color="grey", alpha=0.9)

    if kl_handles is None:
        kl_handles, kl_labels = [ln1, ln2, ln3], ["4-gram", "8-gram", "12-gram"]

    ax.set_title(f"$\\mathbf{{{num_train}}}\\,\\text{{samples}}$", fontsize=28)
    ax.set_xlabel("Step", fontsize=24)
    ax.yaxis.set_major_locator(MultipleLocator(1.0))
    ax.tick_params(labelsize=20)

ax_kl_list[0].set_ylabel("KL Divergence", fontsize=24)

plt.tight_layout(rect=[0, 0, 1, 0.9])

gap = 0.02
first_pos = ax_kl_list[0].get_position()
kl_w, x0, y0, h = first_pos.width, first_pos.x0, first_pos.y0, first_pos.height
for i, ax in enumerate(ax_kl_list):
    ax.set_position([x0 + i * (kl_w + gap), y0, kl_w, h])

val_pos = ax_val.get_position()
val_center_x = val_pos.x0 + val_pos.width / 2
kl_left = min(ax.get_position().x0 for ax in ax_kl_list)
kl_right = max(ax.get_position().x1 for ax in ax_kl_list)
kl_top_y = max(ax.get_position().y1 for ax in ax_kl_list)
kl_center_x = (kl_left + kl_right) / 2

fig.legend(
    val_handles,
    val_labels,
    loc="lower center",
    ncol=4,
    fontsize=22,
    framealpha=1,
    bbox_to_anchor=(val_center_x, val_pos.y1 + 0.07),
    bbox_transform=fig.transFigure,
    columnspacing=0.65,
)

fig.legend(
    kl_handles,
    kl_labels,
    loc="lower center",
    ncol=3,
    fontsize=22,
    framealpha=1,
    bbox_to_anchor=(kl_center_x, kl_top_y + 0.09),
    bbox_transform=fig.transFigure,
)

plt.savefig("figures/dataset-experiments.pdf", bbox_inches="tight")