
# Results Accuracy Plots

Use these cells to generate checkpoint accuracy plots from aggregated evaluation JSON files under `results/`. Select one or more result files, choose which accuracy metrics to display, and the notebook will render a separate plot for each dataset (e.g., `amc23`, `gsm8k`, `math500`).


In [None]:

from __future__ import annotations

import json
from statistics import mean
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Sequence, Tuple

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import clear_output, display


In [None]:

def resolve_results_dir() -> Path:
    """Best-effort search for the `results` folder relative to the notebook."""
    candidates = [
        Path.cwd() / "results",
        Path.cwd().parent / "results",
        Path.cwd().resolve() / "../results",
    ]
    for candidate in candidates:
        candidate = candidate.resolve()
        if candidate.exists() and candidate.is_dir():
            return candidate
    raise FileNotFoundError(
        "Could not locate a `results` directory relative to the current working directory."
    )

RESULTS_DIR = resolve_results_dir()
AVAILABLE_FILES = sorted(RESULTS_DIR.glob("*.json"))

if not AVAILABLE_FILES:
    raise FileNotFoundError(f"No JSON files found in {RESULTS_DIR}")

print(f"Found {len(AVAILABLE_FILES)} result files in {RESULTS_DIR}:")
for path in AVAILABLE_FILES:
    print(f"  - {path.name}")

DEFAULT_METRICS: Sequence[str] = ("ans_acc", "for_acc", "both_acc")


In [None]:

JsonDict = Dict[str, Mapping[str, Mapping[str, object]]]
Aggregated = Dict[str, Dict[str, Dict[int, float]]]


def load_results(path: Path) -> JsonDict:
    with path.open("r", encoding="utf-8") as fp:
        return json.load(fp)


def compute_means(raw: JsonDict) -> Aggregated:
    """Compute mean accuracy values per dataset/metric/checkpoint."""
    aggregated: Aggregated = {}
    for checkpoint_label, datasets in raw.items():
        try:
            checkpoint = int(checkpoint_label.split("-")[-1])
        except ValueError:
            continue
        for dataset_name, metrics in datasets.items():
            dataset_entry = aggregated.setdefault(dataset_name, {})
            for metric_name, values in metrics.items():
                if metric_name == "examples":  # skip verbose sample data
                    continue
                if not isinstance(values, Sequence):
                    continue
                numeric_values = [v for v in values if isinstance(v, (int, float))]
                if not numeric_values:
                    continue
                metric_entry = dataset_entry.setdefault(metric_name, {})
                metric_entry[checkpoint] = mean(numeric_values)
    return aggregated


def collect_available_metrics(aggregated: Mapping[str, Mapping[str, Mapping[int, float]]]) -> List[str]:
    metrics: set[str] = set()
    for dataset_metrics in aggregated.values():
        metrics.update(dataset_metrics.keys())
    return sorted(metrics)


def gather_dataset_bounds(paths: Iterable[Path], metrics: Sequence[str] | None = None) -> Dict[str, Tuple[float, float]]:
    """Return min/max values per dataset for the selected metrics."""
    metric_filter = set(metrics) if metrics else None
    bounds: Dict[str, Tuple[float, float]] = {}
    for path in paths:
        try:
            aggregated = compute_means(load_results(path))
        except Exception:
            continue
        for dataset_name, metric_map in aggregated.items():
            for metric_name, values_by_checkpoint in metric_map.items():
                if metric_filter and metric_name not in metric_filter:
                    continue
                if not values_by_checkpoint:
                    continue
                values = list(values_by_checkpoint.values())
                if not values:
                    continue
                low = min(values)
                high = max(values)
                existing = bounds.get(dataset_name)
                if existing:
                    low = min(low, existing[0])
                    high = max(high, existing[1])
                bounds[dataset_name] = (low, high)
    return bounds


In [None]:

