# Search-Augmented ACE: PUCT Q-Estimator Ablation

**Hypothesis**: Tree search (PUCT) over ACE-style playbooks can outperform greedy sequential evolution, but the Q estimator determines whether the tree actually explores or degenerates into a chain.

**Four conditions** (same total LLM budget):
1. **Greedy ACE** — sequential generate → reflect → curate (baseline)
2. **PUCT-Mean** — Q = mean reward. Standard. Biased by early results.
3. **PUCT-EMA** — Q = exponential moving average (α=0.4). Reacts to recent performance.
4. **PUCT-Bayesian** — Q = Beta posterior mean. Shrinks toward 0.5 with low data, most exploratory.

**Key fix**: Progressive widening — only expand a node after it accumulates `≥ ceil(visits^0.5)` children. This forces branching instead of chaining.

**Setup**: Qwen2.5-7B-Instruct via vLLM on A100, 50 GSM8K problems.

## 1. Setup & Dependencies

In [None]:
%%capture
!pip install vllm openai datasets matplotlib numpy

In [None]:
import subprocess
import time
import os
import signal
import json
import re
import copy
import math
import random
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from openai import OpenAI

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

In [None]:
# Launch vLLM server in background
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
VLLM_PORT = 8000

vllm_proc = subprocess.Popen(
    [
        "python", "-m", "vllm.entrypoints.openai.api_server",
        "--model", MODEL_NAME,
        "--port", str(VLLM_PORT),
        "--max-model-len", "4096",
        "--gpu-memory-utilization", "0.85",
        "--dtype", "auto",
    ],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
print(f"vLLM server PID: {vllm_proc.pid}")

# Wait for server to be ready
client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="dummy")
for attempt in range(60):
    try:
        client.models.list()
        print(f"vLLM ready after {attempt + 1}s")
        break
    except Exception:
        time.sleep(1)
else:
    raise RuntimeError("vLLM server did not start within 60s")

## 2. GSM8K Data Loading & Parsing

In [None]:
from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "main", split="test")

def extract_gsm8k_answer(answer_text: str) -> str:
    """Extract numeric answer from GSM8K '#### <number>' format."""
    match = re.search(r"####\s*(-?[\d,]+\.?\d*)", answer_text)
    if match:
        return match.group(1).replace(",", "").strip()
    # Fallback: last number in text
    nums = re.findall(r"-?[\d,]+\.?\d*", answer_text)
    if nums:
        return nums[-1].replace(",", "")
    return ""

problems = []
for item in ds:
    problems.append({
        "question": item["question"],
        "answer": extract_gsm8k_answer(item["answer"]),
        "full_answer": item["answer"],
    })

# Use first 50 problems, shuffled deterministically
rng = random.Random(SEED)
rng.shuffle(problems)
problems = problems[:50]
print(f"Loaded {len(problems)} GSM8K problems")
print(f"Example: Q='{problems[0]['question'][:80]}...' A={problems[0]['answer']}")

## 3. Core Components: Playbook, Generator, Reflector, Curator

In [None]:
# --- Playbook representation ---

@dataclass
class Bullet:
    id: str
    section: str  # STRATEGIES, COMMON_MISTAKES, SOLUTION_PATTERNS
    content: str
    helpful: int = 0
    harmful: int = 0

    def to_str(self) -> str:
        return f"[{self.id}] helpful={self.helpful} harmful={self.harmful} :: {self.content}"


@dataclass
class Playbook:
    bullets: List[Bullet] = field(default_factory=list)
    _next_id: int = 1

    def add(self, section: str, content: str) -> str:
        prefix = {"STRATEGIES": "str", "COMMON_MISTAKES": "err", "SOLUTION_PATTERNS": "sol"}.get(section, "gen")
        bid = f"{prefix}-{self._next_id:05d}"
        self._next_id += 1
        self.bullets.append(Bullet(id=bid, section=section, content=content))
        return bid

    def remove(self, bid: str):
        self.bullets = [b for b in self.bullets if b.id != bid]

    def update(self, bid: str, content: str):
        for b in self.bullets:
            if b.id == bid:
                b.content = content
                return

    def tag(self, bid: str, label: str):
        for b in self.bullets:
            if b.id == bid:
                if label == "helpful":
                    b.helpful += 1
                elif label == "harmful":
                    b.harmful += 1

    def to_str(self) -> str:
        sections = defaultdict(list)
        for b in self.bullets:
            sections[b.section].append(b.to_str())
        parts = []
        for sec in ["STRATEGIES", "COMMON_MISTAKES", "SOLUTION_PATTERNS"]:
            if sections[sec]:
                parts.append(f"## {sec}")
                parts.extend(sections[sec])
        return "\n".join(parts) if parts else "(empty playbook)"

    def copy(self) -> "Playbook":
        return copy.deepcopy(self)

    @property
    def size(self) -> int:
        return len(self.bullets)


