In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

## Utils

In [None]:
COLOR_P5 = "#D55E00"
COLOR_P5 = "#C51B00" 
COLOR_N5 = "#0072B2"
COLOR_ZERO = "#009E73"
COLOR_QA = "#000000"
COLOR_P5_STOCH = "#E69F00"
COLOR_P5_DET = "#56B4E9"

In [None]:
def analyze_simulation_runs(csv_paths):
    """
    Calculate average and standard deviation for each checkpoint across multiple simulation runs.
    
    Parameters:
    csv_paths (list): List of paths to CSV files
    
    Returns:
    pandas.DataFrame: DataFrame with columns 'mean' and 'std' for each metric,
                     indexed by row number (checkpoint position)
    """
    # Read all CSV files
    dataframes = []
    for path in csv_paths:
        df = pd.read_csv(path)
        dataframes.append(df)
    
    # Check that all dataframes have the same number of rows
    num_rows = len(dataframes[0])
    if not all(len(df) == num_rows for df in dataframes):
        raise ValueError("All CSV files must have the same number of rows")
    
    # Get column names (excluding checkpoint column if present)
    # Assuming the first column might be checkpoint/step/iteration
    columns = dataframes[0].columns.tolist()
    
    # If there's a column that looks like checkpoint/step/iteration, exclude it
    checkpoint_cols = ['checkpoint', 'step', 'iteration', 'epoch']
    data_columns = [col for col in columns 
                   if col.lower() not in checkpoint_cols and not col.lower().startswith('step')]
    
    # If no data columns identified, use all columns except the first one
    if not data_columns:
        data_columns = columns[1:] if len(columns) > 1 else columns
    
    # Create result dictionary to store mean and std for each metric
    results = {}
    
    # Calculate mean and std for each metric across all runs
    for col in data_columns:
        # Extract column values from all dataframes
        col_data = np.array([df[col].values for df in dataframes])
        
        # Calculate mean and std across runs (axis=0)
        col_mean = np.mean(col_data, axis=0)
        col_std = np.std(col_data, axis=0)
        
        results[f'{col}_mean'] = col_mean
        results[f'{col}_std'] = col_std
    
    # Create result dataframe
    result_df = pd.DataFrame(results)
    result_df.index.name = 'checkpoint_position'
    
    # Optionally, add average checkpoint value if checkpoint column exists
    checkpoint_col = None
    for col in columns:
        if col.lower() in checkpoint_cols or col.lower().startswith('step'):
            checkpoint_col = col
            break
    
    if checkpoint_col:
        checkpoint_values = np.array([df[checkpoint_col].values for df in dataframes])
        avg_checkpoint = np.mean(checkpoint_values, axis=0)
        result_df.insert(0, 'avg_checkpoint_value', avg_checkpoint)
    
    return result_df

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections.abc import Sequence


