In [None]:
import os
import ast
import csv
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy

import sys
sys.path.append(os.path.join(os.getcwd(), 'src'))
import config   

# Note:     (kl loss) 0 <= alpha <= 1 (next_token loss)

In [None]:
load_all = True
exclude = []

csv_log_dir = "/scratch/klambert/model_log/single_logs"
subdirs = ["Seed 16-20250918-133806-7fab643-olmo1b-a0.5-t1-lr7p5en06", "Seed 42-20250918-133828-7fab643-olmo1b-a0.5-t1-lr7p5en06", "Seed 32-20250918-133811-7fab643-olmo1b-a0.5-t1-lr7p5en06"]
filenames = ["results.csv"]
all_dfs = []

for i in range(len(subdirs)):
    csv_full_dir = os.path.join(csv_log_dir, subdirs[i])
    csv_files = glob.glob(os.path.join(csv_full_dir, "*.csv")) # finds all files with extension .csv in csv_full_dir

    dfs = {}
    for filepath in csv_files:        
        name = os.path.splitext(os.path.basename(filepath))[0]
        if name in exclude:
            continue
        if not load_all and name not in filenames:
            continue
        
        df = pd.read_csv(filepath, low_memory=False)
        df["source"] = name
        all_dfs.append(df)

combined_df = pd.concat(all_dfs, ignore_index=True)

print(combined_df["source"].value_counts())


In [None]:
sources = combined_df["source"].unique()

cmap = plt.get_cmap('inferno')

color_map = {source: cmap(0.8 * i / max(1, len(sources) - 1)) for i, source in enumerate(sources)}

color_map["default"] = "black"

print("Generated color map:")
for source, color in color_map.items():
    if source != "default":
        print(f'    "{source}": "{color}",')

In [None]:
for df in all_dfs:
    metadata_str = df[df["metadata"].notna()]["metadata"].iloc[0]
    metadata = ast.literal_eval(metadata_str)

    print(f"\n==== METADATA FROM {metadata.get('Custom run name','N/A')} ====")
    print(f"Run ID String: {metadata.get('ID string', 'N/A')}")
    print(f"Run Description: {metadata.get('Description', 'N/A')}\n")

    for key, value in metadata.items():
        print(f"{key}: {value}")

In [None]:
# Student train loss (hybrid) over rounds
side_by_side = False
x_min_loss = False
min_len = 2500

filter_size = 51
kernel = np.ones(filter_size) / filter_size


if not side_by_side:
    # Display values on the x-axis only up to the minimum number of logged lines across files
    if x_min_loss:
        for name, df in combined_df.groupby("source"):
            student_train_df = df[(df["role"] == "student") & (df["phase"] == "train") & (df["function"] == "compute_loss")]

            loss = student_train_df["train_loss"].to_numpy()
            min_len = min(min_len, len(loss))

    plt.figure(figsize=(12, 6))

    for name, df in combined_df.groupby("source"):

        student_train_df = df[(df["role"] == "student") & (df["phase"] == "train") & (df["function"] == "compute_loss")]

        loss = student_train_df["train_loss"].to_numpy()[:min_len]
        smoothed = scipy.signal.fftconvolve(loss, kernel, mode="valid")

        color = color_map.get(name, color_map["default"])
        plt.plot(range(len(loss)), loss, alpha=0.2, label=f"{name} (raw)", color=color)
        plt.plot(range(len(smoothed)), smoothed, label=f"{name} (smoothed)", color=color)

    plt.title("Student Hybrid Training Loss Across Rounds")
    plt.xlabel("Global Step")
    plt.ylabel("Train Loss")
    plt.legend()
    plt.grid(True)
    # plt.yscale("log")
    plt.show()
else:
    num_files = len(dfs)
    cols = min(num_files, 3)
    rows = (num_files + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 4 * rows), squeeze=False)

    for idx, (name, df) in enumerate(dfs.items()):
        row, col = divmod(idx, cols)
        ax = axes[row][col]

        student_train_df = df[(df["role"] == "student") & (df["phase"] == "train") & (df["function"] == "compute_loss")]

        loss = student_train_df["train_loss"].to_numpy()
        if len(loss) < filter_size:
            ax.set_title(f"{name} (too short)")
            ax.axis("off")
            continue

        smoothed = scipy.signal.fftconvolve(loss, kernel, mode="valid")

        color = color_map.get(name, color_map["default"])
        ax.plot(range(len(loss)), loss, alpha=0.3, label=f"Raw", color=color)
        ax.plot(range(len(smoothed)), smoothed, label=f"Smoothed", color=color)
        
        ax.set_title(f"Training Loss: {name}")
        ax.set_xlabel("Global Step")
        ax.set_ylabel("Train Loss")
        ax.legend()

    # Hide any unused subplots
    for i in range(idx + 1, rows * cols):
        row, col = divmod(i, cols)
        axes[row][col].axis("off")

    plt.yscale("log")
    plt.tight_layout()
    plt.show()

In [None]:
# Student Training Eval Loss Across Rounds

side_by_side = False
x_min_loss = False  # Display values on the x-axis only up to the minimum number of logged lines across files
min_len = 200

filter_size = 11
kernel = np.ones(filter_size) / filter_size