def make_initial_playbook() -> Playbook:
    pb = Playbook()
    pb.add("STRATEGIES", "Break word problems into step-by-step arithmetic.")
    pb.add("STRATEGIES", "Identify what quantity the question asks for before computing.")
    pb.add("COMMON_MISTAKES", "Watch for unit conversions (hours to minutes, etc).")
    return pb

print(make_initial_playbook().to_str())

In [None]:
# --- LLM call wrapper ---

call_counter = defaultdict(int)  # track calls by role

def llm_call(system: str, user: str, role: str = "generate", temperature: float = 0.7, max_tokens: int = 1024) -> str:
    """Single LLM call via vLLM OpenAI-compatible API."""
    call_counter[role] += 1
    try:
        resp = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": user},
            ],
            temperature=temperature,
            max_tokens=max_tokens,
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        print(f"LLM call failed ({role}): {e}")
        return ""

def reset_call_counter():
    global call_counter
    call_counter = defaultdict(int)

def get_call_counts() -> Dict[str, int]:
    return dict(call_counter)

In [None]:
# --- Generator ---

def generate(question: str, playbook: Playbook) -> Tuple[str, List[str], str]:
    """
    Generate a solution to a math problem using the playbook.
    Returns (extracted_answer, bullets_used, raw_response).
    """
    pb_text = playbook.to_str()
    system = (
        "You are a math problem solver. Use the playbook strategies below to help solve the problem.\n"
        "When you use a specific strategy, mention its ID (e.g., [str-00001]).\n"
        "Show your work step-by-step, then give the final numeric answer on its own line as: #### <number>\n\n"
        f"PLAYBOOK:\n{pb_text}"
    )
    user = f"Solve this problem:\n{question}"
    raw = llm_call(system, user, role="generate")

    # Extract answer
    answer = ""
    m = re.search(r"####\s*(-?[\d,]+\.?\d*)", raw)
    if m:
        answer = m.group(1).replace(",", "").strip()
    else:
        nums = re.findall(r"-?[\d,]+\.?\d*", raw)
        if nums:
            answer = nums[-1].replace(",", "")

    # Extract bullet references
    bullets_used = re.findall(r"\[(\w+-\d+)\]", raw)

    return answer, bullets_used, raw

In [None]:
# --- Answer comparison ---

def answers_match(predicted: str, ground_truth: str) -> bool:
    """Compare numeric answers with tolerance."""
    try:
        p = float(predicted.replace(",", ""))
        g = float(ground_truth.replace(",", ""))
        return abs(p - g) < 1e-3
    except (ValueError, TypeError):
        return predicted.strip() == ground_truth.strip()

In [None]:
# --- Reflector ---

def reflect(question: str, raw_response: str, predicted: str, ground_truth: str,
            bullets_used: List[str], playbook: Playbook) -> Tuple[str, Dict[str, str]]:
    """
    Reflect on the solution attempt.
    Returns (reflection_text, bullet_tags) where bullet_tags maps bullet_id -> 'helpful'|'harmful'|'neutral'.
    """
    correct = answers_match(predicted, ground_truth)
    feedback = "CORRECT" if correct else f"INCORRECT (predicted {predicted}, expected {ground_truth})"

    bullets_text = ""
    for b in playbook.bullets:
        if b.id in bullets_used:
            bullets_text += f"  {b.to_str()}\n"

    system = (
        "You are a math reasoning analyst. Analyze whether the solution approach was correct "
        "and whether the playbook strategies used were helpful or harmful.\n"
        "For each bullet ID used, output a JSON line: {\"id\": \"str-00001\", \"tag\": \"helpful\"}\n"
        "Tags must be one of: helpful, harmful, neutral.\n"
        "End with a brief reflection paragraph."
    )
    user = (
        f"Problem: {question}\n\n"
        f"Solution attempt:\n{raw_response}\n\n"
        f"Result: {feedback}\n\n"
        f"Playbook bullets used:\n{bullets_text}"
    )
    raw = llm_call(system, user, role="reflect", temperature=0.3)

    # Parse bullet tags
    tags = {}
    for m in re.finditer(r'"id"\s*:\s*"([^"]+)".*?"tag"\s*:\s*"(helpful|harmful|neutral)"', raw):
        tags[m.group(1)] = m.group(2)

    # If no tags parsed but we know correctness, apply heuristic
    if not tags:
        for bid in bullets_used:
            tags[bid] = "helpful" if correct else "neutral"

    return raw, tags