def plot_with_error_bars_multi(
    stats_dfs: Sequence[pd.DataFrame],
    metric_names=("is_A_path_correct", "Answer Accuracy"),
    divide_metric=None,
    names=None,
    figsize=(3.0, 2.5),
    dpi=300,
    error_bar_interval=1,
    ylabel=None,
    loc="lower right",
    palette=None,
    line_colors=None,
    line_styles=None,
    show_legend=True,
    use_fill_between=False,
    xlim=None,
    ylim=None,
    highlight_max=False,
    max_marker="*",
    max_marker_size=100,
):
    """
    Plot mean ± 1 σ curves for one **or many** metrics across several runs,
    optionally dividing each metric by an additional “denominator” metric.

    Parameters
    ----------
    stats_dfs : list[pandas.DataFrame]
        Each DataFrame must contain columns `<metric>_mean` and `<metric>_std`
        for **every** metric passed in `metric_names`, plus for `divide_metric` if specified.
    metric_names : tuple[str, str] or list[tuple[str, str]]
        Single `(column_base, pretty_name)` or list of such pairs for multi‐metric plots.
    divide_metric : tuple[str, str] or None
        If provided, a `(column_base, pretty_name)` pair. Each main metric's mean
        will be divided by this metric's mean, and its std propagated:
            R = A / B,
            σ_R = R * sqrt((σ_A/A)² + (σ_B/B)²).
    names : list[str] or None
        Labels for the runs. When multiple metrics are requested, each legend entry
        becomes "{run label} – {metric pretty_name}".
    figsize, dpi : tuple, int
        Matplotlib figure size and resolution.
    error_bar_interval : int, default 1
        Plot an error bar every *n* points along the curve (ignored if `use_fill_between=True`).
    ylabel : str or None
        Label for the y‐axis. If None, the first metric's pretty name is used.
    loc : str
        Location of the legend (matplotlib location code).
    palette : list/tuple or None
        List of colours. Used if `line_colors` is not provided.
    line_colors : list or None
        Overrides `palette` if provided.
    line_styles : list or None
        List of matplotlib linestyles for each line.
    show_legend : bool, default True
        Whether to display the legend.
    use_fill_between : bool, default False
        If True, shade the 1σ band instead of plotting error bars.
    ylim : tuple or None
        Tuple (ymin, ymax) to set the y‐axis limits explicitly.
    highlight_max : bool, default False
        If True, truncate each curve at its maximum and place a star marker there.
    max_marker : str, default "*"
        Marker style for the maximum point.
    max_marker_size : int, default 100
        Marker size (points^2) for the maximum point.

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """

    # ---------- Input validation ------------------------------------------------
    if not isinstance(stats_dfs, (list, tuple)):
        raise TypeError("stats_dfs must be a list or tuple of DataFrames.")

    # Normalize metric_names to a list of tuples
    if isinstance(metric_names[0], str):
        metric_names = [metric_names]
    n_metrics = len(metric_names)

    # Normalize divide_metric
    if divide_metric is not None:
        if not (isinstance(divide_metric, (list, tuple)) and len(divide_metric) == 2):
            raise ValueError("`divide_metric` must be a (column_base, pretty_name) tuple.")
        denom_base, denom_display = divide_metric

    n_runs = len(stats_dfs)
    if names is None:
        names = ["" for _ in range(n_runs)]
    elif len(names) != n_runs:
        raise ValueError("`names` must have the same length as `stats_dfs`.")

    # ---------- Palette/Color handling ------------------------------------------------
    total_lines = n_runs * n_metrics
    if line_colors is not None:
        if len(line_colors) != total_lines:
            raise ValueError(
                "`line_colors` must have length equal to number of runs * number of metrics."
            )
        palette = line_colors
    elif palette is None or len(palette) < total_lines:
        palette = sns.color_palette("deep", n_colors=total_lines)

    # ---------- Line style handling ---------------------------------------------------
    if line_styles is None:
        line_styles = ["-"] * total_lines
    elif len(line_styles) != total_lines:
        raise ValueError(
            "`line_styles` must have length equal to number of runs * number of metrics."
        )

    # ---------- Plotting --------------------------------------------------------
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=1.0)
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    colour_idx = 0
    for run_idx, (df, run_label) in enumerate(zip(stats_dfs, names)):
        # validate denominator columns once per run
        if divide_metric is not None:
            denom_mean_col = f"{denom_base}_mean"
            denom_std_col = f"{denom_base}_std"
            if denom_mean_col not in df or denom_std_col not in df:
                raise KeyError(
                    f"DataFrame {run_idx} lacks columns '{denom_mean_col}' and/or '{denom_std_col}'."
                )
            denom_mean = df[denom_mean_col].values
            denom_std = df[denom_std_col].values

        for metric_idx, (metric_base, metric_display) in enumerate(metric_names):
            mean_col = f"{metric_base}_mean"
            std_col = f"{metric_base}_std"
            if mean_col not in df or std_col not in df:
                raise KeyError(
                    f"DataFrame {run_idx} lacks columns '{mean_col}' and/or '{std_col}'."
                )

            x = np.arange(1, len(df) + 1)
            y_raw = df[mean_col].values
            y_err_raw = df[std_col].values

            # ---- apply divide-by-metric if requested ----
            if divide_metric is not None:
                ratio = y_raw / denom_mean
                # propagate std for ratio: σ_R = R * sqrt((σ_A/A)^2 + (σ_B/B)^2)
                y_err = ratio * np.sqrt((y_err_raw / y_raw) ** 2 + (denom_std / denom_mean) ** 2)
                y = ratio
            else:
                y = y_raw
                y_err = y_err_raw

            # Determine truncation index if highlighting max
            if highlight_max:
                max_idx = int(np.nanargmax(y))
                plot_x = x[: max_idx + 1]
                plot_y = y[: max_idx + 1]
                plot_y_err = y_err[: max_idx + 1]
            else:
                plot_x, plot_y, plot_y_err = x, y, y_err
                max_idx = None

            color = palette[colour_idx]
            linestyle = line_styles[colour_idx]
            colour_idx += 1

            # Line plot
            ax.plot(
                plot_x,
                plot_y,
                linestyle=linestyle,
                color=color,
                linewidth=1.5,
                label=(
                    f"{run_label}{', ' if run_label else ''}"
                    f"{metric_display}"
                    f"{f' / {denom_display}' if divide_metric else ''}"
                    if n_metrics > 1 or divide_metric
                    else run_label
                ),
            )

            # Error representation
            if use_fill_between:
                ax.fill_between(
                    plot_x,
                    plot_y - plot_y_err,
                    plot_y + plot_y_err,
                    color=color,
                    alpha=0.3,
                    linewidth=0,
                )
            else:
                err_idx = plot_x[::error_bar_interval]
                ax.errorbar(
                    err_idx,
                    plot_y[::error_bar_interval],
                    yerr=plot_y_err[::error_bar_interval],
                    fmt="none",
                    ecolor=color,
                    elinewidth=1.5,
                    capsize=4,
                    capthick=1.5,
                    alpha=0.7,
                )

            # Highlight maximum point
            if highlight_max and max_idx is not None:
                ax.scatter(
                    x[max_idx],
                    y[max_idx],
                    marker=max_marker,
                    color=color,
                    s=max_marker_size,
                    zorder=3,
                )

    # ---------- Axes styling ----------------------------------------------------
    ax.set_xlabel("Epoch", fontweight="bold")
    ax.set_ylabel(
        f"{metric_names[0][1]}{f' / {denom_display}' if divide_metric else ''}"
        if ylabel is None
        else ylabel,
        fontweight="bold",
    )

    # Auto y-limits for single metric between 0 and 1
    if n_metrics == 1 and ylim is None and divide_metric is None:
        all_vals = np.concatenate([df[f"{metric_names[0][0]}_mean"].values for df in stats_dfs])
        if all_vals.min() >= 0 and all_vals.max() <= 1:
            ax.set_ylim(-0.05, 1.05)
    elif ylim is None:
        ax.set_ylim(auto=True)
    else:
        ax.set_ylim(ylim)
    
    if xlim is not None:
        ax.set_xlim(xlim)

    if show_legend:
        ax.legend(loc=loc, frameon=True, fancybox=True)

    sns.despine()
    plt.tight_layout()
    return fig, ax


## Load data

In [None]:
p5_cot = analyze_simulation_runs(["checkpoints/p5_cot_ibac_1.csv", "checkpoints/p5_cot_ibac_2.csv", "checkpoints/p5_cot_ibac_3.csv", "checkpoints/p5_cot_ibac_4.csv",  "checkpoints/p5_cot_ibac_5.csv"])
n5_cot = analyze_simulation_runs(["checkpoints/n5_cot_ibac_1.csv", "checkpoints/n5_cot_ibac_2.csv", "checkpoints/n5_cot_ibac_3.csv", "checkpoints/n5_cot_ibac_4.csv", "checkpoints/n5_cot_ibac_5.csv"])
qa_cot = analyze_simulation_runs(["checkpoints/qa_cot_1.csv", "checkpoints/qa_cot_2.csv", "checkpoints/qa_cot_3.csv",  "checkpoints/qa_cot_4.csv",  "checkpoints/qa_cot_5.csv"])

