In [1]:
import os
import json
import numpy as np
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import glob
from collections import defaultdict
import re

from matplotlib.ticker import AutoMinorLocator, LogLocator

from plot_utils import *

In [2]:
%load_ext autoreload
%autoreload 2

## Plot Functions

In [3]:
def plot_runs(
        runs, metrics=METRICS, warmstart_ckpt=None, plot_warmstart=False, plot_warmstart_horizontal_line=False, n_cols=4, unify_legend=True, legend_n_col=None, unit_figsize=None, log_x_scale=False, log_y_scale=False, 
        use_token_measures=True, plot_line_kwargs=None, run_labels_map=None, display_result=False,
        vertical_lines=None, load_metric_kwargs=None, top_runs=None, plot_group_legend=False, group_legend_num_items_row=2,
    ):

    additional_plot_types = ["envelope", "smoothed"]
    if plot_line_kwargs is None:
        plot_line_kwargs = {
            "average": {  # average of all trials
                "errorbar": "se",
            },
            # "envelope": {  # envelope of all trials
            #     "errorbar": "se",
            # },
            # "smoothed": {  # smoothed line of all trials
            #     "errorbar": "se",
            # },
        }
    
    # gather all metrics from all runs at all steps
    load_runs = runs.copy()
    if warmstart_ckpt is not None and plot_warmstart:
        load_runs[WARMSTART_NAME] = warmstart_ckpt["path"]
        potential_paths = [warmstart_ckpt["path"], warmstart_ckpt["path"]+"_trial_0"]
        
        warmstart_path = None
        for path in potential_paths:
            if os.path.exists(os.path.join(BASE_EXP_DIR, path, "config.yaml")):
                warmstart_path = os.path.join(BASE_EXP_DIR, path)
                break
        
        if warmstart_path is None:
            raise ValueError(f"Warmstart path not found for {warmstart_ckpt['path']}")
            
        with open(os.path.join(warmstart_path, "config.yaml"), "r") as f:
            warmstart_config = yaml.load(f, Loader=yaml.FullLoader)

        warmstart_step = warmstart_ckpt["step"]
        loaded_warmstart_ckpt = True
    else:
        warmstart_step = 0
        loaded_warmstart_ckpt = False
    
    if loaded_warmstart_ckpt and not plot_warmstart_horizontal_line:
        warmstart_tokens = warmstart_config["data"]["batch_size"] * warmstart_config["data"]["seq_len"] * warmstart_config["distributed"]["dp_shard"] * warmstart_step
        warmstart_synth_raw_data_ratio = SYNTHETIC_RAW_TOKEN_RATIO_MAP["warmstart"]
    else:
        warmstart_tokens = 0
        warmstart_synth_raw_data_ratio = 1.0
    
    load_metrics = list(set(metrics))
    if load_metric_kwargs is None:
        load_metric_kwargs = {}
    all_metrics = load_run_metrics(load_runs, load_metrics=load_metrics, warmstart_ckpt=warmstart_ckpt, **load_metric_kwargs)

    run_order = list(runs.keys())
    warmstart_run = None
    if plot_warmstart:
        if not plot_warmstart_horizontal_line:
            # shift the steps of all runs by the step of the warmstart checkpoint
            all_metrics.loc[(all_metrics["run"] != WARMSTART_NAME) & (~all_metrics["run"].str.endswith("(scratch)")), "step"] += warmstart_step
            all_metrics.loc[(all_metrics["run"] != WARMSTART_NAME) & (~all_metrics["run"].str.endswith("(scratch)")), "num_tokens"] += warmstart_tokens
        else:
            warmstart_run = all_metrics[all_metrics["run"] == WARMSTART_NAME]
            # remove it from all_metrics
            all_metrics = all_metrics[all_metrics["run"] != WARMSTART_NAME]
    
        # run_order.insert(0, WARMSTART_NAME)
        run_order.append(WARMSTART_NAME)

    
    # Create a grid plot, each for one metric, and different runs in one plot for comparison
    n_metrics = len(metrics)
    n_rows = (n_metrics + n_cols - 1) // n_cols

    # Create subplot figure
    if unit_figsize is None:
        unit_figsize = (6, 4)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(unit_figsize[0]*n_cols, unit_figsize[1]*n_rows))
    if n_rows * n_cols > 1:
        axes = axes.flatten()  # Flatten to make indexing easier
    else:
        axes = [axes]

    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        # Use seaborn's lineplot instead of matplotlib's plot
        if use_token_measures:
            measure = "num_tokens"
        else:
            measure = "step"

        x_name = measure
        
        plot_style_kwargs = {
            "linewidth": 2.5,
            "alpha": 0.6 if any(plot_line_type in plot_line_kwargs for plot_line_type in additional_plot_types) else 1.0,
        }

        # Plot horizontal lines for baseline runs at step 0
        if run_labels_map is not None:
            hue_orders = list(run_labels_map.keys())
            hue_order_map = {run: hue_orders.index(run) for run in run_order}
        else:
            hue_order_map = {run: idx for idx, run in enumerate(run_order)}

        baseline_data = all_metrics[all_metrics['step'].isna()]

        if not baseline_data.empty:
            for run in baseline_data['run'].unique():
                run_baseline = baseline_data[baseline_data['run'] == run]
                if not run_baseline.empty:
                    # Calculate mean and standard error for the baseline
                    baseline_mean = run_baseline[metric].mean()
                    baseline_stderr = run_baseline[metric].std() / np.sqrt(len(run_baseline))
                    
                    # Plot the mean as a horizontal line
                    ax.axhline(
                        y=baseline_mean,
                        color=sns.color_palette()[hue_order_map.get(run, 0)] if run in hue_order_map else "gray",
                        linestyle='--', label=f"{run}",
                        **plot_style_kwargs
                    )
                    
                    # Add error band for standard error
                    ax.axhspan(
                        baseline_mean - baseline_stderr,
                        baseline_mean + baseline_stderr,
                        color=sns.color_palette()[hue_order_map.get(run, 0)] if run in hue_order_map else "gray",
                        alpha=0.2
                    )
        
        # Then plot the rest of the data (excluding step=0)
        plot_data = all_metrics[all_metrics['step'].notna()]
        hue_order_map_plot = {run: hue_order_map[run] for run in plot_data["run"].unique()}
        hue_order_map_plot = dict(sorted(hue_order_map_plot.items(), key=lambda x: x[1]))

        if top_runs is None:
            top_runs = []

        plot_data_top = plot_data[plot_data["run"].isin(top_runs)]
        plot_data_bottom = plot_data[~plot_data["run"].isin(top_runs)]

        if "average" in plot_line_kwargs:
            shared_plot_kwargs = {
                "x": x_name,
                "y": metric,
                "hue": "run",
                "ax": ax,
                "hue_order": hue_order_map_plot,
                "marker": "o",
                **plot_line_kwargs["average"],
                **plot_style_kwargs,
            }
            # Plot bottom lines first
            sns.lineplot(
                data=plot_data_bottom, 
                zorder=3,  
                **shared_plot_kwargs,
            )
            
            # Plot top lines second with higher z-order
            if not plot_data_top.empty:
                sns.lineplot(
                    data=plot_data_top, 
                    zorder=4,
                    **shared_plot_kwargs,
                )
        
        for plot_line_type in additional_plot_types:
            if plot_line_type not in plot_line_kwargs:
                continue

            # Add smoothed version using rolling average
            if "average" in plot_line_kwargs:
                color_map = {line.get_label(): line.get_color() for line in ax.get_lines()}
            else:
                color_map = {run: sns.color_palette()[hue_order_map[run]] for run in run_order}
        
            extra_metrics = post_process_metrics(all_metrics, process_type=plot_line_type, metric_cols=[metric], sort_by_col=x_name)

            # display(extra_metrics.head(100))

            for j, run in enumerate(run_order):
                if run in baseline_data["run"].unique() or (plot_warmstart and run in warmstart_run["run"].unique()):
                    continue

                run_data = extra_metrics[extra_metrics['run'] == run]

                if display_result:
                    # display the last step result
                    last_step_data = run_data[run_data[x_name] == run_data[x_name].max()]
                    display(last_step_data[["run", "step", "num_tokens", metric]])

                sns.lineplot(
                    data=run_data, x=x_name, y=metric, 
                    ax=ax,
                    color=color_map[run],
                    linewidth=3,
                    label=None if "average" in plot_line_kwargs else run,
                    zorder=4 if run in top_runs else 3,
                    **plot_line_kwargs[plot_line_type],
                )

        
        if vertical_lines is not None:
            assert isinstance(vertical_lines, list), "vertical_lines must be a list"
            for line in vertical_lines:
                assert isinstance(line, dict), "vertical_lines must be a list of dictionaries"
                line["linewidth"] = 2   
                ax.axvline(**line)


        if plot_warmstart and plot_warmstart_horizontal_line:
            # ax.axhline(y=warmstart_run[warmstart_run["step"] == warmstart_step][metric].iloc[0], color=sns.color_palette()[hue_order_map[WARMSTART_NAME]], linestyle="--", label=f"{WARMSTART_NAME}", linewidth=2, zorder=2)
            warmstart_mean = warmstart_run[warmstart_run["step"] == warmstart_step][metric].mean()
            warmstart_stderr = warmstart_run[warmstart_run["step"] == warmstart_step][metric].std() / np.sqrt(len(warmstart_run))
            ax.axhline(y=warmstart_mean, color=sns.color_palette()[hue_order_map[WARMSTART_NAME]], linestyle="--", label=f"{WARMSTART_NAME}", linewidth=2, zorder=2)
            ax.axhspan(warmstart_mean - warmstart_stderr, warmstart_mean + warmstart_stderr, color=sns.color_palette()[hue_order_map[WARMSTART_NAME]], alpha=0.2)


        if run_labels_map is not None:
            # Get all lines and their labels from the plot
            lines = ax.get_lines()
            for line in lines:
                current_label = line.get_label()
                if current_label in run_labels_map:
                    line.set_label(run_labels_map[current_label])
        
        if not unify_legend:
            # Remove existing legend if any
            if ax.get_legend() is not None:
                ax.get_legend().remove()
            # Create new legend with updated labels
            ax.legend()
            legend = ax.get_legend()
            for handle in legend.get_lines():
                handle.set_alpha(1.0)
        else:
            if ax.get_legend() is not None:
                ax.get_legend().remove()

        ax.set_ylabel("Metric")
        ax.grid(True)
        

        if log_x_scale:
            ax.set_xscale("log", base=10)

        if log_y_scale:
            ax.set_yscale("log")
            ax.grid(True, which='both', axis='y')   # add back the grid line
        
        if "nll" in metric or "elbo" in metric:
            ax.set_yscale("log")
            # disable scientific notation
            ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.2f'))
            ax.yaxis.set_minor_formatter(plt.FormatStrFormatter('%.2f'))
            ax.grid(True, which='both', axis='y')   # add back the grid line

            if "nll" in metric:
                ax.set_ylabel("NLL")
            elif "elbo" in metric:
                ax.set_ylabel("ELBO w/ 4 Samples")
        else:
            ax.set_ylabel("Accuracy")
        ax.set_title(ALL_EVAL_METRIC_LABEL_MAP.get(metric, metric))

        if x_name == "num_tokens":
            ax.set_xlabel("Total Training Tokens")
        elif x_name == "step":
            ax.set_xlabel("Training Steps")
        elif x_name == "raw_tokens_adjusted_num_tokens":
            ax.set_xlabel("Effective Raw Tokens Seen")
        else:
            raise ValueError(f"Unknown x_name: {x_name}")


    # Remove empty subplots if any
    for idx in range(len(metrics), len(axes)):
        fig.delaxes(axes[idx])
    
    if unify_legend:
        # Adjust the figure size to accommodate the legend at the top
        plt.gcf().set_size_inches(unit_figsize[0]*n_cols, unit_figsize[1]*n_rows + 1)  # Added extra height instead of width
        # Create unified legend at the top
        handles, labels = axes[0].get_legend_handles_labels()
        
        if any(plot_line_type in plot_line_kwargs for plot_line_type in additional_plot_types):
            new_handles = []
            for handle in handles:
                new_handle = plt.Line2D([], [], color=handle.get_color(), label=handle.get_label(), alpha=1.0)
                new_handles.append(new_handle)
        else:
            new_handles = handles

        all_num_cols = 0
        if plot_group_legend:
            group_handles = defaultdict(list)
            group_labels = defaultdict(list)
            
            for handle, label in zip(new_handles, labels):
                group_label, sub_label = label.split(", ")
                if sub_label not in group_labels[group_label]:
                    group_handles[group_label].append(handle)
                    group_labels[group_label].append(sub_label)
            
            # Create a single legend with all groups
            all_handles = []
            all_labels = []
            
            # Sort groups if needed for consistent ordering
            sorted_groups = sorted(group_handles.keys())
            
            for group_label in sorted_groups:
                # Get number of items in this group
                group_items = list(zip(group_handles[group_label], group_labels[group_label]))
                num_items = len(group_items)
                
                # Calculate how many columns this group will span
                items_num_row = group_legend_num_items_row
                cols_per_group = (num_items + items_num_row - 1) // items_num_row
                all_num_cols += cols_per_group

                # Padding group labels to multiple of cols_per_group
                num_padding = cols_per_group - (num_items % cols_per_group)
                if num_padding < cols_per_group:
                    for _ in range(num_padding):
                        group_items.append((plt.Line2D([], [], color='none'), ""))

                # Padding empty items in the first row (for group label) to make it column centered
                num_group_label_padding = (cols_per_group - 1)
                num_group_label_padding_left = (num_group_label_padding + 1) // 2
                num_group_label_padding_right = num_group_label_padding - num_group_label_padding_left

                
                for col_idx in range(cols_per_group):
                    # first row (for label)
                    if col_idx < num_group_label_padding_left or col_idx >= num_group_label_padding_left + 1:
                        # add padding
                        all_handles.append(plt.Line2D([], [], color='none'))
                        all_labels.append("")
                    else:
                        # add group label
                        all_handles.append(plt.Line2D([], [], color='none'))
                        all_labels.append(f"{group_label}")


                    # sub rows
                    col_items = group_items[col_idx*items_num_row: (col_idx+1)*items_num_row]
                    for handle, label in col_items:
                        all_handles.append(handle)
                        all_labels.append(label)
                
            
            # Calculate appropriate number of columns
            legend = fig.legend(all_handles, all_labels, 
                      bbox_to_anchor=(0.5, 1.0), 
                      loc='lower center', 
                      ncol=all_num_cols,
                      fontsize=SIZE_MEDIUM)
            # Make group labels bold
            for text in legend.get_texts():
                if text.get_text() in group_handles.keys():
                    text.set_weight('bold')
        else:
            fig.legend(new_handles, labels, bbox_to_anchor=(0.5, 1.0), loc='lower center', ncol=legend_n_col or len(run_order), fontsize=SIZE_MEDIUM)

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.3, hspace=0.5)  # Adjust these values as needed
    
    return fig

