In [1]:
import numpy as np
import itertools
import os
from pathlib import Path
import pandas as pd
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt

from plot_utils import *

In [2]:
%load_ext autoreload
%autoreload 2

## Plot Functions

In [9]:
VAL_FILENAME = "validation_model_completions_step_{step}.csv"
TEST_FILENAME = "test_model_completions_step_{step}.csv"

def get_accuracy_from_df(df):
    return df['is_correct'].sum() / len(df)


def get_test_acc(results_dir, eval_steps, num_samples=None):
    try:
        last_step = eval_steps[-1]
        test_step_last_df = pd.read_csv(f"{results_dir}/{TEST_FILENAME.format(step=last_step)}")

        if num_samples is not None:
            # group by 'prompt_id' and take the first num_samples rows for each prompt_id
            test_step_last_df = test_step_last_df.groupby('prompt').head(num_samples)

        # print(len(test_step_520_df))

        return get_accuracy_from_df(test_step_last_df)
    except Exception as e:
        return None


def load_all_results_from_format(
        exp_runs, base_exp_dir, eval_group_name, datasets, val_filename, test_filename, 
        bootstrap_iters=[1, 2, 3, 4], trials=[0, 1, 2], seeds=[42, 43, 44, 45, 46], eval_steps=[520], num_samples=None
    ):
    results = {}

    # Iterate over methods and datasets
    for method, base_dir_format in exp_runs.items():
        for dataset in datasets:
            format_key = (method, dataset)

            if format_key not in results:
                results[format_key] = []
            
            # Generate all combinations of variables
            if method == "latent_bootstrap":
                for iter_val, trial_val, seed_val in itertools.product(bootstrap_iters, trials, seeds):
                    try:
                        results_dir = os.path.join(base_exp_dir, base_dir_format.format(iter=iter_val, trial=trial_val, eval_name=eval_group_name.format(dataset=dataset), seed=seed_val))
                        test_acc = get_test_acc(results_dir, num_samples=num_samples, eval_steps=eval_steps)

                        if test_acc is None:
                            print(f"File not found for method {method}, dataset {dataset}, iter {iter_val}, trial {trial_val}, seed {seed_val}: {results_dir}")
                            continue

                        results[format_key].append((iter_val, trial_val, seed_val, test_acc))
                    except Exception as e:
                        print(f"Error processing {iter_val}/{trial_val}/{seed_val}: {e}")
            else:
                for trial_val, seed_val in itertools.product(trials, seeds):
                    try:
                        results_dir = os.path.join(base_exp_dir, base_dir_format.format(trial=trial_val, eval_name=eval_group_name.format(dataset=dataset), seed=seed_val))
                        test_acc = get_test_acc(results_dir, num_samples=num_samples, eval_steps=eval_steps)

                        if test_acc is None:
                            print(f"File not found for method {method}, dataset {dataset}, trial {trial_val}, seed {seed_val}: {results_dir}")
                            continue

                        results[format_key].append((trial_val, seed_val, test_acc))
                    except Exception as e:
                        print(f"Error processing {trial_val}/{seed_val}: {e}")
                
    df_results = {}
    for key, value in results.items():
        if key[0] == 'latent_bootstrap':
            df_results[key] = pd.DataFrame(value, columns=['iter', 'trial', 'seed', 'accuracy'])
        else:
            df_results[key] = pd.DataFrame(value, columns=['trial', 'seed', 'accuracy'])

    return df_results

In [24]:
RUN_MAP = {
    "latent_bootstrap": 'Latent Bootstrap', 
    "raw_flops_matched_baseline": 'Raw Train-FLOP-Match', 
    "warmstart": 'Latent Warmstart', 
    "raw_token_matched_baseline": 'Raw Token-Match',
}

DATASET_MAP = {
    "math": 'MATH (Fine-tuned)',
    "gsm8k": 'GSM8K (Fine-tuned)',
}

RUN_ORDER = list(RUN_MAP.keys())
COLOR_PALETTE = sns.color_palette()