p5_tokens = analyze_simulation_runs(["checkpoints/p5_tokens_ibac_1.csv", "checkpoints/p5_tokens_ibac_2.csv", "checkpoints/p5_tokens_ibac_3.csv", "checkpoints/p5_tokens_ibac_4.csv", "checkpoints/p5_tokens_ibac_5.csv"])
n5_tokens = analyze_simulation_runs(["checkpoints/n5_tokens_ibac_1.csv", "checkpoints/n5_tokens_ibac_2.csv", "checkpoints/n5_tokens_ibac_3.csv", "checkpoints/n5_tokens_ibac_4.csv", "checkpoints/n5_tokens_ibac_5.csv"])
zero_tokens = analyze_simulation_runs(["checkpoints/zero_tokens_ibac_1.csv", "checkpoints/zero_tokens_ibac_2.csv", "checkpoints/zero_tokens_ibac_3.csv", "checkpoints/zero_tokens_ibac_4.csv", "checkpoints/zero_tokens_ibac_5.csv"])
zero_tokens_temp_0_5 = analyze_simulation_runs(["checkpoints/zero_tokens_temp_0_5_ibac_1.csv", "checkpoints/zero_tokens_temp_0_5_ibac_2.csv", "checkpoints/zero_tokens_temp_0_5_ibac_3.csv", "checkpoints/zero_tokens_temp_0_5_ibac_4.csv", "checkpoints/zero_tokens_temp_0_5_ibac_5.csv"])
zero_tokens_temp_1_0 = analyze_simulation_runs(["checkpoints/zero_tokens_temp_1_0_ibac_1.csv", "checkpoints/zero_tokens_temp_1_0_ibac_2.csv", "checkpoints/zero_tokens_temp_1_0_ibac_3.csv", "checkpoints/zero_tokens_temp_1_0_ibac_4.csv", "checkpoints/zero_tokens_temp_1_0_ibac_5.csv"])

p5_stoch = analyze_simulation_runs(["checkpoints/p5_stoch_redundancy_ibac_1.csv", "checkpoints/p5_stoch_redundancy_ibac_2.csv", "checkpoints/p5_stoch_redundancy_ibac_3.csv", "checkpoints/p5_stoch_redundancy_ibac_4.csv", "checkpoints/p5_stoch_redundancy_ibac_5.csv"])
p5_stoch_temp = analyze_simulation_runs(["checkpoints/p5_stoch_redundancy_ibac_1_temp_1_0.csv", "checkpoints/p5_stoch_redundancy_ibac_2_temp_1_0.csv", "checkpoints/p5_stoch_redundancy_ibac_3_temp_1_0.csv", "checkpoints/p5_stoch_redundancy_ibac_4_temp_1_0.csv", "checkpoints/p5_stoch_redundancy_ibac_5_temp_1_0.csv"])
p5_det = analyze_simulation_runs(["checkpoints/p5_det_redundancy_ibac_1.csv", "checkpoints/p5_det_redundancy_ibac_2.csv", "checkpoints/p5_det_redundancy_ibac_3.csv", "checkpoints/p5_det_redundancy_ibac_4.csv", "checkpoints/p5_det_redundancy_ibac_5.csv"])

zero_32_tokens = analyze_simulation_runs(["checkpoints/zero_32M_tokens_ibac_1.csv","checkpoints/zero_32M_tokens_ibac_2.csv", "checkpoints/zero_32M_tokens_ibac_3.csv", "checkpoints/zero_32M_tokens_ibac_4.csv", "checkpoints/zero_32M_tokens_ibac_5.csv"])
n5_32_tokens = analyze_simulation_runs(["checkpoints/n5_32M_tokens_ibac_1.csv", "checkpoints/n5_32M_tokens_ibac_2.csv", "checkpoints/n5_32M_tokens_ibac_3.csv", "checkpoints/n5_32M_tokens_ibac_4.csv", "checkpoints/n5_32M_tokens_ibac_5.csv"])
p5_32_tokens = analyze_simulation_runs(["checkpoints/p5_32M_tokens_ibac_1.csv", "checkpoints/p5_32M_tokens_ibac_2.csv", "checkpoints/p5_32M_tokens_ibac_3.csv", "checkpoints/p5_32M_tokens_ibac_4.csv", "checkpoints/p5_32M_tokens_ibac_5.csv"])

p5_32M_tokens_3_layers_scheduler = analyze_simulation_runs(["checkpoints/p5_32M_tokens_3_layers_scheduler_1.csv", "checkpoints/p5_32M_tokens_3_layers_scheduler_2.csv", "checkpoints/p5_32M_tokens_3_layers_scheduler_3.csv"])
p5_32M_tokens_6_layers_scheduler = analyze_simulation_runs(["checkpoints/p5_32M_tokens_6_layers_scheduler_1.csv", "checkpoints/p5_32M_tokens_6_layers_scheduler_2.csv", "checkpoints/p5_32M_tokens_6_layers_scheduler_3.csv"])
p5_16M_tokens_3_layers_scheduler = analyze_simulation_runs(["checkpoints/p5_16M_tokens_3_layers_scheduler_1.csv", "checkpoints/p5_16M_tokens_3_layers_scheduler_2.csv", "checkpoints/p5_16M_tokens_3_layers_scheduler_3.csv"])

## p5 vs n5 200k CoT is_A_path_correct

In [None]:
fig, ax = plot_with_error_bars_multi([n5_cot, p5_cot], names=[r"$\eta=-5$", r"$\eta=+5$"], line_colors=[COLOR_N5, COLOR_P5], use_fill_between=True, show_legend=False, highlight_max=True)
fig.savefig('p5_n5_path_correctness_200k_fig1.pdf', bbox_inches='tight')

## p5 vs n5 200k CoT confidence

In [None]:
fig, ax = plot_with_error_bars_multi([p5_cot, n5_cot], names=[r"Efficient Trace", r"Inefficient Trace", ], metric_names=("avg_prob_1", "Next-Token Confidence"), use_fill_between=True, line_colors=[COLOR_P5, COLOR_N5], ylim=(0.7, 1.0), show_legend=True)
fig.savefig('p5_n5_token_confidence_128M_fig1.pdf', bbox_inches='tight')

## p5 vs QA 350k CoT

In [None]:
ORANGE=   "#E69F00"  # orange
SKYBLUE=    "#56B4E9"  # sky-blue
BG =    "#009E73"  # bluish-green
YEL =    "#F0E442"  # yellow
BLUE =    "#0072B2"  # blue
VERMIL=   "#D55E00"  # vermilion
REDPURP=    "#CC79A7"  # reddish-purple
BLACK=    "#000000"  # black

