In [None]:
from analysis.utils import fetch_runs, get_runs_data
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
runs = fetch_runs(tags_any=["ICLR-minimal-dataset"])
for run in runs:
    print(run.name)

In [None]:
ngram_metrics = ["unigram_val_loss", "bigram_val_loss", "trigram_val_loss"]
student_metrics = ["val_loss", "teacher_val_loss"]
ngram_data = get_runs_data(runs, metrics=ngram_metrics)
student_data = get_runs_data(runs, metrics=student_metrics)

In [None]:
results = {"dataset_size": [], "model": [], "best_val_loss": []}

ngram_groups = ngram_data.groupby(["_run_name", "cfg.dataset.number.train"])[ngram_metrics]
for (name, dataset_size), group in ngram_groups:
    best_four_gram = min(group["unigram_val_loss"].tolist())
    best_eight_gram = min(group["bigram_val_loss"].tolist())
    best_twelve_gram = min(group["trigram_val_loss"].tolist())

    results["dataset_size"].extend([dataset_size] * 3)
    results["model"].extend(["4-gram", "8-gram", "12-gram"])
    results["best_val_loss"].extend([best_four_gram, best_eight_gram, best_twelve_gram])

student_groups = student_data.groupby(["_run_name", "cfg.dataset.number.train"])[student_metrics]
for (name, dataset_size), group in student_groups:
    best_student = min(group["val_loss"].tolist())
    results["dataset_size"].append(dataset_size)
    results["model"].append("student")
    results["best_val_loss"].append(best_student)

df = pd.DataFrame(results)
print(df.to_markdown())

In [None]:
model_order = ["4-gram", "8-gram", "12-gram", "student"]
pivot = (
    df.pivot(index="dataset_size", columns="model", values="best_val_loss")
    .reindex(sorted(df["dataset_size"].unique()), axis=0)[model_order]
)

x = np.arange(len(pivot.index))
n_models = len(model_order)
width = 0.18
offsets = (np.arange(n_models) - (n_models - 1) / 2) * width

plt.figure(figsize=(12, 6))
for i, m in enumerate(model_order):
    plt.bar(x + offsets[i], pivot[m].values, width=width, label=m)

plt.xticks(x, pivot.index, rotation=0)
plt.xlabel("Dataset size")
plt.ylabel("Best validation loss")
plt.legend(loc="upper right")
plt.tight_layout()
plt.show()