In [1]:
import os
import json
import numpy as np
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import glob
from collections import defaultdict
import re

from matplotlib.ticker import AutoMinorLocator, LogLocator

from plot_utils import *

In [2]:
%load_ext autoreload
%autoreload 2

## Plot Functions

In [3]:
def plot_runs(
        runs, metrics=METRICS, warmstart_ckpt=None, use_token_measures=True, 
        plot_warmstart=False, plot_warmstart_horizontal_line=False, n_cols=4, unify_legend=True, legend_n_col=None, unit_figsize=None, log_x_scale=False, log_y_scale=False, 
        plot_raw_token_adjusted_x_axis=False, plot_line_kwargs=None, run_labels_map=None, display_result=False,
        vertical_lines=None, load_metric_kwargs=None,
    ):

    additional_plot_types = ["envelope", "smoothed"]
    if plot_line_kwargs is None:
        plot_line_kwargs = {
            "average": {  # average of all trials
                "errorbar": "se",
            },
            # "envelope": {  # envelope of all trials
            #     "errorbar": "se",
            # },
            # "smoothed": {  # smoothed line of all trials
            #     "errorbar": "se",
            # },
        }
    
    # gather all metrics from all runs at all steps
    load_runs = runs.copy()
    if warmstart_ckpt is not None and plot_warmstart:
        load_runs[WARMSTART_NAME] = warmstart_ckpt["path"]
        potential_paths = [warmstart_ckpt["path"], warmstart_ckpt["path"]+"_trial_0"]
        
        warmstart_path = None
        for path in potential_paths:
            if os.path.exists(os.path.join(BASE_EXP_DIR, path, "config.yaml")):
                warmstart_path = os.path.join(BASE_EXP_DIR, path)
                break
        
        if warmstart_path is None:
            raise ValueError(f"Warmstart path not found for {warmstart_ckpt['path']}")
            
        with open(os.path.join(warmstart_path, "config.yaml"), "r") as f:
            warmstart_config = yaml.load(f, Loader=yaml.FullLoader)

        warmstart_step = warmstart_ckpt["step"]
        loaded_warmstart_ckpt = True
    else:
        warmstart_step = 0
        loaded_warmstart_ckpt = False
    
    if loaded_warmstart_ckpt and not plot_warmstart_horizontal_line:
        warmstart_tokens = warmstart_config["data"]["batch_size"] * warmstart_config["data"]["seq_len"] * warmstart_config["distributed"]["dp_shard"] * warmstart_step
        warmstart_synth_raw_data_ratio = SYNTHETIC_RAW_TOKEN_RATIO_MAP["warmstart"]
    else:
        warmstart_tokens = 0
        warmstart_synth_raw_data_ratio = 1.0
    
    load_metrics = list(set(metrics))
    if load_metric_kwargs is None:
        load_metric_kwargs = {}
    all_metrics = load_run_metrics(load_runs, load_metrics=load_metrics, warmstart_ckpt=warmstart_ckpt, **load_metric_kwargs)

    run_order = list(runs.keys())
    warmstart_run = None
    if plot_warmstart:
        if not plot_warmstart_horizontal_line:
            # shift the steps of all runs by the step of the warmstart checkpoint
            all_metrics.loc[(all_metrics["run"] != WARMSTART_NAME) & (~all_metrics["run"].str.endswith("(scratch)")), "step"] += warmstart_step
            all_metrics.loc[(all_metrics["run"] != WARMSTART_NAME) & (~all_metrics["run"].str.endswith("(scratch)")), "num_tokens"] += warmstart_tokens
        else:
            warmstart_run = all_metrics[all_metrics["run"] == WARMSTART_NAME]
            # remove it from all_metrics
            all_metrics = all_metrics[all_metrics["run"] != WARMSTART_NAME]
    
        run_order.append(WARMSTART_NAME)

    # Create a grid plot, each for one metric, and different runs in one plot for comparison
    n_metrics = len(metrics)
    n_rows = (n_metrics + n_cols - 1) // n_cols

    # Create subplot figure
    if unit_figsize is None:
        unit_figsize = (6, 4)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(unit_figsize[0]*n_cols, unit_figsize[1]*n_rows))
    axes = axes.flatten()  # Flatten to make indexing easier

    if plot_raw_token_adjusted_x_axis:
        # compute raw tokens for different methods and mixing ratios
        all_metrics["mixing_ratio"] = all_metrics["run"].str.extract(r"_mix=(\d+\.\d+)", expand=False).astype(float)

        # For runs without a mix parameter, set the mixing ratio to appropriate defaults
        mask_no_mix = all_metrics["mixing_ratio"].isna()
        mask_baseline = all_metrics["run"].str.contains("raw", case=False)
        all_metrics.loc[mask_baseline & mask_no_mix, "mixing_ratio"] = 1.0  # baselines use all raw data
        all_metrics.loc[~mask_baseline & mask_no_mix, "mixing_ratio"] = 0.0  # non-baselines without mix use all synthetic

        # Use the name mapping in SYNTHETIC_RAW_TOKEN_RATIO_MAP
        def _get_synth_raw_data_ratio(run_name):
            if "pretrained" in run_name:
                return 1.0
            matched_keys = list(filter(lambda x: x in run_name, SYNTHETIC_RAW_TOKEN_RATIO_MAP.keys()))
            assert len(matched_keys) > 0, f"No matched keys found for {run_name}: {matched_keys}"
            if len(matched_keys) > 1:
                print(f"Warning: Multiple matched keys found for {run_name}: {matched_keys}, using the first one: {matched_keys[0]}")
            return SYNTHETIC_RAW_TOKEN_RATIO_MAP[matched_keys[0]]
        
        all_metrics["synth_raw_data_ratio"] = all_metrics["run"].apply(_get_synth_raw_data_ratio)
        for key in ["step", "num_tokens"]:
            if key == "num_tokens":
                warmstart_baseline = warmstart_tokens
            else:
                warmstart_baseline = warmstart_step

            # equation: raw token/step count x, mixing ratio r, final token/step count y, synthetic/raw ratio s
            # y = x * r + x * s * (1 - r)
            # x = y / (r + s * (1 - r))
            # all_metrics[f"raw_tokens_adjusted_{key}"] = warmstart_baseline + (all_metrics[key] - warmstart_baseline) * (all_metrics["mixing_ratio"] + (1 - all_metrics["mixing_ratio"]) / (all_metrics["synth_raw_data_ratio"]))
            all_metrics[f"raw_tokens_adjusted_{key}"] = warmstart_baseline / warmstart_synth_raw_data_ratio + (all_metrics[key] - warmstart_baseline) / (all_metrics["mixing_ratio"] + (1 - all_metrics["mixing_ratio"]) * (all_metrics["synth_raw_data_ratio"]))

    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        # Use seaborn's lineplot instead of matplotlib's plot
        if use_token_measures:
            measure = "num_tokens"
        else:
            measure = "step"

        if plot_raw_token_adjusted_x_axis:
            x_name = f"raw_tokens_adjusted_{measure}"
        else:
            x_name = measure
        
        plot_style_kwargs = {
            "linewidth": 2.5,
            "alpha": 0.6 if any(plot_line_type in plot_line_kwargs for plot_line_type in additional_plot_types) else 1.0,
        }

        # Plot horizontal lines for baseline runs at step 0
        hue_order_map = {run: idx for idx, run in enumerate(run_order)}

        baseline_data = all_metrics[all_metrics['step'].isna()]
        if not baseline_data.empty:
            x_min = all_metrics[x_name].min()
            x_max = all_metrics[x_name].max()
            
            for run in baseline_data['run'].unique():
                run_baseline = baseline_data[baseline_data['run'] == run]
                if not run_baseline.empty:
                    baseline_value = run_baseline[metric].iloc[0]
                    ax.axhline(
                        # y=baseline_value, xmin=x_min, xmax=x_max, 
                        y=baseline_value,
                        color="gray",
                        linestyle='--', label=f"{run}",
                        **plot_style_kwargs
                    )

        # Then plot the rest of the data (excluding step=0)
        plot_data = all_metrics[all_metrics['step'].notna()]
        hue_order_map_plot = {run: hue_order_map[run] for run in plot_data["run"].unique()}

        
        if "average" in plot_line_kwargs:
            sns.lineplot(
                data=plot_data, 
                x=x_name, 
                y=metric, 
                hue="run", 
                ax=ax, 
                hue_order=hue_order_map_plot, 
                marker="o",
                **plot_line_kwargs["average"],
                **plot_style_kwargs,
            )
        
        for plot_line_type in additional_plot_types:
            if plot_line_type not in plot_line_kwargs:
                continue

            # Add smoothed version using rolling average
            color_map = {line.get_label(): line.get_color() for line in ax.get_lines()}
        
            extra_metrics = post_process_metrics(all_metrics, process_type=plot_line_type, metric_cols=[metric], sort_by_col=x_name)

            # display(extra_metrics.head(100))

            for j, run in enumerate(run_order):
                if run in baseline_data["run"].unique() or (plot_warmstart and run in warmstart_run["run"].unique()):
                    continue


                run_data = extra_metrics[extra_metrics['run'] == run]

                if display_result:
                    # display the last step result
                    last_step_data = run_data[run_data[x_name] == run_data[x_name].max()]
                    display(last_step_data[["run", "step", "num_tokens", metric]])

                sns.lineplot(
                    data=run_data, x=x_name, y=metric, 
                    ax=ax,
                    color=color_map[run] if "average" in plot_line_kwargs else None, 
                    linewidth=3,
                    label=None if "average" in plot_line_kwargs else run,
                    **plot_line_kwargs[plot_line_type],
                )

        
        if vertical_lines is not None:
            assert isinstance(vertical_lines, list), "vertical_lines must be a list"
            for line in vertical_lines:
                assert isinstance(line, dict), "vertical_lines must be a list of dictionaries"
                line["linewidth"] = 2   
                ax.axvline(**line)


        if plot_warmstart and plot_warmstart_horizontal_line:
            ax.axhline(y=warmstart_run[warmstart_run["step"] == warmstart_step][metric].iloc[0], color=sns.color_palette()[hue_order_map[WARMSTART_NAME]], linestyle="--", label=f"{WARMSTART_NAME}", linewidth=2)


        if run_labels_map is not None:
            # Get all lines and their labels from the plot
            lines = ax.get_lines()
            for line in lines:
                current_label = line.get_label()
                if current_label in run_labels_map:
                    line.set_label(run_labels_map[current_label])
        
        if not unify_legend:
            # Remove existing legend if any
            if ax.get_legend() is not None:
                ax.get_legend().remove()
            # Create new legend with updated labels
            ax.legend()
            legend = ax.get_legend()
            for handle in legend.get_lines():
                handle.set_alpha(1.0)
        else:
            if ax.get_legend() is not None:
                ax.get_legend().remove()

        ax.set_ylabel("Metric")
        ax.grid(True)
        

        if log_x_scale:
            ax.set_xscale("log", base=10)

        if log_y_scale:
            ax.set_yscale("log")
            ax.grid(True, which='both', axis='y')   # add back the grid line
        
        if "nll" in metric or "elbo" in metric:
            ax.set_yscale("log")
            # disable scientific notation
            ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.2f'))
            ax.yaxis.set_minor_formatter(plt.FormatStrFormatter('%.2f'))
            ax.grid(True, which='both', axis='y')   # add back the grid line

            if "nll" in metric:
                ax.set_ylabel("NLL")
            elif "elbo" in metric:
                ax.set_ylabel("ELBO w/ 4 Samples")
        else:
            ax.set_ylabel("Accuracy")
        ax.set_title(ALL_EVAL_METRIC_LABEL_MAP.get(metric, metric))

        if x_name == "num_tokens":
            ax.set_xlabel("Total Training Tokens")
        elif x_name == "step":
            ax.set_xlabel("Training Steps")
        elif x_name == "raw_tokens_adjusted_num_tokens":
            ax.set_xlabel("Effective Raw Tokens Seen")
        else:
            raise ValueError(f"Unknown x_name: {x_name}")


    # Remove empty subplots if any
    for idx in range(len(metrics), len(axes)):
        fig.delaxes(axes[idx])
    
    if unify_legend:
        # Adjust the figure size to accommodate the legend at the top
        plt.gcf().set_size_inches(unit_figsize[0]*n_cols, unit_figsize[1]*n_rows + 1)  # Added extra height instead of width
        # Create unified legend at the top
        handles, labels = axes[0].get_legend_handles_labels()
        
        if any(plot_line_type in plot_line_kwargs for plot_line_type in additional_plot_types):
            new_handles = []
            for handle in handles:
                new_handle = plt.Line2D([], [], color=handle.get_color(), label=handle.get_label(), alpha=1.0)
                new_handles.append(new_handle)
        else:
            new_handles = handles
        
        fig.legend(new_handles, labels, bbox_to_anchor=(0.5, 1.0), loc='lower center', ncol=legend_n_col or len(run_order), fontsize=SIZE_MEDIUM)

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.3, hspace=0.5)  # Adjust these values as needed
    
    return fig