def plot_accuracy_curves(
    selected_files: Iterable[Path],
    metrics: Iterable[str] | None = None,
    datasets: Sequence[str] | None = None,
    y_limits: Mapping[str, Tuple[float, float]] | None = None,
) -> None:
    paths = list(selected_files)
    if not paths:
        print("Select at least one results file to plot.")
        return

    per_file = {}
    dataset_names: set[str] = set()
    for path in paths:
        aggregated = compute_means(load_results(path))
        per_file[path.stem] = aggregated
        dataset_names.update(aggregated.keys())

    if not dataset_names:
        print("No datasets found in the selected files.")
        return

    chosen_metrics = list(metrics) if metrics else list(DEFAULT_METRICS)
    dataset_order = list(datasets) if datasets else sorted(dataset_names)
    y_limits = {name: (float(bounds[0]), float(bounds[1])) for name, bounds in (y_limits or {}).items()}

    for dataset_name in dataset_order:
        plt.figure(figsize=(7, 4.5))
        plotted_any = False
        for label, dataset_data in per_file.items():
            metrics_for_dataset = dataset_data.get(dataset_name, {})
            for metric_name in chosen_metrics:
                metric_values = metrics_for_dataset.get(metric_name)
                if not metric_values:
                    continue
                checkpoints = sorted(metric_values)
                scores = [metric_values[ckpt] for ckpt in checkpoints]
                plt.plot(
                    checkpoints,
                    scores,
                    marker="o",
                    label=f"{label} Â· {metric_name}",
                )
                plotted_any = True
        if not plotted_any:
            plt.close()
            continue
        y_min, y_max = y_limits.get(dataset_name, (0.0, 1.0))
        if y_min >= y_max:
            delta = max(0.01, abs(y_min) * 0.05)
            y_max = y_min + delta
        plt.title(f"{dataset_name} Accuracy vs. Checkpoint")
        plt.xlabel("Checkpoint")
        plt.ylabel("Accuracy")
        plt.ylim(y_min, y_max)
        plt.grid(alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.show()


In [None]:
selected = [RESULTS_DIR / "SMC_Self_1.5B_BASE_512_N16_D8_lr_typecosine_max_steps_7_beta0.01.json"]

plot_accuracy_curves(
    selected,
    metrics=["ans_acc"],
    y_limits={
        "gsm8k": (0.2, 0.5),
        "math500": (0.05, 0.2),
    },
)

In [None]:

file_selector = widgets.SelectMultiple(
    options=[(path.name, path) for path in AVAILABLE_FILES],
    description="Results",
    rows=min(10, max(4, len(AVAILABLE_FILES))),
    layout=widgets.Layout(width="32%"),
)

metric_names: set[str] = set()
for path in AVAILABLE_FILES:
    try:
        aggregated = compute_means(load_results(path))
    except Exception:
        continue
    metric_names.update(collect_available_metrics(aggregated))

metric_options = sorted(metric_names) or list(DEFAULT_METRICS)
metric_default = tuple(m for m in DEFAULT_METRICS if m in metric_options) or tuple(metric_options)

metric_selector = widgets.SelectMultiple(
    options=metric_options,
    value=metric_default,
    description="Metrics",
    rows=min(10, max(3, len(metric_options))),
    layout=widgets.Layout(width="28%"),
)

range_controls: Dict[str, Dict[str, object]] = {}
range_container = widgets.VBox(
    children=(widgets.Label("Select results to configure y-axis ranges."),),
    layout=widgets.Layout(width="40%"),
)

plot_button = widgets.Button(description="Plot", button_style="primary", tooltip="Generate plots")
output = widgets.Output()


def update_range_widgets(*_):
    selected_paths = list(file_selector.value)
    selected_metrics = list(metric_selector.value)

    if not selected_paths:
        range_controls.clear()
        range_container.children = (
            widgets.Label("Select results to configure y-axis ranges."),
        )
        return

    bounds = gather_dataset_bounds(selected_paths, selected_metrics)
    if not bounds:
        range_controls.clear()
        range_container.children = (
            widgets.Label("No datasets found for the current selection."),
        )
        return

    children = []
    current_datasets = set(bounds.keys())

    for dataset_name in sorted(current_datasets):
        low, high = bounds[dataset_name]
        span = high - low
        padding = span * 0.1 if span > 0 else max(0.05, abs(low) * 0.1 + 0.05)
        slider_min = low - padding
        slider_max = high + padding
        control = range_controls.get(dataset_name)

        if control is None:
            state = {"updating": False}
            slider = widgets.FloatRangeSlider(
                description="",
                min=slider_min,
                max=slider_max,
                step=0.01,
                value=(low, high),
                layout=widgets.Layout(width="95%"),
                continuous_update=False,
            )
            min_box = widgets.FloatText(
                value=low,
                description="min",
                layout=widgets.Layout(width="50%"),
            )
            max_box = widgets.FloatText(
                value=high,
                description="max",
                layout=widgets.Layout(width="50%"),
            )

            def sync_from_slider(change, state=state, min_box=min_box, max_box=max_box):
                if state["updating"]:
                    return
                state["updating"] = True
                lo, hi = change["new"]
                min_box.value = lo
                max_box.value = hi
                state["updating"] = False

            def sync_from_min(change, state=state, slider=slider, max_box=max_box):
                if state["updating"]:
                    return
                value = change["new"]
                if value is None:
                    return
                state["updating"] = True
                new_min = float(value)
                if new_min < slider.min:
                    slider.min = new_min
                if new_min >= slider.max:
                    slider.max = new_min + 0.1
                new_max_input = max_box.value
                new_max = float(new_max_input) if new_max_input is not None else new_min + 0.001
                if new_max <= new_min:
                    new_max = new_min + 0.001
                if new_max > slider.max:
                    slider.max = new_max
                slider.value = (new_min, new_max)
                max_box.value = new_max
                state["updating"] = False

            def sync_from_max(change, state=state, slider=slider, min_box=min_box):
                if state["updating"]:
                    return
                value = change["new"]
                if value is None:
                    return
                state["updating"] = True
                new_max = float(value)
                if new_max > slider.max:
                    slider.max = new_max
                if new_max <= slider.min:
                    slider.min = new_max - 0.1
                new_min_input = min_box.value
                new_min = float(new_min_input) if new_min_input is not None else new_max - 0.001
                if new_min >= new_max:
                    new_min = new_max - 0.001
                if new_min < slider.min:
                    slider.min = new_min
                slider.value = (new_min, new_max)
                min_box.value = new_min
                state["updating"] = False

            slider.observe(sync_from_slider, names="value")
            min_box.observe(sync_from_min, names="value")
            max_box.observe(sync_from_max, names="value")

            container = widgets.VBox(
                [
                    widgets.Label(dataset_name),
                    slider,
                    widgets.HBox([min_box, max_box]),
                ],
                layout=widgets.Layout(width="95%"),
            )

            range_controls[dataset_name] = {
                "slider": slider,
                "min_box": min_box,
                "max_box": max_box,
                "container": container,
                "state": state,
            }
        else:
            slider = control["slider"]
            min_box = control["min_box"]
            max_box = control["max_box"]
            state = control["state"]
            state["updating"] = True
            slider.min = min(slider_min, slider.min)
            slider.max = max(slider_max, slider.max)
            slider.step = 0.01
            current_min = min_box.value if min_box.value is not None else low
            current_max = max_box.value if max_box.value is not None else high
            if current_min >= current_max:
                current_max = current_min + 0.001
            slider.value = (float(current_min), float(current_max))
            min_box.value = slider.value[0]
            max_box.value = slider.value[1]
            state["updating"] = False

        children.append(range_controls[dataset_name]["container"])

    for dataset_name in list(range_controls.keys()):
        if dataset_name not in current_datasets:
            range_controls.pop(dataset_name)

    range_container.children = tuple(children)


def on_plot_button_clicked(_):
    with output:
        clear_output(wait=True)
        selected_paths = list(file_selector.value)
        if not selected_paths:
            print("Please select at least one results file.")
            return
        selected_metrics = list(metric_selector.value)
        y_limits = {}
        for dataset_name, control in range_controls.items():
            min_val = control["min_box"].value
            max_val = control["max_box"].value
            if min_val is None or max_val is None:
                continue
            min_val = float(min_val)
            max_val = float(max_val)
            if min_val >= max_val:
                max_val = min_val + 0.001
            y_limits[dataset_name] = (min_val, max_val)
        plot_accuracy_curves(
            selected_paths,
            metrics=selected_metrics,
            y_limits=y_limits,
        )


def _on_selection_change(change):
    if change.get("name") == "value":
        update_range_widgets()


file_selector.observe(_on_selection_change, names="value")
metric_selector.observe(_on_selection_change, names="value")

update_range_widgets()

plot_button.on_click(on_plot_button_clicked)

display(widgets.HBox([file_selector, metric_selector, range_container]))
display(plot_button)
display(output)