In [None]:
# --- Curator ---

MAX_BULLETS = 20

def curate(playbook: Playbook, reflection: str, question: str) -> Playbook:
    """
    Curate the playbook based on reflection.
    Returns a new (copied) playbook with operations applied.
    """
    pb = playbook.copy()
    pb_text = pb.to_str()

    system = (
        "You are a playbook curator for math problem solving. Based on the reflection, "
        "propose operations to improve the playbook.\n"
        "Output a JSON array of operations:\n"
        '[{"op": "ADD", "section": "STRATEGIES", "content": "new insight"},\n'
        ' {"op": "UPDATE", "id": "str-00001", "content": "refined text"},\n'
        ' {"op": "DELETE", "id": "err-00002"}]\n'
        f"Sections: STRATEGIES, COMMON_MISTAKES, SOLUTION_PATTERNS\n"
        f"Max bullets: {MAX_BULLETS}. Current: {pb.size}.\n"
        "Only propose operations that are clearly supported by the reflection. Keep it minimal."
    )
    user = (
        f"Current playbook:\n{pb_text}\n\n"
        f"Problem context: {question[:200]}\n\n"
        f"Reflection:\n{reflection}"
    )
    raw = llm_call(system, user, role="curate", temperature=0.3)

    # Parse operations from JSON array
    ops = []
    # Try to find JSON array in response
    json_match = re.search(r'\[\s*\{.*?\}\s*\]', raw, re.DOTALL)
    if json_match:
        try:
            ops = json.loads(json_match.group())
        except json.JSONDecodeError:
            pass

    # Apply operations
    for op in ops:
        try:
            if op.get("op") == "ADD" and pb.size < MAX_BULLETS:
                pb.add(op.get("section", "STRATEGIES"), op.get("content", ""))
            elif op.get("op") == "UPDATE" and op.get("id"):
                pb.update(op["id"], op.get("content", ""))
            elif op.get("op") == "DELETE" and op.get("id"):
                pb.remove(op["id"])
        except Exception:
            pass

    # Safety: if curator emptied the playbook, reset
    if pb.size == 0:
        pb = make_initial_playbook()
        pb._next_id = playbook._next_id

    return pb

## 4. Search Strategies

In [None]:
# --- Shared tracking ---

@dataclass
class RunLog:
    """Tracks per-problem results for a single condition."""
    correct: List[bool] = field(default_factory=list)
    playbook_sizes: List[int] = field(default_factory=list)
    call_counts: Dict[str, int] = field(default_factory=dict)
    final_playbook: Optional[Playbook] = None

    @property
    def running_accuracy(self) -> List[float]:
        acc = []
        total = 0
        for i, c in enumerate(self.correct):
            total += int(c)
            acc.append(total / (i + 1))
        return acc

    @property
    def final_accuracy(self) -> float:
        if not self.correct:
            return 0.0
        # Accuracy on last 20 problems
        tail = self.correct[-20:]
        return sum(tail) / len(tail)

In [None]:
# --- Strategy 1: Greedy ACE ---

CURATE_EVERY = 5

def run_greedy(problems: List[dict]) -> RunLog:
    """Sequential generate → reflect → curate loop."""
    reset_call_counter()
    log = RunLog()
    pb = make_initial_playbook()
    reflections_buffer = []

    for i, prob in enumerate(problems):
        answer, bullets_used, raw = generate(prob["question"], pb)
        correct = answers_match(answer, prob["answer"])
        log.correct.append(correct)
        log.playbook_sizes.append(pb.size)

        # Reflect
        reflection, tags = reflect(
            prob["question"], raw, answer, prob["answer"], bullets_used, pb
        )
        for bid, label in tags.items():
            pb.tag(bid, label)
        reflections_buffer.append((reflection, prob["question"]))

        # Curate periodically
        if (i + 1) % CURATE_EVERY == 0 and reflections_buffer:
            # Use most recent reflection
            ref_text, ref_q = reflections_buffer[-1]
            pb = curate(pb, ref_text, ref_q)
            reflections_buffer = []

        if (i + 1) % 10 == 0:
            acc = sum(log.correct) / len(log.correct)
            print(f"  Greedy [{i+1}/{len(problems)}] acc={acc:.2%} bullets={pb.size}")

    log.call_counts = get_call_counts()
    log.final_playbook = pb
    return log