## Plots

In [4]:
EXP_RUNS = {
    "latent_bootstrap_iter=1_mc=4": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=bootstrap_latents_iter=1_mc=4_scratch",
    "latent_bootstrap_iter=2_mc=4": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=bootstrap_latents_iter=2_mc=4_scratch",
    "latent_bootstrap_iter=3_mc=4": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=bootstrap_latents_iter=3_mc=4_scratch",
    "latent_bootstrap_iter=4_mc=4": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=bootstrap_latents_iter=4_mc=4_scratch",
    "raw_token_matched_baseline": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=raw_token_matched_scratch",
    "raw_flops_matched_baseline": "train_bootstrap_fixed_data_bootstrap/train_bootstrap_fixed_data_bootstrap_setup=raw_flops_matched_scratch",
}

WARMSTART_CHECKPOINT = {
    "path": "train_bootstrap_fixed_data_warmstart/train_bootstrap_fixed_data_warmstart_latent=random_opt=cosine_lr=1e-4_240m_raw",
    "step": 4069,
}


RUN_LABELS_MAP = {
    "latent_bootstrap_iter=1_mc=4": "Latent Bootstrap, Iteration 1",
    "latent_bootstrap_iter=2_mc=4": "Latent Bootstrap, Iteration 2",
    "latent_bootstrap_iter=3_mc=4": "Latent Bootstrap, Iteration 3",
    "latent_bootstrap_iter=4_mc=4": "Latent Bootstrap, Iteration 4",
    "raw_token_matched_baseline": "Raw Baseline, Raw-Token-Match",
    "raw_flops_matched_baseline": "Raw Baseline, Train-FLOP-Match",
    "latent_warmstart": "Latent Bootstrap, Iteration 0 (Warmstart)",
}