if not side_by_side:
    if x_min_loss:
        for name, df in combined_df.groupby("source"):
            student_train_df = df[(df["role"] == "student") & (df["phase"] == "eval") & (df["function"] == "eval_step")]

            loss = student_train_df["eval_loss"].to_numpy()
            min_len = min(min_len, len(loss))

    plt.figure(figsize=(12, 6))

    for name, df in combined_df.groupby("source"):
        student_train_df = df[(df["role"] == "student") & (df["phase"] == "eval") & (df["function"] == "eval_step")]

        loss = student_train_df["eval_loss"].to_numpy()[:min_len]
        pad = (filter_size - 1) // 2
        padded_loss = np.pad(loss, pad_width=pad, mode="edge")
        smoothed = scipy.signal.fftconvolve(padded_loss, kernel, mode="valid")

        color = color_map.get(name, color_map["default"])
        plt.plot(range(len(loss)), loss, alpha=0.2, label=f"{name} (raw)", color=color)
        plt.plot(range(len(smoothed)), smoothed, label=f"{name} (smoothed)", color=color)

    baseline_teacher_loss = 0.7968094515065487
    plt.axhline(baseline_teacher_loss, color="black", linestyle="--", linewidth=2,
                label=f"teacher baseline ({baseline_teacher_loss:.3f})")

    plt.title("Student Training Eval Loss Across Rounds")
    plt.xlabel("Global Step")
    plt.ylabel("Validation Loss")
    plt.legend()
    plt.grid(True)
    # plt.yscale("log")
    plt.show()
else:
    num_files = len(dfs)
    cols = min(num_files, 3)
    rows = (num_files + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 4 * rows), squeeze=False)

    for idx, (name, df) in enumerate(dfs.items()):
        row, col = divmod(idx, cols)
        ax = axes[row][col]

        student_train_df = df[(df["role"] == "student") & (df["phase"] == "eval") & (df["function"] == "eval_step")]

        loss = student_train_df["eval_loss"].to_numpy()
        if len(loss) < filter_size:
            ax.set_title(f"{name} (too short)")
            ax.axis("off")
            continue

        smoothed = scipy.signal.fftconvolve(loss, kernel, mode="valid")

        color = color_map.get(name, color_map["default"])
        ax.plot(range(len(loss)), loss, alpha=0.3, label=f"Raw", color=color)
        ax.plot(range(len(smoothed)), smoothed, label=f"Smoothed", color=color)
        
        ax.set_title(f"Validation Loss: {name}")
        ax.set_xlabel("Global Step")
        ax.set_ylabel("Validation Loss")
        ax.legend()

    # Hide any unused subplots
    for i in range(idx + 1, rows * cols):
        row, col = divmod(i, cols)
        axes[row][col].axis("off")

    plt.yscale("log")
    plt.tight_layout()
    plt.show()

In [None]:
# Student train loss (kl) over rounds
side_by_side = False
x_min_loss = False
min_len = 200

filter_size = 11
kernel = np.ones(filter_size) / filter_size

if not side_by_side:
    # Display values on the x-axis only up to the minimum number of logged lines across files
    if x_min_loss:
        for name, df in combined_df.groupby("source"):
            student_train_df = df[
                (df["role"] == "student") &
                (df["phase"] == "eval") &
                (df["function"] == "eval_step")
            ]

            loss = student_train_df["eval_kl_loss"].to_numpy()
            min_len = min(min_len, len(loss))

    plt.figure(figsize=(12, 6))

    for  name, df in combined_df.groupby("source"):
        student_train_df = df[
            (df["role"] == "student") &
            (df["phase"] == "eval") &
            (df["function"] == "eval_step")
        ]

        loss = student_train_df["eval_kl_loss"].to_numpy()[:min_len]
        pad = (filter_size - 1) // 2
        padded_loss = np.pad(loss, pad_width=pad, mode='edge')
        smoothed = scipy.signal.fftconvolve(padded_loss, kernel, mode="valid")

        color = color_map.get(name, color_map["default"])
        plt.plot(range(len(loss)), loss, alpha=0.2, label=f"{name} (raw)", color=color)
        plt.plot(range(len(smoothed)), smoothed, label=f"{name} (smoothed)", color=color)


    plt.title("Student Training KL Eval Loss Across Rounds")
    plt.xlabel("Global Step")
    plt.ylabel("KL Validation Loss")
    plt.legend()
    plt.grid(True)
    # plt.yscale("log")
    plt.show()
else:
    num_files = len(dfs)
    cols = min(num_files, 3)
    rows = (num_files + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 4 * rows), squeeze=False)
    
    for idx, (name, df) in enumerate(dfs.items()):
        row, col = divmod(idx, cols)
        ax = axes[row][col]

        student_train_df = df[
            (df["role"] == "student") &
            (df["phase"] == "eval") &
            (df["function"] == "eval_step")
        ]

        loss = student_train_df["eval_kl_loss"].to_numpy()
        if len(loss) < filter_size:
            ax.set_title(f"{name} (too short)")
            ax.axis("off")
            continue

        smoothed = scipy.signal.fftconvolve(loss, kernel, mode="valid")

        color = color_map.get(name, color_map["default"])
        ax.plot(range(len(loss)), loss, alpha=0.3, label=f"Raw", color=color)
        ax.plot(range(len(smoothed)), smoothed, label=f"Smoothed", color=color)
        
        ax.set_title(f"KL Validation Loss: {name}")
        ax.set_xlabel("Global Step")
        ax.set_ylabel("KL Loss")
        ax.legend()

    # Hide any unused subplots
    for i in range(idx + 1, rows * cols):
        row, col = divmod(i, cols)
        axes[row][col].axis("off")
    
    plt.yscale("log")
    plt.tight_layout()
    plt.show()