In [None]:
# --- Strategy 2: PUCT-ACE with Q-estimator variants ---

@dataclass
class MCTSNode:
    playbook: Playbook
    parent: Optional["MCTSNode"] = None
    children: List["MCTSNode"] = field(default_factory=list)
    visits: int = 0
    reward_history: List[float] = field(default_factory=list)
    results: List[bool] = field(default_factory=list)

    def q_value(self, mode: str = "mean") -> float:
        if not self.reward_history:
            return 0.5  # optimistic prior for unvisited nodes
        if mode == "mean":
            return sum(self.reward_history) / len(self.reward_history)
        elif mode == "ema":
            # Exponential moving average — recent batches matter more
            alpha = 0.4
            q = self.reward_history[0]
            for r in self.reward_history[1:]:
                q = alpha * r + (1 - alpha) * q
            return q
        elif mode == "bayesian":
            # Beta(successes+1, failures+1) posterior mean
            # Shrinks toward 0.5 with few observations — maximally exploratory
            s = sum(self.reward_history)
            n = len(self.reward_history)
            return (s + 1) / (n + 2)
        return 0.5


def puct_select(node: MCTSNode, c_puct: float = 1.5, q_mode: str = "mean") -> MCTSNode:
    """Select best child via PUCT, recurse to leaf."""
    if not node.children:
        return node
    n_parent = node.visits
    best, best_score = None, -float("inf")
    for child in node.children:
        prior = 1.0 / len(node.children)
        exploit = child.q_value(q_mode)
        explore = c_puct * prior * math.sqrt(n_parent) / (1 + child.visits)
        score = exploit + explore
        if score > best_score:
            best_score = score
            best = child
    return puct_select(best, c_puct, q_mode)


def backprop(node: MCTSNode, reward: float):
    """Backpropagate reward up the tree."""
    while node is not None:
        node.visits += 1
        node.reward_history.append(reward)
        node = node.parent


def should_expand(node: MCTSNode, pw_alpha: float = 0.5) -> bool:
    """Progressive widening: expand only when node needs more children.
    Branching factor grows as ceil(visits^alpha). With alpha=0.5:
      visits=1 → 1 child allowed
      visits=4 → 2 children
      visits=9 → 3 children
    This forces the tree to RE-EVALUATE existing nodes before branching.
    """
    max_children = max(1, math.ceil(node.visits ** pw_alpha))
    return len(node.children) < max_children


def run_puct(problems: List[dict], batch_size: int = 3, c_puct: float = 1.5,
             q_mode: str = "mean") -> RunLog:
    """PUCT tree search over playbook versions with progressive widening."""
    reset_call_counter()
    log = RunLog()
    root = MCTSNode(playbook=make_initial_playbook())
    prob_idx = 0

    while prob_idx < len(problems):
        # Select leaf via PUCT
        leaf = puct_select(root, c_puct, q_mode)

        # Evaluate on a batch using a COPY (don't mutate tree node)
        eval_pb = leaf.playbook.copy()
        batch_end = min(prob_idx + batch_size, len(problems))
        batch = problems[prob_idx:batch_end]
        batch_correct = []
        last_reflection = ""
        last_question = ""

        for prob in batch:
            answer, bullets_used, raw = generate(prob["question"], eval_pb)
            correct = answers_match(answer, prob["answer"])
            batch_correct.append(correct)
            log.correct.append(correct)
            log.playbook_sizes.append(eval_pb.size)

            reflection, tags = reflect(
                prob["question"], raw, answer, prob["answer"], bullets_used, eval_pb
            )
            for bid, label in tags.items():
                eval_pb.tag(bid, label)
            last_reflection = reflection
            last_question = prob["question"]

        prob_idx = batch_end
        reward = sum(batch_correct) / len(batch_correct)
        leaf.results.extend(batch_correct)

        # Progressive widening: only expand if this node needs more children
        if should_expand(leaf):
            new_pb = curate(eval_pb, last_reflection, last_question)
            child = MCTSNode(playbook=new_pb, parent=leaf)
            leaf.children.append(child)
            backprop(child, reward)
        else:
            # Just backprop to re-score existing node (no curate call)
            backprop(leaf, reward)

        if prob_idx % 10 <= batch_size:
            acc = sum(log.correct) / len(log.correct)
            depth = _tree_depth(root)
            size = _tree_size(root)
            branches = _branch_count(root)
            print(f"  PUCT-{q_mode} [{prob_idx}/{len(problems)}] acc={acc:.2%} "
                  f"depth={depth} nodes={size} branches={branches}")

    log.call_counts = get_call_counts()
    best_leaf = _best_leaf(root, q_mode)
    log.final_playbook = best_leaf.playbook
    return log