In [None]:
import matplotlib.patches as mpatches

fig, ax = plot_with_error_bars_multi([qa_cot, p5_tokens], 
                                     names=[r"QA", r"$\eta=+5$"],
                                     line_colors=["#5D6D7E", "#8E44AD", "#CC79A7", "#5D6D7E", "#8E44AD", "#CC79A7"],
                                     metric_names=[("is_A_path_correct_l3", r"$d=3$"), ("is_A_path_correct_l5", r"$d=5$"), ("is_A_path_correct_l7", r"$d=7$")],
                                     ylabel="Answer Accuracy", use_fill_between=True, show_legend=False,
                                     line_styles=["--","--","--","-","-","-"], xlim=(0, 25)
                                    )
legend_labels = ['graph depth 3', 'graph depth 5', 'graph depth 7']
legend_colors = ['#5D6D7E', '#8E44AD', '#CC79A7']
    
# Create patches (squares) for each legend entry
patches = [
    mpatches.Rectangle((0, 0), 1, 1, facecolor=color, edgecolor='black', linewidth=0.5, label=label)
    for label, color in zip(legend_labels, legend_colors)
]

# Add the custom legend
custom_legend = ax.legend(
    handles=patches,
    loc="lower right",
    frameon=True,
    fancybox=True,
    title=None,
    handlelength=1.0,
    handleheight=1.0
)
fig.savefig('p5_qa_350k_fig2.pdf', bbox_inches='tight')

## p5 128M submetrics

In [None]:
fig, ax = plot_with_error_bars_multi([p5_tokens], 
                                     metric_names=[("is_A_path_correct", "path optimal"), ("is_A_path_possible", "path possible"), ("is_A_cost_consistent", "cost consistent"), ("is_A_cost_optimal", "cost optimal")],
                                     ylabel="", show_legend=True, use_fill_between=True,
                                     line_colors=[COLOR_P5, "#E69F00", "green", COLOR_N5],
                                    )
fig.savefig('p5_tokens_submetrics_fig2.pdf', bbox_inches='tight')

## Cost analisys

In [None]:
from matplotlib import cm, colors
base = cm.get_cmap('YlGnBu_r', 256)          # 256-level version
# Keep only the upper 80 % of the map (0.2 → 1.0)
trim = colors.LinearSegmentedColormap.from_list(
        'YlGnBu_r_trim', base(np.linspace(0.2, 1.0, 256)))

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Sequence, Optional, Union

def plot_accuracy_heatmaps(
    tables: Sequence[Union[np.ndarray, pd.DataFrame]],
    titles: Optional[Sequence[str]] = None,
    figsize=(4.0, 4.0),
    dpi=300,
    cmap='viridis',
    vmin=0.0,
    vmax=1.0,
    xlabel='Addend',
    ylabel='Addend',
    font_scale: float = 1.0,
    label_fontsize: Optional[float] = None,
    title_fontsize: Optional[float] = None,
    cbar_label: str = "Accuracy"
):
    """
    Plot an arbitrary number of accuracy tables side by side as heatmaps
    with unified x-axis labels and consistent font sizing.

    Parameters
    ----------
    tables : sequence of 2D arrays or DataFrames
        Your accuracy matrices. Length must be >= 1.
    titles : sequence of str, optional
        Titles for each subplot. Defaults to ['Table 1', 'Table 2', ...].
    figsize : tuple
        Overall figure size.
    dpi : int
        Figure resolution.
    cmap : str
        Matplotlib colormap.
    vmin, vmax : float
        Value range for the color scaling.
    xlabel, ylabel : str
        Axis labels (ylabel only on leftmost plot, xlabel shown once below all plots).
    font_scale : float
        Seaborn context font scale to match other plots.
    label_fontsize : float, optional
        Explicit font size for axis & tick labels (overrides font_scale if provided).
    title_fontsize : float, optional
        Explicit font size for titles (overrides font_scale if provided).
    cbar_label : str
        Label for the colorbar.
    Returns
    -------
    fig, axes : matplotlib Figure and list of Axes
    """
    n = len(tables)
    if n < 1:
        raise ValueError("Need at least one table to plot.")
    if titles is None:
        titles = [f"Table {i+1}" for i in range(n)]
    if len(titles) != n:
        raise ValueError("`titles` must have the same length as `tables`.")

    # Font sizes
    base_label_fs = label_fontsize or (font_scale * plt.rcParams['axes.labelsize'])
    base_title_fs = title_fontsize or (font_scale * plt.rcParams['axes.titlesize'])
    tick_fs = int(font_scale) * plt.rcParams['xtick.labelsize']

    # Convert inputs to arrays
    mats = [df.values if isinstance(df, pd.DataFrame) else df for df in tables]

    # Seaborn styling
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=font_scale)

    # Create figure & axes
    fig, axes = plt.subplots(
        ncols=n,
        figsize=figsize,
        dpi=dpi,
        gridspec_kw={'wspace': 0.05},
        constrained_layout=True
    )
    # if only one axis, wrap it
    if n == 1:
        axes = [axes]

    # Turn off major gridlines
    for ax in axes:
        ax.grid(False, which="major")

    # Plot each heatmap
    for idx, (ax, mat, title) in enumerate(zip(axes, mats, titles)):
        im = ax.imshow(mat, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='nearest')

        rows, cols = mat.shape
        # X ticks & labels on every plot
        ax.set_xticks(np.arange(cols))
        ax.set_xticklabels([str(i+1) for i in range(cols)],
                           fontweight="bold", fontsize=tick_fs)

        # Y ticks & ylabel only on first plot
        if idx == 0:
            ax.set_yticks(np.arange(rows))
            ax.set_yticklabels([str(i) for i in range(rows)],
                               fontweight="bold", fontsize=tick_fs)
            ax.set_ylabel(ylabel, fontweight="bold", fontsize=base_label_fs)
        else:
            ax.set_yticks([])

        # Minor gridlines between cells
        ax.set_xticks(np.arange(cols+1)-0.5, minor=True)
        ax.set_yticks(np.arange(rows+1)-0.5, minor=True)
        ax.grid(which="minor", color="w", linestyle='-', linewidth=2)
        ax.tick_params(which="minor", length=0)

        # Title
        ax.set_title(title, fontweight="bold", fontsize=base_title_fs)

    # Shared x-axis label
    fig.supxlabel(xlabel, fontweight="bold", fontsize=base_label_fs)

    # Shared colorbar on the right
    cbar = fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
    cbar.ax.set_ylabel(cbar_label, rotation=-90, va="bottom",
                       fontweight="bold", fontsize=base_label_fs)
    cbar.ax.tick_params(labelsize=tick_fs)

    sns.despine(left=True, bottom=True)

    return fig, axes


