In [4]:
from collections import defaultdict
import json


def load_metrics(file):
    known_metrics = ['test_score_mean', 'test_score_sem', 'test_score_count', 'test_score_values']
    metrics_parsed = defaultdict(dict)
    with open(file, 'rt') as f:
        metrics_raw = json.load(f)
    for key in metrics_raw:
        for known in known_metrics:
            if key.startswith(f'{known}/'):
                data_source = key[len(known)+1:]
                metrics_parsed[data_source][known] = metrics_raw[key]
                matched = True
                break
        if not matched:
            print(f'skipping {key}, unknown how to parse')
    if len(metrics_parsed) == 1:
        return next(iter(metrics_parsed.values()))
    return metrics_parsed

In [5]:
from pathlib import Path

def load_line(linespec):
    d = addict.Dict()
    root = Path(f'generations') / Path(linespec.folder)
    for file in root.glob(f'**/{linespec.step}/**/*.parquet.{linespec.reward_fn}.jsonl'):
        spec, size, diff = file.parts[-4:-1]
        split = file.stem.split(".")[0]
        d[spec][size][diff][split] = load_metrics(file)
    return d

def load_lines(lines):
    d = addict.Dict()
    for line in lines:
        d[line.name] = load_line(line)
    return d


In [6]:
from typing import List, Tuple
from __future__ import annotations
def latex_table(col_headers: List[str],
                row_headers: List[str],
                rows: List[List[Tuple[float, float]]],
                digits: int = 2,
                *,
                times100: bool = False,
                caption: str | None = None,
                label: str | None = None,
                use_booktabs: bool = True,
                escape_underscores: bool = True,
                table_font_size: str | None = None     # NEW
                ) -> str:
    """
    Build a LaTeX table from lists of (mean, std) tuples.

    Features
    --------
    • Escapes underscores in headers / caption.
    • Shows “± std” in \\tiny.
    • Bold-faces the best mean in each row (highest value).
    • Optional ×100 percentage display.
    • Optional global font-size for the *whole table* (table_font_size).

    Parameters
    ----------
    table_font_size : str | None
        LaTeX size command *without backslash*, e.g. 'small', 'footnotesize'.
        Wrapped around the tabular in a local group { ... }.
    """
    if len(row_headers) != len(rows):
        raise ValueError("row_headers length must equal number of row lists")
    for r in rows:
        if len(r) != len(col_headers):
            raise ValueError("each row must have len(col_headers) tuples")

    # ---------- helpers -----------------------------------------------------
    esc = (lambda s: s.replace('_', r'\_')) if escape_underscores else (lambda s: s)
    fmt = f"{{:.{digits}f}}"

    def _fmt(x: float) -> str:
        return fmt.format(x * 100 if times100 else x)

    def _cell(mean: float, std: float, bold=False) -> str:
        txt = f"{_fmt(mean)} {{\\tiny$\\pm$ {_fmt(std)}}}"
        return rf"\textbf{{{txt}}}" if bold else txt

    # ---------- build tabular ----------------------------------------------
    border = (r"\toprule", r"\midrule", r"\bottomrule") if use_booktabs \
             else (r"\hline",) * 3
    col_spec = "l" + "c" * len(col_headers)

    lines = [
        border[0],
        " & ".join([""] + [esc(h) for h in col_headers]) + r" \\",
        border[1],
    ]

    for rh, row in zip(row_headers, rows):
        best_idx = max(range(len(row)), key=lambda i: row[i][0])
        cells = [_cell(m, s, j == best_idx) for j, (m, s) in enumerate(row)]
        lines.append(" & ".join([esc(rh)] + cells) + r" \\")
    lines.append(border[2])

    tabular = '\n'.join([
        rf"\begin{{tabular}}{{{col_spec}}}",
        *lines,
        r"\end{tabular}"
    ])

    # ---------- optional font-size wrapper ----------------------------------
    if table_font_size:
        tabular = f"{{\\{table_font_size}\n{tabular}\n}}"

    # ---------- optional full table environment -----------------------------
    if caption or label:
        parts = [r"\begin{table}[ht]", r"\centering", tabular]
        if caption:
            parts.append(rf"\caption{{{esc(caption)}}}")
        if label:
            parts.append(rf"\label{{{label}}}")
        parts.append(r"\end{table}")
        return '\n'.join(parts)

    return tabular


In [7]:
import yaml
import addict

with open('lines.yaml', 'rt') as f:
    lines = [addict.Dict(d) for d in yaml.safe_load(f)]
data = load_lines(lines)
print(data.keys())

dict_keys(['specific_prompt_strict', 'specific_prompt_lenient', 'generic_prompt_strict', 'generic_prompt_lenient', 'specific_prompt_filter', 'generic_prompt_filter', 'specific_prompt_sft_loss', 'generic_prompt_sft_loss', 'specific_prompt_sft_reward', 'generic_prompt_sft_reward'])


In [None]:

from matplotlib import pyplot as plt
import numpy as np