def _tree_depth(node: MCTSNode) -> int:
    if not node.children:
        return 0
    return 1 + max(_tree_depth(c) for c in node.children)


def _tree_size(node: MCTSNode) -> int:
    return 1 + sum(_tree_size(c) for c in node.children)


def _branch_count(node: MCTSNode) -> int:
    """Count nodes with >1 child (actual branching points)."""
    count = 1 if len(node.children) > 1 else 0
    return count + sum(_branch_count(c) for c in node.children)


def _best_leaf(node: MCTSNode, q_mode: str = "mean") -> MCTSNode:
    if not node.children:
        return node
    best = max(node.children, key=lambda c: c.q_value(q_mode) if c.visits > 0 else -1)
    return _best_leaf(best, q_mode)

## 5. Experiment Runner

In [None]:
print("=" * 60)
print("Running Greedy ACE (baseline)...")
print("=" * 60)
greedy_log = run_greedy(problems)
print(f"Greedy done. Final acc (last 20): {greedy_log.final_accuracy:.2%}")
print(f"Calls: {greedy_log.call_counts}")

In [None]:
puct_logs = {}
for q_mode in ["mean", "ema", "bayesian"]:
    print("=" * 60)
    print(f"Running PUCT-{q_mode.upper()}...")
    print("=" * 60)
    log = run_puct(problems, q_mode=q_mode)
    puct_logs[q_mode] = log
    print(f"PUCT-{q_mode} done. Final acc (last 20): {log.final_accuracy:.2%}")
    print(f"Calls: {log.call_counts}")
    print()

## 6. Analysis & Plotting

In [None]:
results = {"Greedy ACE": greedy_log}
for q_mode, log in puct_logs.items():
    results[f"PUCT-{q_mode.upper()}"] = log

COLORS = {
    "Greedy ACE": "#1f77b4",
    "PUCT-MEAN": "#ff7f0e",
    "PUCT-EMA": "#2ca02c",
    "PUCT-BAYESIAN": "#d62728",
}

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("PUCT Q-Estimator Ablation on ACE Playbook Search (GSM8K)", fontsize=14, fontweight="bold")

# --- (a) Running accuracy curves ---
ax = axes[0, 0]
for name, log in results.items():
    acc = log.running_accuracy
    ax.plot(range(1, len(acc) + 1), acc, label=name, color=COLORS[name], linewidth=2)
ax.set_xlabel("Problems Solved")
ax.set_ylabel("Running Accuracy")
ax.set_title("(a) Running Accuracy Over Time")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# --- (b) Final accuracy bar chart with bootstrap CI ---
ax = axes[0, 1]

def bootstrap_ci(data, n_boot=1000, ci=0.95):
    data = np.array(data, dtype=float)
    means = [np.mean(np.random.choice(data, size=len(data), replace=True)) for _ in range(n_boot)]
    lo = np.percentile(means, (1 - ci) / 2 * 100)
    hi = np.percentile(means, (1 + ci) / 2 * 100)
    return np.mean(data), lo, hi

names = list(results.keys())
means, lows, highs = [], [], []
for name in names:
    tail = results[name].correct[-20:]
    m, lo, hi = bootstrap_ci(tail)
    means.append(m)
    lows.append(m - lo)
    highs.append(hi - m)

bars = ax.bar(names, means, color=[COLORS[n] for n in names], yerr=[lows, highs], capsize=5)
ax.set_ylabel("Accuracy (last 20 problems)")
ax.set_title("(b) Final Accuracy Comparison")
ax.set_ylim(0, 1)
ax.tick_params(axis='x', rotation=15)
for bar, m in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f"{m:.1%}",
            ha="center", fontsize=9)

# --- (c) Playbook size over time ---
ax = axes[1, 0]
for name, log in results.items():
    sizes = log.playbook_sizes
    ax.plot(range(1, len(sizes) + 1), sizes, label=name, color=COLORS[name], linewidth=2)