In [None]:
acc1 = pd.read_csv("checkpoints/p5_tokens__ibac_1_epoch_1_accuracy_table.csv").iloc[:, 1:]
acc2 = pd.read_csv("checkpoints/p5_tokens__ibac_1_epoch_2_accuracy_table.csv").iloc[:, 1:]
acc3 = pd.read_csv("checkpoints/p5_tokens__ibac_1_epoch_4_accuracy_table.csv").iloc[:, 1:]

fig, axes = plot_accuracy_heatmaps(
    tables=[acc1, acc2, acc3],
    titles=["Epoch 1", "Epoch 2", "Epoch 4"],
    title_fontsize=16.0,
    label_fontsize=18.0,
)
fig.savefig('p5_cost_accuracy_fig2.pdf', bbox_inches='tight')

## n_CoT_steps

In [None]:
fig, ax = plot_with_error_bars_multi([zero_tokens, zero_tokens_temp_0_5, zero_tokens_temp_1_0], 
                                     names=[r"T=0", r"T=0.5", r"T=1.0"],
                                     metric_names=[("n_CoT_steps", "CoT steps")],
                                     loc="upper right", show_legend=True, use_fill_between=True,
                                     line_colors=[COLOR_ZERO, COLOR_ZERO, COLOR_ZERO],
                                     line_styles=["-", "--", ":"]
                                    )
ax.set_ylim(20, 80)

y_vals = [33, 43, 58]
adjs = [5, 4, 5]
labels = [r"$\eta=+5$", r"$\eta=0$", r"$\eta=-5$"]
cs = [COLOR_P5, COLOR_ZERO, COLOR_N5]
x_text = ax.get_xlim()[1] # adjust as needed for spacing
for y, label, adj, col in zip(y_vals, labels, adjs, cs):
    ax.axhline(y, linewidth=1.5, ls="--", alpha=0.5, color="black", zorder=1)
    ax.text(
        x=x_text + adj,
        y=y, 
        s=label, 
        color=col,
        va='center', 
        ha='right', 
        fontsize=10
    )

fig.savefig('cot_steps.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([zero_tokens, zero_tokens_temp_0_5, zero_tokens_temp_1_0], 
                                     names=[r"T=0", r"T=0.5", r"T=1.0"],
                                     loc="lower right", show_legend=True, use_fill_between=True,
                                     line_colors=[COLOR_ZERO, COLOR_ZERO, COLOR_ZERO],
                                     line_styles=["-", "--", ":"]
                                    )

## 32M tokens v 128M tokens

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, n5_32_tokens, zero_tokens, zero_32_tokens, p5_tokens, p5_32_tokens], names=[r"$\eta=-5 128M$", r"$\eta=-5$ 32M T", r"$\eta=0 128M$", r"$\eta=0$ 32M T", r"$\eta=5$", r"$\eta=5$ 32M T"],
line_colors=[COLOR_N5,COLOR_N5,COLOR_ZERO,COLOR_ZERO, COLOR_P5,COLOR_P5,],
line_styles=['-', "--", "-", "--", "-", "--"],
show_legend=False,
use_fill_between=True,
highlight_max=True)

legend_labels = [r"$\eta=-5$", r"$\eta=+5$", r"$\eta=0$"]
legend_colors = [COLOR_N5, COLOR_P5, COLOR_ZERO]
    
# Create patches (squares) for each legend entry
patches = [
    mpatches.Rectangle((0, 0), 1, 1, facecolor=color, edgecolor='black', linewidth=0.5, label=label)
    for label, color in zip(legend_labels, legend_colors)
]

# Add the custom legend
custom_legend = ax.legend(
    handles=patches,
    loc="lower right",
    frameon=True,
    fancybox=True,
    title=None,
    handlelength=1.0,
    handleheight=1.0
)
fig.savefig('n5_p5_zero_32M_128M_fig4.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, zero_tokens, p5_tokens], names=[r"$\eta=-5 128M$", r"$\eta=0 128M$", r"$\eta=5$"],
line_colors=[COLOR_N5,COLOR_ZERO,COLOR_P5],
line_styles=['-', "-", "-"],
show_legend=False,
use_fill_between=True,
metric_names=[("avg_prob_1", "Next-Token Confidence")],
ylim=(0.6, 1.0),
)
fig.savefig('n5_p5_zero_128M_confidence_fig4.pdf', bbox_inches='tight')

## Redundancy

In [None]:
fig, ax = plot_with_error_bars_multi([p5_tokens, p5_stoch, p5_stoch_temp, p5_det], 
                                     show_legend=False, 
                                     use_fill_between=True, 
                                     line_styles=["-", "-", "--", "-"],
                                     line_colors=[COLOR_P5, "#8C4C36", "#8C4C36", "#E69F00"]
                                     )
fig.savefig('redundancy_fig5.pdf', bbox_inches='tight')

## train loss

In [None]:
# 128M
p5_loss_1 = analyze_simulation_runs(["checkpoints/p5_tokens_1_train_loss.csv"])
p5_loss_2 = analyze_simulation_runs([ "checkpoints/p5_tokens_2_train_loss.csv"])
p5_loss_3 = analyze_simulation_runs([ "checkpoints/p5_tokens_3_train_loss.csv"])
p5_loss_4 = analyze_simulation_runs([ "checkpoints/p5_tokens_4_train_loss.csv"])
p5_loss_5 = analyze_simulation_runs([ "checkpoints/p5_tokens_5_train_loss.csv"])