## Plots

In [4]:
EXP_RUNS = {
    # # # Method compare
    "latent_thought": "train_synth_data_method_compare_warmstart/train_synth_data_method_compare_warmstart_latent=random_opt=cosine_lr=1e-4_latent_thought",
    "raw_fresh": "train_synth_data_method_compare_warmstart/train_synth_data_method_compare_warmstart_latent=null_opt=cosine_lr=1e-4_raw_fresh",
    "raw_repeat": "train_synth_data_method_compare_warmstart/train_synth_data_method_compare_warmstart_latent=null_opt=cosine_lr=1e-4_raw_repeat",

    "wrap_base": "train_synth_data_method_compare_warmstart/train_synth_data_method_compare_warmstart_latent=pure_opt=cosine_lr=1e-4_wrap_baseline_mix=0.0",
    "wrap_cot": "train_synth_data_method_compare_warmstart/train_synth_data_method_compare_warmstart_latent=pure_opt=cosine_lr=1e-4_wrap_cot_mix=0.0",
    # "baseline_pretrained_tinyllama": "pretrained_hf_ckpts/TinyLlama/TinyLlama_v1.1-embd-resized",
}


RUN_LABELS_MAP = {
    "latent_thought": "Latent Thought",
    "raw_fresh": "Raw-Fresh",
    "raw_repeat": "Raw-Repeat",
    "wrap_base": "WRAP-Orig",
    "wrap_cot": "WRAP-CoT",
    "surface_cot": "Latent Thought (Mix in Surface)",
    # "baseline_pretrained_tinyllama": "Pretrained TinyLlama",
}