ax.set_xlabel("Problems Solved")
ax.set_ylabel("Playbook Bullets")
ax.set_title("(c) Playbook Size Over Time")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# --- (d) Cost breakdown ---
ax = axes[1, 1]
roles = ["generate", "reflect", "curate"]
x = np.arange(len(names))
width = 0.2
for i, role in enumerate(roles):
    counts = [results[n].call_counts.get(role, 0) for n in names]
    ax.bar(x + i * width, counts, width, label=role.capitalize())
ax.set_xticks(x + width)
ax.set_xticklabels(names, fontsize=8, rotation=15)
ax.set_ylabel("LLM Calls")
ax.set_title("(d) LLM Call Breakdown")
ax.legend()

plt.tight_layout()
plt.savefig("search_ace_results.png", dpi=150, bbox_inches="tight")
plt.show()
print("Plots saved to search_ace_results.png")

In [None]:
# --- Statistical comparison ---

print("=" * 60)
print("Statistical Comparison (Bootstrap 95% CI on accuracy difference vs Greedy)")
print("=" * 60)

def bootstrap_diff_ci(a, b, n_boot=2000):
    a, b = np.array(a, dtype=float), np.array(b, dtype=float)
    diffs = []
    for _ in range(n_boot):
        ia = np.random.choice(len(a), len(a), replace=True)
        ib = np.random.choice(len(b), len(b), replace=True)
        diffs.append(np.mean(b[ib]) - np.mean(a[ia]))
    lo, hi = np.percentile(diffs, [2.5, 97.5])
    return np.mean(diffs), lo, hi

greedy_tail = greedy_log.correct[-20:]

for q_mode in ["mean", "ema", "bayesian"]:
    name = f"PUCT-{q_mode.upper()}"
    tail = puct_logs[q_mode].correct[-20:]
    mean_diff, lo, hi = bootstrap_diff_ci(greedy_tail, tail)
    sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
    print(f"{name} vs Greedy: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Pairwise among PUCT variants
print()
for a_mode, b_mode in [("mean", "ema"), ("mean", "bayesian"), ("ema", "bayesian")]:
    a_tail = puct_logs[a_mode].correct[-20:]
    b_tail = puct_logs[b_mode].correct[-20:]
    mean_diff, lo, hi = bootstrap_diff_ci(a_tail, b_tail)
    sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
    print(f"PUCT-{b_mode.upper()} vs PUCT-{a_mode.upper()}: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

## 7. Results Summary

### Experiment
- **Dataset**: GSM8K (50 problems, shuffled)
- **Model**: Qwen2.5-7B-Instruct (local via vLLM)
- **Conditions**: Greedy ACE, PUCT-Mean, PUCT-EMA, PUCT-Bayesian

### What changed from v1
- **Removed ES-ACE**: Too slow (4x budget even after round-robin fix). Population-based search needs efficiency improvements (MAP-Elites, crossover) before it's viable.
- **Fixed PUCT chain degeneration**: Added progressive widening (`ceil(visits^0.5)` children). Nodes must be re-evaluated before branching, forcing actual tree structure.
- **Q-estimator ablation**: The original PUCT used mean Q, which couldn't distinguish nodes when accuracy was uniformly high (~90%). Three variants test whether different Q estimators produce meaningfully different tree behavior.

### Q-Estimator Properties
| Estimator | Formula | Exploration Bias | When It Helps |
|-----------|---------|-----------------|---------------|
| **Mean** | `sum(r) / n` | Neutral | Stable environments, many visits |
| **EMA** | `alpha * r_new + (1-alpha) * q_old` | Recency | Non-stationary playbooks (curator changes things) |
| **Bayesian** | `(s+1)/(n+2)` | Toward 0.5 | Low data (shrinks extreme Q with few visits) |

### Key Questions
1. Does progressive widening produce actual branching (branches > 0)?
2. Which Q estimator yields the best final accuracy?
3. Does any PUCT variant beat Greedy, or does the tree overhead hurt at 50 problems?

### Limitations
- 50 problems is small; need 200+ to see exploration pay off
- Uniform PUCT prior (no LLM-based prior yet)
- Single seed; should repeat with multiple seeds
- Progressive widening alpha=0.5 is a hyperparameter; not tuned

In [None]:
# --- Cleanup: kill vLLM server ---
try:
    os.kill(vllm_proc.pid, signal.SIGTERM)
    print(f"Killed vLLM server (PID {vllm_proc.pid})")
except ProcessLookupError:
    print("vLLM server already stopped")