n5_loss_1 = analyze_simulation_runs(["checkpoints/n5_tokens_1_train_loss.csv"])
n5_loss_2 = analyze_simulation_runs(["checkpoints/n5_tokens_2_train_loss.csv"])
n5_loss_3 = analyze_simulation_runs(["checkpoints/n5_tokens_3_train_loss.csv"])
n5_loss_4 = analyze_simulation_runs(["checkpoints/n5_tokens_4_train_loss.csv"])
n5_loss_5 = analyze_simulation_runs(["checkpoints/n5_tokens_5_train_loss.csv"])

In [None]:
def plot_loss(
    stats_dfs: Sequence[pd.DataFrame],
    metric_names=('is_A_path_correct', 'Answer Correctness'),
    names=None,
    x_col="_step_mean",              # ← NEW: which column to use for the abscissa
    figsize=(3.0, 2.5),
    dpi=300,
    error_bar_interval=1,
    ylabel=None,
    loc='lower right',
    palette=None,
    line_colors=None,
    line_styles=None,
    show_legend=True,
    use_fill_between=False,
    no_errorbars=False,
    ylim=None
):
    """
    Plot mean ± 1 σ curves for one **or many** metrics across several runs on
    LOG–LOG axes.

    Parameters
    ----------
    x_col : str
        Name of the column that contains the x-values (e.g. steps, epochs).
        If None, the DataFrame index is used.
    """
    # ---------- Input validation ------------------------------------------------
    if not isinstance(stats_dfs, (list, tuple)):
        raise TypeError("stats_dfs must be a list or tuple of DataFrames.")

    # Normalise `metric_names` → list[(base, pretty)]
    if isinstance(metric_names[0], str):
        metric_names = [metric_names]
    n_metrics = len(metric_names)
    n_runs    = len(stats_dfs)

    if names is None:
        names = [f'Run {i + 1}' for i in range(n_runs)]
    elif len(names) != n_runs:
        raise ValueError("`names` must have the same length as `stats_dfs`.")

    # ---------- Colour & style bookkeeping --------------------------------------
    total_lines = n_runs * n_metrics
    if line_colors is not None:
        if len(line_colors) != total_lines:
            raise ValueError("`line_colors` length must equal runs × metrics.")
        palette = line_colors
    elif palette is None or len(palette) < total_lines:
        palette = sns.color_palette('deep', n_colors=total_lines)

    if line_styles is None:
        line_styles = ['-'] * total_lines
    elif len(line_styles) != total_lines:
        raise ValueError("`line_styles` length must equal runs × metrics.")

    # ---------- Plotting ---------------------------------------------------------
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=1.0)

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    #ax.set_xscale('log')
    #ax.set_yscale('log')

    colour_idx = 0
    for df, run_label in zip(stats_dfs, names):
        # Resolve x values only once per run
        x = df[x_col].values if x_col is not None else df.index.values
        x = np.array([sum(x[i:i+5]) / len(x[i:i+5]) for i in range(0, len(x), 5)])
        lo, hi = 0.0, 20.0   # new range
        x  = (x - x.min()) / (x.max() - x.min()) * (hi - lo) + lo
        x = x[:20]
        for metric_col_base, metric_display in metric_names:
            mean_col = f'{metric_col_base}_mean'
            std_col  = f'{metric_col_base}_std'
            if mean_col not in df or std_col not in df:
                raise KeyError(
                    f"DataFrame for '{run_label}' lacks '{mean_col}' and/or '{std_col}'."
                )
            
            
            y     = df[mean_col].values
            y = np.array([sum(y[i:i+5]) / len(y[i:i+5]) for i in range(0, len(y), 5)])[:20] 
            y_err = df[std_col].values
            y_err = np.array([sum(y_err[i:i+5]) / len(y_err[i:i+5]) for i in range(0, len(y_err), 5)])[:20]
            color = palette[colour_idx]
            style = line_styles[colour_idx]
            colour_idx += 1

            # main curve
            ax.plot(
                x, y,
                linestyle=style, color=color, linewidth=2.5, alpha=0.7,
                label= '_nolegend_'
            )

            # uncertainty visuals
            if not no_errorbars:
                if use_fill_between:
                    ax.fill_between(
                        x, y - y_err, y + y_err,
                        color=color, alpha=0.3, linewidth=0
                    )
                else:
                    # integer indices *into the arrays* for error bars
                    err_idx = np.arange(len(x))[::error_bar_interval]
                    ax.errorbar(
                        x[err_idx], y[err_idx], yerr=y_err[err_idx],
                        fmt='none', ecolor=color, elinewidth=1.5,
                        capsize=4, capthick=1.5, alpha=0.7
                    )

    # ---------- Axes & legend tweaks --------------------------------------------
    ax.set_xlabel('Epoch', fontweight='bold')
    ax.set_ylabel(metric_names[0][1] if ylabel is None else ylabel,
                  fontweight='bold')
    
    
    from matplotlib.legend_handler import HandlerTuple

    GROUP = 5                 # how many curves per legend entry
    if show_legend:
        # 1️⃣  collect all real line handles = children with a label
        all_lines  = [c for c in ax.get_children() if isinstance(c, plt.Line2D)]
        # 2️⃣  bundle every GROUP lines into a tuple
        handles = [tuple(all_lines[i:i + GROUP])
                   for i in range(0, len(all_lines), GROUP)]
        # 3️⃣  pick a label for each bundle
        if names and len(names) == len(handles):
            labels = names
        else:  # fallback: Run 1, Run 2, …
            labels = ['$\eta=5$', '$\eta=-5$']
    
        # 4️⃣  build the legend with a handler that can draw tuples
        ax.legend(handles=handles,
                  labels=labels,
                  handler_map={tuple: HandlerTuple(ndivide=None)},
                  loc=loc, frameon=True, fancybox=True)

    if ylim is not None:
        ax.set_ylim(ylim)

    sns.despine()
    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_loss([p5_loss_1, p5_loss_2, p5_loss_3, p5_loss_4, p5_loss_5, n5_loss_1, n5_loss_2, n5_loss_3, n5_loss_4, n5_loss_5], names=[r"$\eta=5$, seed 1", r"$\eta=5$, seed 2", r"$\eta=5$, seed 3", r"$\eta=5$, seed 4", r"$\eta=5$, seed 5", r"$\eta=-5$, seed 1", r"$\eta=-5$, seed 2", r"$\eta=-5$, seed 3", r"$\eta=-5$, seed 4", r"$\eta=-5$, seed 5"], metric_names=[("train/loss", "Training loss")], loc="upper right",
 use_fill_between=True,
show_legend=True,
line_colors=[
    "#C51B00",  # original (seed 1)
    "#A81800",  # darker red
    "#8C1500",  # more muted
    "#701200",  # even darker
    "#541000",  # darkest shade

    "#0072B2",  # original (seed 1)
    "#005A8F",  # darker blue
    "#00426B",  # more muted
    "#002B48",  # even darker
    "#001324",  # darkest shade
],)
fig.savefig('./n5_p5_128M_train_loss_fig4.pdf', bbox_inches='tight')