In [5]:
PLOT_METRICS = [
    # "finemath_4plus_val.nll_per_token",
    "finemath_4plus_val.elbo_4_per_token",
    "hendrycks_math_cot_synthetic.exact_match",
]

PLOT_KWARGS = {
    "metrics": PLOT_METRICS,
    "log_x_scale": True, 
    "log_y_scale": False, 
    "legend_n_col": 4,
    "plot_warmstart": True,
    "plot_warmstart_horizontal_line": True,
    "warmstart_ckpt": WARMSTART_CHECKPOINT, 
    "load_metric_kwargs": {
        "load_all_trials": True,
    },
    "plot_line_kwargs": {
        # "average": {
        #     "errorbar": "se",
        # },
        "envelope": {
            "errorbar": "se",
        },
        # "smoothed": {
        #     "errorbar": "se",
        # },
    },  
    "n_cols": 2,
    "run_labels_map": RUN_LABELS_MAP,
    "plot_group_legend": True,
    "legend_n_col": 3,
    "top_runs": ["latent_bootstrap_iter=1_mc=4", "latent_bootstrap_iter=2_mc=4", "latent_bootstrap_iter=3_mc=4", "raw_token_matched_baseline"],
}


### Bootstrapping results

In [None]:
_plot_runs = {k: v for k, v in EXP_RUNS.items()}
_plot_kwargs = copy.deepcopy(PLOT_KWARGS)

fig = plot_runs(
    _plot_runs, 
    **_plot_kwargs,
)


### Exact performance curves

In [None]:
_plot_runs = {k: v for k, v in EXP_RUNS.items()}
_plot_kwargs = copy.deepcopy(PLOT_KWARGS)

_plot_kwargs["plot_line_kwargs"] = {
        "average": {
            "errorbar": "se",
        },
        # "envelope": {
        #     "errorbar": "se",
        # },
        # "smoothed": {
        #     "errorbar": "se",
        # },
} 

fig = plot_runs(
    _plot_runs, 
    **_plot_kwargs,
)



### Perplexity eval

In [None]:
_plot_runs = {k: v for k, v in EXP_RUNS.items() if "bootstrap" in k}
_plot_kwargs = copy.deepcopy(PLOT_KWARGS)
_plot_kwargs["top_runs"] = []
_plot_kwargs["n_cols"] = 1
_plot_kwargs["metrics"] = ["finemath_4plus_val.nll_per_token"]

fig = plot_runs(
    _plot_runs, 
    **_plot_kwargs,
)