plot_config = {
    "lines": 
    {
        'specific_prompt_strict': {
            'name': 'rl_strict_pspec',
            'short': 'rl_stct_spc',
            'color': 'red',
            'linestyle': 'dashed'
        },
        'specific_prompt_filter': {
            'name': 'rl_filter_pspec',
            'short': 'rl_flt_spc',
            'color': 'gold',
            'linestyle': 'dashed'
        },
        'specific_prompt_lenient': {
            'name': 'rl_soft_pspec',
            'short': 'rl_sof_spc',
            'color': 'orange',
            'linestyle': 'dashed'
        },
        'generic_prompt_strict': {
            'name': 'rl_strict_pgen',
            'short': 'rl_stct_gen',
            'color': 'red',
            'linestyle': 'solid'
        },
        'generic_prompt_filter': {
            'name': 'rl_filter_pgen',
            'short': 'rl_flt_gen',
            'color': 'gold',
            'linestyle': 'solid'
        },
        'generic_prompt_lenient': {
            'name': 'rl_soft_pgen',
            'short': 'rl_sof_gen',
            'color': 'orange',
            'linestyle': 'solid'
        },
        'specific_prompt_sft_loss': {
            'name': 'sft_loss_pspec',
            'short': 'sft_los_spc',
            'color': 'blue',
            'linestyle': 'dashed'
        },
        'generic_prompt_sft_loss': {
            'name': 'sft_loss_pgen',
            'short': 'sft_los_gen',
            'color': 'green',
            'linestyle': 'solid'
        },
        'specific_prompt_sft_reward': {
            'name': 'sft_reward_pspec',
            'short': 'sft_rwd_spc',
            'color': 'purple',
            'linestyle': 'dashed'
        },
        'generic_prompt_sft_reward': {
            'name': 'sft_reward_pgen',
            'short': 'sft_rwd_gen',
            'color': 'brown',
            'linestyle': 'solid'
        },
    },
    'ymin': 0.0,
    'ymax': 1.0,
    'shade_alpha': 0.2
}

def round_err(dat):
    # return dat
    return [round(100*d) / 100 for d in dat]

tmax = 1.0
tmin = 0.001
for tmin, tmax, subname in [(0.001, 0.1, "gt0.9"), (0.001, 1.0, "full")]:
    for spec in ['n10v2', 'n15v2']:
        for size in ['2k']:
            for d in ['d0', 'd1', 'd2', 'd3', 'd4', 'd5']:
                plt.figure()
                for line in plot_config['lines']:
                    line_config = plot_config['lines'][line]
                    x = 1-np.logspace(np.log10(tmin,),np.log10(tmax),100)
                    if len(data[line]) == 0: raise ValueError(line)
                    y_raw = (np.array(round_err(data[line][spec][size][d]['dev']['test_score_values']))[None,:] >  x[:,None])
                    y_mean = y_raw.mean(-1)
                    y_sem= y_raw.std(-1) / np.sqrt(y_raw.shape[-1])
                    plt.plot(x, y_mean, 
                                linestyle=line_config['linestyle'], 
                                color=line_config['color'], 
                                label=line_config['name'],
                            #      elinewidth=1,     # thinner bar lines
                            # capsize=3,        # half-length of the “T” caps
                            # capthick=1,        # thickness of the caps
                    )
                    plt.fill_between(x, y_mean-y_sem, y_mean+y_sem, color=line_config['color'], alpha=plot_config['shade_alpha'])
                plt.legend()
                plt.xlabel('threshold')
                plt.ylabel('accuracy')
                plt.title(f'{spec} {d} (n={data[line][spec][size][d]["dev"]["test_score_count"]})')
                plt.ylim(plot_config['ymin'], plot_config['ymax'])
                plt.savefig(f'plots/dev.{spec}.{size}.{d}.{subname}.threshold.pdf')
                plt.clf()
    for spec in ['n5v2']:
        for size in ['2k']:
            for d in ['d0', 'd1', 'd2', 'd3']:
                plt.figure()
                for line in plot_config['lines']:
                    line_config = plot_config['lines'][line]
                    x = 1-np.logspace(np.log10(tmin),np.log10(tmax),100)
                    if len(data[line]) == 0: raise ValueError(line)
                    y_raw = (np.array(round_err(data[line][spec][size][d]['dev']['test_score_values']))[None,:] >  x[:,None])
                    y_mean = y_raw.mean(-1)
                    y_sem= y_raw.std(-1) / np.sqrt(y_raw.shape[-1])
                    plt.plot(x, y_mean, 
                                linestyle=line_config['linestyle'], 
                                color=line_config['color'], 
                                label=line_config['name'],
                            #      elinewidth=1,     # thinner bar lines
                            # capsize=3,        # half-length of the “T” caps
                            # capthick=1,        # thickness of the caps
                    )
                    plt.fill_between(x, y_mean-y_sem, y_mean+y_sem, color=line_config['color'], alpha=plot_config['shade_alpha'])
                plt.legend()
                plt.xlabel('threshold')
                plt.ylabel('accuracy')
                plt.title(f'{spec} {d} (n={data[line][spec][size][d]["dev"]["test_score_count"]})')
                plt.ylim(plot_config['ymin'], plot_config['ymax'])
                plt.savefig(f'plots/dev.{spec}.{size}.{d}.{subname}.threshold.pdf')
                plt.clf()

ttable = 0.01
table_rows = []
table_row_headers = []
table_column_headers = [line['short'] for line in plot_config['lines'].values()]

for spec in ['n5v2', 'n10v2', 'n15v2']:
    for size in ['2k']:
        for d in ['d0', 'd1', 'd2', 'd3', 'd4', 'd5'] if spec != 'n5v2' else ['d0', 'd1', 'd2', 'd3']:
            row = []
            for line in plot_config['lines']:
                line_config = plot_config['lines'][line]
                y_raw = (np.array(round_err(data[line][spec][size][d]['dev']['test_score_values']))) > (1- ttable)
                y_mean = y_raw.mean(-1)
                y_sem= y_raw.std(-1) / np.sqrt(y_raw.shape[-1])
                row.append((float(y_mean), 2*float(y_sem)))
            table_rows.append(row)
            table_row_headers.append(f'{spec}-{d}')
print(latex_table(table_column_headers, table_row_headers, table_rows, times100=True, caption="RL vs SFT, accuracy (average \textsc{correct}) at threshold $0.01$, see\cref{eqn:correct}. System with the best overal mean is bolded. Range denotes 95\\% confidence interval.", label="tab:rl_vs_sft", table_font_size='tiny'))