# Supplementary

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, p5_tokens], names=[r"$\eta=-5$", r"$\eta=5$"],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_N5,COLOR_P5],
metric_names=[("repeated_CoT_steps", "repeated_CoT_steps")],
divide_metric=("n_CoT_steps", "cot steps"),
ylabel="",
)
fig.savefig('n5_p5_repeated_cot_steps_perc.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, p5_tokens], names=[r"$\eta=-5$", r"$\eta=5$"],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_N5,COLOR_P5],
metric_names=[("sub_prob_optimal_CoT_steps", "sub_prob_optimal_CoT_steps")],
divide_metric=("n_CoT_steps", "cot steps"),
ylabel="",
)
fig.savefig('n5_p5_sub_prob_optimal_perc.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, p5_tokens], names=[r"$\eta=-5$", r"$\eta=5$"],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_N5,COLOR_P5],
metric_names=[("consistent_CoT_steps", "consistent_CoT_steps")],
divide_metric=("n_CoT_steps", "cot steps"),
ylabel="",
)
fig.savefig('n5_p5_consistent_cot_steps_perc.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, p5_tokens], names=[r"$\eta=-5$", r"$\eta=5$"],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_N5,COLOR_P5],
metric_names=[("CoT_path_possible", "CoT_path_possible")],
divide_metric=("n_CoT_steps", "cot steps"),
ylabel="",
)
fig.savefig('n5_p5_cot_path_possible.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, p5_tokens], names=[r"$\eta=-5$", r"$\eta=5$"],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_N5,COLOR_P5],
metric_names=[("CoT_steps_skipped_sub_prob", "CoT_steps_skipped_sub_prob")],
divide_metric=("n_CoT_steps", "cot steps"),
ylabel="",
)
fig.savefig('n5_p5_cot_steps_skipped.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([n5_tokens, p5_tokens], names=[r"$\eta=-5$", r"$\eta=5$"],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_N5,COLOR_P5],
metric_names=[("syntax_errors", "syntax_errors")],
ylabel="",
)
fig.savefig('n5_p5_syntax_errors.pdf', bbox_inches='tight')

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter


def plot_histograms(
    data: Sequence[pd.Series] | Sequence[np.ndarray],
    labels=None,
    bins=30,
    density=False,
    cumulative=False,
    figsize=(3.0, 2.5),
    dpi=300,
    palette=None,
    colors=None,
    alpha=0.6,
    edgecolor="black",
    linewidth=1.0,
    xlabel=None,
    ylabel=None,
    title=None,
    loc="upper right",
    show_legend=True,
    ylim=None,
    xlim=None,
):
    """
    Plot one or more overlaid histograms with y-axis shown as percentages.

    Parameters
    ----------
    data : list of pandas.Series or numpy.ndarray
        Each element provides the values to histogram. All series will be overlaid.
    labels : list[str] or None
        Legend labels for each dataset.
    bins : int or sequence, default 30
        Number of bins (or explicit bin edges) for the histograms.
    density : bool, default False
        If True, show probability density instead of counts (in which case y-axis percentages are of density values).
    cumulative : bool, default False
        If True, plot the cumulative histogram.
    figsize : tuple(float, float), default (3.0, 2.5)
        Figure size in inches.
    dpi : int, default 300
        Figure resolution.
    palette : list or None
        A list of colors to cycle through. Ignored if `colors` is provided.
    colors : list or None
        Explicit list of colors for each histogram. Overrides `palette`.
    alpha : float, default 0.6
        Transparency for each histogram patch.
    edgecolor : str, default "black"
        Color of the bar edges.
    linewidth : float, default 1.0
        Width of the bar edges.
    xlabel : str or None
        Label for the x-axis.
    ylabel : str or None
        Label for the y-axis. If None, will default to "%".
    title : str or None
        Plot title.
    loc : str, default "upper right"
        Legend location.
    show_legend : bool, default True
        Whether to draw the legend.
    ylim : tuple or None
        y-axis limits as (min, max).
    xlim : tuple or None
        x-axis limits as (min, max).

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """
    # Input validation
    if not isinstance(data, (list, tuple)):
        raise TypeError("`data` must be a list or tuple of arrays or Series.")
    n = len(data)
    if labels is None:
        labels = [f"Series {i+1}" for i in range(n)]
    elif len(labels) != n:
        raise ValueError("`labels` must have the same length as `data`.")

    # Color handling
    if colors is not None:
        if len(colors) != n:
            raise ValueError("`colors` must have same length as `data`.")
        palette = colors
    elif palette is None or len(palette) < n:
        palette = sns.color_palette("deep", n_colors=n)

    # Plot setup
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=1.0)
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)

    # Plot each histogram
    for i, series in enumerate(data):
        arr = np.asarray(series).ravel()
        # For raw counts, convert to percent of total; for density, leave as-is
        if density:
            hist_kwargs = {"density": True}
        else:
            hist_kwargs = {"weights": np.ones_like(arr) / arr.size * 100}

        ax.hist(
            arr,
            bins=bins,
            cumulative=cumulative,
            color=palette[i],
            alpha=alpha,
            edgecolor=edgecolor,
            linewidth=linewidth,
            label=labels[i],
            **hist_kwargs,
        )

    # Axes styling
    if xlabel:
        ax.set_xlabel(xlabel, fontweight="bold")
    if ylabel:
        ax.set_ylabel(ylabel, fontweight="bold")
    else:
        ax.set_ylabel("%", fontweight="bold")
    if title:
        ax.set_title(title, fontweight="bold")

    # Format y-axis labels as percentages
    ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.0f}%"))

    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)

    if show_legend:
        ax.legend(loc=loc, frameon=True, fancybox=True)

    sns.despine()
    plt.tight_layout()
    return fig, ax