In [5]:
PLOT_METRCS = [
    "hendrycks_math_cot_synthetic.exact_match", "hendrycks_math_cot.exact_match",
    "gsm8k_cot_synthetic_alt.exact_match", "gsm8k_cot_alt.exact_match",
    # "mmlu_cot_synthetic_stem.exact_match", "mmlu_cot_flan_stem.exact_match",
]

PLOT_KWARGS = {
    "metrics": PLOT_METRCS,
    "log_x_scale": True, 
    "log_y_scale": False, 
    "legend_n_col": 3,
    "plot_raw_token_adjusted_x_axis": False,
    "plot_warmstart": False,
    "plot_warmstart_horizontal_line": False,
    "load_metric_kwargs": {
        "load_all_trials": True,
    },
    "plot_line_kwargs": {
        "average": {
            "errorbar": "se",
        },
        # "envelope": {
        #     "errorbar": "se",
        # },
        # "smoothed": {
        #     "errorbar": "se",
        # },
    },  
    "run_labels_map": RUN_LABELS_MAP,
    "n_cols": 2,
}

### Comparison with baselines

In [None]:
_plot_runs = copy.deepcopy(EXP_RUNS)
_plot_kwargs = copy.deepcopy(PLOT_KWARGS)

fig = plot_runs(
    _plot_runs, 
    **_plot_kwargs,
)

### Comparison normalized by the effective raw tokens seen 

In [None]:
_plot_runs = copy.deepcopy(EXP_RUNS)
_plot_kwargs = copy.deepcopy(PLOT_KWARGS)
_plot_kwargs["plot_raw_token_adjusted_x_axis"] = True

fig = plot_runs(
    _plot_runs, 
    **_plot_kwargs,
)