In [1]:
import json, re, pandas as pd
from collections import Counter

In [2]:
NUM_RE = re.compile(r"[+-]?\d+(?:/\d+)?")

In [3]:
def _latex_to_plain(s: str) -> str:
    s = re.sub(r"\\boxed{([^}]*)}", r"\1", s)                           # \boxed{…}
    s = re.sub(r"\\(?:d)?frac{([^}]*)}{([^}]*)}", r"\1/\2", s)          # \frac{a}{b}
    return s.replace(",", "").lstrip("$€£ ").strip()

In [4]:
def extract_numeric(text: str | None):
    """Return final numeric token - tries <answer>… first, else last number."""
    if not text:
        return None
    # closed tag
    m = re.search(r"<answer>(.*?)</answer>", text, re.I | re.S)
    if m and (n := NUM_RE.search(_latex_to_plain(m.group(1)))):
        return n.group(0)
    # open tag
    m = re.search(r"<answer>(.*)$", text, re.I | re.S)
    if m and (n := NUM_RE.search(_latex_to_plain(m.group(1)))):
        return n.group(0)
    # number before </answer>
    m = re.search(r"([+-]?\d+(?:/\d+)?)\s*</answer>", text, re.I | re.S)
    if m:
        return _latex_to_plain(m.group(1))
    # fallback
    nums = NUM_RE.findall(_latex_to_plain(text))
    return nums[-1] if nums else None

In [5]:
def _load_jsonl(path, tolerant=False):
    bad = 0; items = []
    with open(path, encoding="utf-8") as f:
        for ln_no, ln in enumerate(f, 1):
            ln = ln.strip()
            if not ln:
                continue
            try:
                items.append(json.loads(ln))
            except json.JSONDecodeError as e:
                bad += 1
                if not tolerant:
                    raise RuntimeError(f"{path}:{ln_no}\n{e}")
    if bad:
        print(f"[warn] skipped {bad} malformed lines")
    return items


In [6]:
def _load_jsonflex(path):
    """Accepts JSON-Lines *or* a single JSON array."""
    with open(path, encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        # case 1: JSON array ― file starts with “[”
        if first == "[":
            return json.load(f)            # returns list[dict]
        # case 2: JSONL ― stream line-by-line
        items = []
        for ln in f:
            ln = ln.strip()
            if ln:
                try:
                    items.append(json.loads(ln))
                except json.JSONDecodeError:
                    pass                   # or raise
        return items


In [7]:
def evaluate(pred_path: str, gold_path: str, show_errors: int = 5) -> float:
    pred, gold = _load_jsonflex(pred_path), _load_jsonflex(gold_path)
    if len(pred) != len(gold):
        print(f"[warn] len(pred)={len(pred)} ≠ len(gold)={len(gold)}")
    rows, correct = [], 0
    for i, (p, g) in enumerate(zip(pred, gold)):
        p_ans = extract_numeric(p.get("raw") or p.get("prediction"))
        g_ans = extract_numeric(g.get("answer") or g.get("solution"))
        ok = (p_ans == g_ans)
        correct += ok
        if not ok and len(rows) < show_errors:
            rows.append({"idx": i, "pred": p_ans, "gold": g_ans})
    acc = correct / len(pred) if pred else 0.0
    print(f"Accuracy: {correct}/{len(pred)}  ({acc*100:.2f}%)")
    if rows:
        display(pd.DataFrame(rows))
    return acc

In [8]:
evaluate("results/gsm_COT_20250710_0040.jsonl",  "data/benchmarks/gsm8k/test.jsonl")

Accuracy: 1085/1319  (82.26%)


Unnamed: 0,idx,pred,gold
0,5,50,64
1,7,120,160
2,8,180,45
3,12,21,13
4,15,4875,125


0.8225928733889311

In [9]:
evaluate("results/math_COT_20250710_0040.jsonl","data/benchmarks/math_test_all.jsonl")

Accuracy: 1630/5000  (32.60%)


Unnamed: 0,idx,pred,gold
0,0,1/11,11
1,2,11200,70
2,3,61,83
3,4,1230,1440
4,6,2,18


0.326

In [10]:
evaluate("results/gsm_STATIC_COT_20250712_1548.jsonl",  "data/benchmarks/gsm8k/test.jsonl")

Accuracy: 948/1319  (71.87%)


Unnamed: 0,idx,pred,gold
0,7,120,160
1,8,355,45
2,13,36,18
3,15,96,125
4,16,167,230


0.7187263078089462

In [11]:
evaluate("results/gsm_STATIC_COT_20250713_0140.jsonl",  "data/benchmarks/gsm8k/test.jsonl")

Accuracy: 948/1319  (71.87%)


Unnamed: 0,idx,pred,gold
0,7,120,160
1,8,355,45
2,13,36,18
3,15,96,125
4,16,167,230


0.7187263078089462

In [12]:
evaluate("results/gsm_STATIC_COT_20250713_1117.jsonl",  "data/benchmarks/gsm8k/test.jsonl")

Accuracy: 1029/1319  (78.01%)


Unnamed: 0,idx,pred,gold
0,4,3,20
1,7,120,160
2,12,12,13
3,17,450,57500
4,20,67,15


0.78013646702047

In [13]:
evaluate("results/gsm_DYNAMIC_COT_20250716_0825.jsonl",  "data/benchmarks/gsm8k/test.jsonl")

Accuracy: 1095/1319  (83.02%)


Unnamed: 0,idx,pred,gold
0,5,29,64
1,7,100,160
2,8,175,45
3,12,12,13
4,16,190,230


0.8301743745261562

In [14]:
evaluate("results/gsm_DYNAMIC_COT_20250715_1710.jsonl",  "data/benchmarks/gsm8k/test.jsonl")

Accuracy: 1064/1319  (80.67%)


Unnamed: 0,idx,pred,gold
0,2,50000,70000
1,5,50,64
2,7,0,160
3,8,0,45
4,12,12,13


0.8066717210007581