In [None]:
cot_steps_n5 = pd.read_csv("checkpoints/cot_steps_efficiency_n5.csv").loc[7:]["parameter"].astype(int)
cot_steps_p5 = pd.read_csv("checkpoints/cot_steps_efficiency_p5.csv").loc[7:]["parameter"].astype(int)

In [None]:
fig, ax = plot_histograms([cot_steps_n5],
                          show_legend=False,
                          xlabel="CoT steps",
                          ylabel="occurrences",
                          colors=[COLOR_N5],
                          alpha=1.0
                          )

fig.savefig('cot_steps_n5.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_histograms([cot_steps_p5],
                          show_legend=False,
                          xlabel="CoT steps",
                          ylabel="occurrences",
                          colors=[COLOR_P5],
                          alpha=1.0
                          )

fig.savefig('cot_steps_p5.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([p5_32M_tokens_3_layers_scheduler, p5_32M_tokens_6_layers_scheduler],
show_legend=False,
use_fill_between=True,
line_colors=[COLOR_P5,COLOR_P5],
line_styles=["-", "--"],
)
fig.savefig('p5_3_6_layers.pdf', bbox_inches='tight')

In [None]:
fig, ax = plot_with_error_bars_multi([p5_16M_tokens_3_layers_scheduler, p5_32M_tokens_3_layers_scheduler],
                                     show_legend=False,
                                     use_fill_between=True,
                                     line_colors=[COLOR_P5,COLOR_P5],
                                     line_styles=["-.", "-"],
                                     )
fig.savefig('p5_16M_32M.pdf', bbox_inches='tight')

In [None]:
sum_table_n5 = pd.read_csv("checkpoints/sum_table_n5.csv").iloc[:, 1:]
sum_table_p5 = pd.read_csv("checkpoints/sum_table_p5.csv").iloc[:, 1:]

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Sequence, Optional, Union

def plot_log1p_heatmaps(
    tables: Sequence[Union[np.ndarray, pd.DataFrame]],
    titles: Optional[Sequence[str]] = None,
    figsize=(3.0, 4.0),
    dpi=300,
    cmap='viridis',
    vmin: Optional[float] = 0.0,
    vmax: Optional[float] = None,
    xlabel='Addend',
    ylabel='Addend',
    font_scale: float = 1.0,
    label_fontsize: Optional[float] = None,
    title_fontsize: Optional[float] = None,
    cbar_label: str = "log(count + 1)"
):
    """
    Plot an arbitrary number of tables of values, but display log(x+1) of
    each cell instead of the raw values.

    All other styling and layout mirrors the original heatmap function.
    """
    n = len(tables)
    if n < 1:
        raise ValueError("Need at least one table to plot.")
    if titles is None:
        titles = [f"Table {i+1}" for i in range(n)]
    if len(titles) != n:
        raise ValueError("`titles` must have the same length as `tables`.")

    # Font sizes
    base_label_fs = label_fontsize or (font_scale * plt.rcParams['axes.labelsize'])
    base_title_fs = title_fontsize or (font_scale * plt.rcParams['axes.titlesize'])
    tick_fs       = font_scale * plt.rcParams['xtick.labelsize']

    # Convert & transform inputs to arrays via log1p
    mats = []
    for tbl in tables:
        arr = tbl.values if isinstance(tbl, pd.DataFrame) else tbl
        mats.append(np.log1p(arr))

    # If vmax not set, derive from max(log1p(x)) across all mats
    if vmax is None:
        vmax = max(mat.max() for mat in mats)

    # Seaborn styling
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=font_scale)

    # Create figure & axes
    fig, axes = plt.subplots(
        ncols=n,
        figsize=figsize,
        dpi=dpi,
        gridspec_kw={'wspace': 0.05},
        constrained_layout=True
    )
    # normalize axes array
    if n == 1:
        axes = [axes]

    # Turn off major gridlines
    for ax in axes:
        ax.grid(False, which="major")

    # Plot each heatmap
    for idx, (ax, mat, title) in enumerate(zip(axes, mats, titles)):
        im = ax.imshow(mat, cmap=cmap, vmin=vmin, vmax=vmax, interpolation='nearest')

        rows, cols = mat.shape
        ax.set_xticks(np.arange(cols))
        ax.set_xticklabels([str(i+1) for i in range(cols)],
                           fontweight="bold", fontsize=tick_fs)

        if idx == 0:
            ax.set_yticks(np.arange(rows))
            ax.set_yticklabels([str(i) for i in range(rows)],
                               fontweight="bold", fontsize=tick_fs)
            ax.set_ylabel(ylabel, fontweight="bold", fontsize=base_label_fs)
        else:
            ax.set_yticks([])

        # Minor gridlines between cells
        ax.set_xticks(np.arange(cols+1)-0.5, minor=True)
        ax.set_yticks(np.arange(rows+1)-0.5, minor=True)
        ax.grid(which="minor", color="w", linestyle='-', linewidth=2)
        ax.tick_params(which="minor", length=0)

        ax.set_title(title, fontweight="bold", fontsize=base_title_fs)

    # Shared x-axis label
    fig.supxlabel(xlabel, fontweight="bold", fontsize=base_label_fs)

    # Shared colorbar on the right
    cbar = fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
    cbar.ax.set_ylabel(cbar_label, rotation=-90, va="bottom",
                       fontweight="bold", fontsize=base_label_fs)
    cbar.ax.tick_params(labelsize=tick_fs)

    sns.despine(left=True, bottom=True)

    return fig, axes


fig, axes = plot_log1p_heatmaps(
    tables=[sum_table_n5, sum_table_p5],
    titles=[r"$\eta=-5$", r"$\eta=+5$"],
    title_fontsize=16.0,
    label_fontsize=18.0,
)
fig.savefig('sum_table_costs.pdf', bbox_inches='tight')

# Getting losses from wandb

In [None]:
import wandb

wandb.login()
api = wandb.Api()

In [None]:
run = api.run(f"[put here run id]")
history = run.history()
loss_data = history[["_step", "train/loss"]]
loss_data.to_csv("checkpoints/p5_tokens_3_ibac_train_loss.csv", index=False)