# Plot the results from model selection
def plot_fine_tuning_results(selected_results, datasets, unify_legend=True):
    n_cols, n_rows = len(datasets), 1
    unit_figsize = (6, 4)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(unit_figsize[0]*n_cols, unit_figsize[1]*n_rows))
    if n_cols * n_rows == 1:
        axes = [axes]
    for idx, dataset in enumerate(datasets):
        # Create plot for each dataset
        ax = axes[idx]
        # Check if we have bootstrap data for this dataset
        bootstrap_df = selected_results[('latent_bootstrap', dataset)]
        assert 'iter' in bootstrap_df.columns, f"No iter column in bootstrap data for {dataset}"

        sns.lineplot(x='iter', y='accuracy', data=bootstrap_df, ax=ax, errorbar="se", marker="o", markersize=10, linewidth=3, label=RUN_MAP['latent_bootstrap'], color=COLOR_PALETTE[RUN_ORDER.index('latent_bootstrap')])
        
        # Plot other experiment types as horizontal lines
        x_min, x_max = ax.get_xlim()
        
        for i, exp_type in enumerate(['warmstart', 'raw_flops_matched_baseline', 'raw_token_matched_baseline']):
            df = selected_results[(exp_type, dataset)]
            if len(df) == 0:
                print(f"No {exp_type} data for {dataset}")
                continue
            
            # Calculate mean, std, and standard error of test accuracy
            mean_test_acc = df['accuracy'].mean()
            std_test_acc = df['accuracy'].std()
            se_test_acc = std_test_acc / np.sqrt(len(df))  # Calculate standard error
            
            run_color = COLOR_PALETTE[RUN_ORDER.index(exp_type)]
            ax.axhline(y=mean_test_acc, color=run_color, linestyle='--', label=RUN_MAP[exp_type], linewidth=3)
            
            # Add shaded region for standard error
            ax.fill_between(
                [x_min - 0.5, x_max + 0.5], 
                mean_test_acc - se_test_acc,  # Use standard error instead of std
                mean_test_acc + se_test_acc,  # Use standard error instead of std
                color=run_color, alpha=0.2)
            
        
        x_min, x_max = ax.get_xlim()
        ax.set_xlim(np.ceil(x_min) - 0.1, np.floor(x_max) + 0.1)
        
        # Set labels and title
        ax.set_xlabel('Bootstrap Iteration')
        ax.set_ylabel('Accuracy')
        ax.set_title(DATASET_MAP[dataset])
        ax.legend(loc='best')
        ax.grid(True, linestyle='--', alpha=0.7)

        all_bootstrap_iters = bootstrap_df['iter'].unique()
        ax.set_xticks(all_bootstrap_iters)
        ax.set_xticklabels(all_bootstrap_iters)

        if not unify_legend:
            ax.legend()
            legend = ax.get_legend()
        else:
            ax.get_legend().remove()


    if unify_legend:
        # Adjust the figure size to accommodate the legend at the top
        plt.gcf().set_size_inches(unit_figsize[0]*n_cols, unit_figsize[1]*n_rows + 1)  # Added extra height instead of width
        # Create unified legend at the top
        handles, labels = axes[0].get_legend_handles_labels()
        fig.legend(handles, labels, bbox_to_anchor=(0.5, 1.0), loc='lower center', ncol=2, fontsize=SIZE_LARGE)

    plt.tight_layout()
    return fig

## Plots

In [27]:
EXP_RUN_TO_BASE_DUMP_DIR_FORMAT = {
    "latent_bootstrap": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=bootstrap_latents_iter={iter}_mc=4_scratch_trial_{trial}/{eval_name}/0000033060/seed={seed}",
    "raw_token_matched_baseline": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=raw_token_matched_scratch_trial_{trial}/{eval_name}/0000009765/seed={seed}",
    "raw_flops_matched_baseline": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=raw_flops_matched_scratch_trial_{trial}/{eval_name}/0000033060/seed={seed}",
    # "warmstart": "train_bootstrap_fixed_data_warmstart/train_bootstrap_fixed_data_warmstart_latent=random_opt=cosine_lr=1e-4_240m_raw_trial_{trial}/{eval_name}/0000004096/seed={seed}",
}

BASE_EXP_DIR = "../exp_logs"
EVAL_GROUP_NAME = "finetune_eval_on_{dataset}"
EVAL_DATASETS = ["math"]

LOAD_RESULTS_KWARGS = {
    "exp_runs": EXP_RUN_TO_BASE_DUMP_DIR_FORMAT,
    "base_exp_dir": BASE_EXP_DIR,
    "eval_group_name": EVAL_GROUP_NAME,
    "datasets": EVAL_DATASETS,
    "val_filename": VAL_FILENAME,
    "test_filename": TEST_FILENAME,
    "bootstrap_iters": [1, 2, 3, 4],
    "trials": [0, 1, 2],
    "seeds": [42, 43, 44, 45, 46],
    "eval_steps": [520],
    "num_samples": None,
}

In [None]:
formatted_results = load_all_results_from_format(**LOAD_RESULTS_KWARGS)

In [None]:
fig = plot_fine_tuning_results(formatted_results, datasets=EVAL_DATASETS)