# Search-Augmented ACE: Search Strategy Comparison

**Hypothesis**: Different search strategies over ACE-style playbook space trade off exploration, exploitation, and efficiency differently. We compare fourteen conditions spanning no-evolution baselines through flat bandits to tree search.

**Fourteen conditions** (matched LLM budget ~100-120 calls):
1. **Majority Vote** — no evolution, 2 samples per problem, majority answer (null hypothesis)
2. **Best-of-N** — no evolution, N samples per problem, pick highest-confidence answer (rejection sampling baseline)
3. **Greedy ACE** — sequential generate -> reflect -> curate (standard ACE baseline)
4. **Thompson Sampling** — flat bandit over a pool of curated playbook variants
5. **UCB Bandit** — flat bandit with deterministic UCB1 selection (vs Thompson's stochastic sampling)
6. **PUCT-Mean** — tree search, Q = mean reward
7. **PUCT-EMA** — tree search, Q = exponential moving average (alpha=0.4)
8. **PUCT-Bayesian** — tree search, Q = Beta posterior mean
9. **PUCT-Variance** — tree search, Q = mean + c*sqrt(var/n) (arXiv:2512.21648)
10. **Beam Search** — width-K beam over playbook variants, prune to top-K each round
11. **Dynamic Thompson** — Thompson with periodic arm addition from best arm
12. **AB-MCTS** — Adaptive Progressive Widening: Thompson-sampled wider-vs-deeper per node
13. **Thompson-Disc** — Discounted Thompson Sampling (gamma=0.95) for non-stationary playbook performance
14. **Discounted MCTS** — PUCT tree search with gamma=0.95 discounted Q-values for non-stationary nodes

**Setup**: Qwen2.5-7B-Instruct via vLLM on A100 (bfloat16, prefix caching, async parallel eval).

**GPU Optimization**: All conditions use AsyncOpenAI with semaphore-based concurrency (64 concurrent requests). Majority Vote fires all 100 calls in parallel. Greedy ACE batches 5 generate+reflect calls between curate intervals. PUCT uses virtual loss for parallel multi-leaf evaluation. vLLM configured with `--enable-prefix-caching` (playbook system prompts shared across problems, ~13% throughput gain).

## 1. Setup & Dependencies

In [None]:
%%capture
!pip install vllm==0.6.6 openai==1.58.1 datasets==3.2.0 matplotlib==3.9.3 numpy==1.26.4 nest_asyncio==1.6.0

In [None]:
import subprocess
import time
import os
import signal
import json
import re
import copy
import math
import random
import pickle
import asyncio
import nest_asyncio
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
from collections import defaultdict, Counter
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from openai import OpenAI, AsyncOpenAI

# Allow nested event loops (required for Colab/Jupyter)
nest_asyncio.apply()

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Checkpoint directory
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Concurrency for async LLM calls (tuned for 7B on A100 40GB)
MAX_CONCURRENT_LLM = 64

In [None]:
# Launch vLLM server in background — optimized for A100
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
VLLM_PORT = 8000

vllm_proc = subprocess.Popen(
    [
        "python", "-m", "vllm.entrypoints.openai.api_server",
        "--model", MODEL_NAME,
        "--port", str(VLLM_PORT),
        "--max-model-len", "8192",
        "--gpu-memory-utilization", "0.95",
        "--dtype", "bfloat16",
        "--max-num-seqs", "1024",
        "--max-num-batched-tokens", "16384",
        "--enable-prefix-caching",
    ],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
print(f"vLLM server PID: {vllm_proc.pid}")

# Wait for server to be ready
client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="dummy")
aclient = AsyncOpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="dummy")

for attempt in range(120):
    try:
        client.models.list()
        print(f"vLLM ready after {attempt + 1}s")
        break
    except Exception:
        time.sleep(1)
else:
    raise RuntimeError("vLLM server did not start within 120s")

## 2. GSM8K Data Loading & Parsing

In [None]:
from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "main", split="test")

def extract_gsm8k_answer(answer_text: str) -> str:
    """Extract numeric answer from GSM8K '#### <number>' format."""
    match = re.search(r"####\s*(-?[\d,]+\.?\d*)", answer_text)
    if match:
        return match.group(1).replace(",", "").strip()
    # Fallback: last number in text
    nums = re.findall(r"-?[\d,]+\.?\d*", answer_text)
    if nums:
        return nums[-1].replace(",", "")
    return ""

problems = []
for item in ds:
    problems.append({
        "question": item["question"],
        "answer": extract_gsm8k_answer(item["answer"]),
        "full_answer": item["answer"],
    })

# Use first 50 problems, shuffled deterministically
rng = random.Random(SEED)
rng.shuffle(problems)
problems = problems[:50]

# --- Data validation ---
assert len(problems) == 50, f"Expected 50 problems, got {len(problems)}"
assert all(p["answer"] for p in problems), "Found empty ground truth answers"
assert all(p["question"].strip() for p in problems), "Found empty questions"
for p in problems:
    try:
        float(p["answer"].replace(",", ""))
    except ValueError:
        raise ValueError(f"Non-numeric ground truth: {p['answer']}")

print(f"Loaded {len(problems)} GSM8K problems (all validated)")
print(f"Example: Q='{problems[0]['question'][:80]}...' A={problems[0]['answer']}")

## 3. Core Components: Playbook, Generator, Reflector, Curator

In [None]:
# --- Playbook representation ---

@dataclass
class Bullet:
    id: str
    section: str  # STRATEGIES, COMMON_MISTAKES, SOLUTION_PATTERNS
    content: str
    helpful: int = 0
    harmful: int = 0

    def to_str(self) -> str:
        return f"[{self.id}] helpful={self.helpful} harmful={self.harmful} :: {self.content}"


@dataclass
class Playbook:
    bullets: List[Bullet] = field(default_factory=list)
    _next_id: int = 1

    def add(self, section: str, content: str) -> str:
        prefix = {"STRATEGIES": "str", "COMMON_MISTAKES": "err", "SOLUTION_PATTERNS": "sol"}.get(section, "gen")
        bid = f"{prefix}-{self._next_id:05d}"
        self._next_id += 1
        self.bullets.append(Bullet(id=bid, section=section, content=content))
        return bid

    def remove(self, bid: str):
        self.bullets = [b for b in self.bullets if b.id != bid]

    def update(self, bid: str, content: str):
        for b in self.bullets:
            if b.id == bid:
                b.content = content
                return

    def tag(self, bid: str, label: str):
        for b in self.bullets:
            if b.id == bid:
                if label == "helpful":
                    b.helpful += 1
                elif label == "harmful":
                    b.harmful += 1

    def to_str(self) -> str:
        sections = defaultdict(list)
        for b in self.bullets:
            sections[b.section].append(b.to_str())
        parts = []
        for sec in ["STRATEGIES", "COMMON_MISTAKES", "SOLUTION_PATTERNS"]:
            if sections[sec]:
                parts.append(f"## {sec}")
                parts.extend(sections[sec])
        return "\n".join(parts) if parts else "(empty playbook)"

    def copy(self) -> "Playbook":
        return copy.deepcopy(self)

    @property
    def size(self) -> int:
        return len(self.bullets)


def make_initial_playbook() -> Playbook:
    pb = Playbook()
    pb.add("STRATEGIES", "Break word problems into step-by-step arithmetic.")
    pb.add("STRATEGIES", "Identify what quantity the question asks for before computing.")
    pb.add("COMMON_MISTAKES", "Watch for unit conversions (hours to minutes, etc).")
    return pb

print(make_initial_playbook().to_str())

In [None]:
# --- LLM call wrappers (sync + async) ---

call_counter = defaultdict(int)  # track calls by role
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM)

def llm_call(system: str, user: str, role: str = "generate", temperature: float = 0.7, max_tokens: int = 512) -> str:
    """Single synchronous LLM call via vLLM OpenAI-compatible API."""
    call_counter[role] += 1
    try:
        resp = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": user},
            ],
            temperature=temperature,
            max_tokens=max_tokens,
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        print(f"LLM call failed ({role}): {e}")
        return ""

async def llm_call_async(system: str, user: str, role: str = "generate",
                         temperature: float = 0.7, max_tokens: int = 512) -> str:
    """Async LLM call with semaphore-based concurrency control."""
    call_counter[role] += 1
    async with _llm_semaphore:
        try:
            resp = await aclient.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": system},
                    {"role": "user", "content": user},
                ],
                temperature=temperature,
                max_tokens=max_tokens,
            )
            return resp.choices[0].message.content.strip()
        except Exception as e:
            print(f"Async LLM call failed ({role}): {e}")
            return ""

def reset_call_counter():
    global call_counter
    call_counter = defaultdict(int)

def get_call_counts() -> Dict[str, int]:
    return dict(call_counter)

In [None]:
# --- Generator (sync + async) ---

def _build_generate_prompts(question: str, playbook: Playbook) -> Tuple[str, str]:
    """Build system/user prompts for generation."""
    pb_text = playbook.to_str()
    system = (
        "You are a math problem solver. Use the playbook strategies below to help solve the problem.\n"
        "When you use a specific strategy, mention its ID (e.g., [str-00001]).\n"
        "Show your work step-by-step, then give the final numeric answer on its own line as: #### <number>\n\n"
        f"PLAYBOOK:\n{pb_text}"
    )
    user = f"Solve this problem:\n{question}"
    return system, user

def _parse_generate_response(raw: str) -> Tuple[str, List[str]]:
    """Extract answer and bullet references from raw LLM response."""
    answer = ""
    m = re.search(r"####\s*(-?[\d,]+\.?\d*)", raw)
    if m:
        answer = m.group(1).replace(",", "").strip()
    else:
        nums = re.findall(r"-?[\d,]+\.?\d*", raw)
        if nums:
            answer = nums[-1].replace(",", "")
    bullets_used = re.findall(r"\[(\w+-\d+)\]", raw)
    return answer, bullets_used

def generate(question: str, playbook: Playbook) -> Tuple[str, List[str], str]:
    """Generate a solution to a math problem using the playbook (sync)."""
    system, user = _build_generate_prompts(question, playbook)
    raw = llm_call(system, user, role="generate")
    answer, bullets_used = _parse_generate_response(raw)
    return answer, bullets_used, raw

async def generate_async(question: str, playbook: Playbook) -> Tuple[str, List[str], str]:
    """Generate a solution to a math problem using the playbook (async)."""
    system, user = _build_generate_prompts(question, playbook)
    raw = await llm_call_async(system, user, role="generate")
    answer, bullets_used = _parse_generate_response(raw)
    return answer, bullets_used, raw

In [None]:
# --- Answer comparison ---

def answers_match(predicted: str, ground_truth: str) -> bool:
    """Compare numeric answers with tolerance."""
    try:
        p = float(predicted.replace(",", ""))
        g = float(ground_truth.replace(",", ""))
        return abs(p - g) < 1e-3
    except (ValueError, TypeError):
        return predicted.strip() == ground_truth.strip()

In [None]:
# --- Reflector (sync + async) ---

def _build_reflect_prompts(question: str, raw_response: str, predicted: str, ground_truth: str,
                           bullets_used: List[str], playbook: Playbook) -> Tuple[str, str]:
    """Build system/user prompts for reflection."""
    correct = answers_match(predicted, ground_truth)
    feedback = "CORRECT" if correct else f"INCORRECT (predicted {predicted}, expected {ground_truth})"

    bullets_text = ""
    for b in playbook.bullets:
        if b.id in bullets_used:
            bullets_text += f"  {b.to_str()}\n"

    system = (
        "You are a math reasoning analyst. Analyze whether the solution approach was correct "
        "and whether the playbook strategies used were helpful or harmful.\n"
        "For each bullet ID used, output a JSON line: {\"id\": \"str-00001\", \"tag\": \"helpful\"}\n"
        "Tags must be one of: helpful, harmful, neutral.\n"
        "End with a brief reflection paragraph."
    )
    user = (
        f"Problem: {question}\n\n"
        f"Solution attempt:\n{raw_response}\n\n"
        f"Result: {feedback}\n\n"
        f"Playbook bullets used:\n{bullets_text}"
    )
    return system, user

def _parse_reflect_response(raw: str, bullets_used: List[str], correct: bool) -> Dict[str, str]:
    """Parse bullet tags from reflection response."""
    tags = {}
    for m in re.finditer(r'"id"\s*:\s*"([^"]+)".*?"tag"\s*:\s*"(helpful|harmful|neutral)"', raw):
        tags[m.group(1)] = m.group(2)
    if not tags:
        for bid in bullets_used:
            tags[bid] = "helpful" if correct else "neutral"
    return tags

def reflect(question: str, raw_response: str, predicted: str, ground_truth: str,
            bullets_used: List[str], playbook: Playbook) -> Tuple[str, Dict[str, str]]:
    """Reflect on the solution attempt (sync)."""
    system, user = _build_reflect_prompts(question, raw_response, predicted, ground_truth, bullets_used, playbook)
    raw = llm_call(system, user, role="reflect", temperature=0.3)
    correct = answers_match(predicted, ground_truth)
    tags = _parse_reflect_response(raw, bullets_used, correct)
    return raw, tags

async def reflect_async(question: str, raw_response: str, predicted: str, ground_truth: str,
                        bullets_used: List[str], playbook: Playbook) -> Tuple[str, Dict[str, str]]:
    """Reflect on the solution attempt (async)."""
    system, user = _build_reflect_prompts(question, raw_response, predicted, ground_truth, bullets_used, playbook)
    raw = await llm_call_async(system, user, role="reflect", temperature=0.3)
    correct = answers_match(predicted, ground_truth)
    tags = _parse_reflect_response(raw, bullets_used, correct)
    return raw, tags

In [None]:
# --- Curator ---

MAX_BULLETS = 20

def curate(playbook: Playbook, reflection: str, question: str) -> Playbook:
    """
    Curate the playbook based on reflection.
    Returns a new (copied) playbook with operations applied.
    """
    pb = playbook.copy()
    pb_text = pb.to_str()

    system = (
        "You are a playbook curator for math problem solving. Based on the reflection, "
        "propose operations to improve the playbook.\n"
        "Output a JSON array of operations:\n"
        '[{"op": "ADD", "section": "STRATEGIES", "content": "new insight"},\n'
        ' {"op": "UPDATE", "id": "str-00001", "content": "refined text"},\n'
        ' {"op": "DELETE", "id": "err-00002"}]\n'
        f"Sections: STRATEGIES, COMMON_MISTAKES, SOLUTION_PATTERNS\n"
        f"Max bullets: {MAX_BULLETS}. Current: {pb.size}.\n"
        "Only propose operations that are clearly supported by the reflection. Keep it minimal."
    )
    user = (
        f"Current playbook:\n{pb_text}\n\n"
        f"Problem context: {question[:200]}\n\n"
        f"Reflection:\n{reflection}"
    )
    raw = llm_call(system, user, role="curate", temperature=0.3)

    # Parse operations from JSON array
    ops = []
    # Try to find JSON array in response
    json_match = re.search(r'\[\s*\{.*?\}\s*\]', raw, re.DOTALL)
    if json_match:
        try:
            ops = json.loads(json_match.group())
        except json.JSONDecodeError:
            pass

    # Apply operations
    for op in ops:
        try:
            if op.get("op") == "ADD" and pb.size < MAX_BULLETS:
                pb.add(op.get("section", "STRATEGIES"), op.get("content", ""))
            elif op.get("op") == "UPDATE" and op.get("id"):
                pb.update(op["id"], op.get("content", ""))
            elif op.get("op") == "DELETE" and op.get("id"):
                pb.remove(op["id"])
        except Exception:
            pass

    # Safety: if curator emptied the playbook, reset
    if pb.size == 0:
        pb = make_initial_playbook()
        pb._next_id = playbook._next_id

    return pb


def _apply_curate_ops(pb, raw, original_playbook):
    """Shared logic for parsing and applying curator operations."""
    ops = []
    json_match = re.search(r'\[\s*\{.*?\}\s*\]', raw, re.DOTALL)
    if json_match:
        try:
            ops = json.loads(json_match.group())
        except json.JSONDecodeError:
            pass
    for op in ops:
        try:
            if op.get("op") == "ADD" and pb.size < MAX_BULLETS:
                pb.add(op.get("section", "STRATEGIES"), op.get("content", ""))
            elif op.get("op") == "UPDATE" and op.get("id"):
                pb.update(op["id"], op.get("content", ""))
            elif op.get("op") == "DELETE" and op.get("id"):
                pb.remove(op["id"])
        except Exception:
            pass
    if pb.size == 0:
        pb = make_initial_playbook()
        pb._next_id = original_playbook._next_id
    return pb


async def curate_async(playbook, reflection, question):
    """Async version of curate. Uses llm_call_async to avoid blocking the event loop."""
    pb = playbook.copy()
    pb_text = pb.to_str()
    system = (
        "You are a playbook curator for math problem solving. Based on the reflection, "
        "propose operations to improve the playbook.\n"
        "Output a JSON array of operations:\n"
        '[{"op": "ADD", "section": "STRATEGIES", "content": "new insight"},\n'
        ' {"op": "UPDATE", "id": "str-00001", "content": "refined text"},\n'
        ' {"op": "DELETE", "id": "err-00002"}]\n'
        f"Sections: STRATEGIES, COMMON_MISTAKES, SOLUTION_PATTERNS\n"
        f"Max bullets: {MAX_BULLETS}. Current: {pb.size}.\n"
        "Only propose operations that are clearly supported by the reflection. Keep it minimal."
    )
    user = (
        f"Current playbook:\n{pb_text}\n\n"
        f"Problem context: {question[:200]}\n\n"
        f"Reflection:\n{reflection}"
    )
    raw = await llm_call_async(system, user, role="curate", temperature=0.3, max_tokens=256)
    return _apply_curate_ops(pb, raw, playbook)

## 4. Search Strategies

In [None]:
# --- Shared tracking ---

@dataclass
class RunLog:
    """Tracks per-problem results for a single condition."""
    correct: List[bool] = field(default_factory=list)
    playbook_sizes: List[int] = field(default_factory=list)
    call_counts: Dict[str, int] = field(default_factory=dict)
    final_playbook: Optional[Playbook] = None

    @property
    def running_accuracy(self) -> List[float]:
        acc = []
        total = 0
        for i, c in enumerate(self.correct):
            total += int(c)
            acc.append(total / (i + 1))
        return acc

    @property
    def final_accuracy(self) -> float:
        if not self.correct:
            return 0.0
        tail = self.correct[-20:]
        return sum(tail) / len(tail)

def save_checkpoint(name: str, log: RunLog):
    """Save a condition's RunLog to disk for crash recovery."""
    path = CHECKPOINT_DIR / f"{name}.pkl"
    with open(path, "wb") as f:
        pickle.dump(log, f)
    print(f"  Checkpoint saved: {path}")

def load_checkpoint(name: str) -> Optional[RunLog]:
    """Load a condition's RunLog from disk if it exists."""
    path = CHECKPOINT_DIR / f"{name}.pkl"
    if path.exists():
        with open(path, "rb") as f:
            log = pickle.load(f)
        print(f"  Loaded checkpoint: {path} ({len(log.correct)} problems)")
        return log
    return None

In [None]:
# --- Strategy 1: Greedy ACE (async batched within curate intervals) ---

CURATE_EVERY = 5

def run_greedy(problems: List[dict]) -> RunLog:
    """Sequential generate -> reflect -> curate loop, with async batching.

    Within each curate interval (5 problems), we fire all generate calls in parallel,
    then all reflect calls in parallel, then curate once (sync, since it's 1 call).
    """
    reset_call_counter()
    log = RunLog()
    pb = make_initial_playbook()

    async def _run():
        nonlocal pb
        for batch_start in range(0, len(problems), CURATE_EVERY):
            batch = problems[batch_start:batch_start + CURATE_EVERY]

            # Parallel generate
            gen_tasks = [generate_async(prob["question"], pb) for prob in batch]
            gen_results = await asyncio.gather(*gen_tasks)

            # Parallel reflect
            ref_tasks = []
            for prob, (answer, bullets_used, raw) in zip(batch, gen_results):
                ref_tasks.append(reflect_async(
                    prob["question"], raw, answer, prob["answer"], bullets_used, pb
                ))
            ref_results = await asyncio.gather(*ref_tasks)

            # Process results sequentially (tag bullets, log)
            last_reflection = ""
            last_question = ""
            for prob, (answer, bullets_used, raw), (reflection, tags) in zip(batch, gen_results, ref_results):
                correct = answers_match(answer, prob["answer"])
                log.correct.append(correct)
                log.playbook_sizes.append(pb.size)
                for bid, label in tags.items():
                    pb.tag(bid, label)
                last_reflection = reflection
                last_question = prob["question"]

            # Curate (single sync call)
            if last_reflection:
                pb = await curate_async(pb, last_reflection, last_question)

            done = batch_start + len(batch)
            if done % 10 == 0 or done == len(problems):
                acc = sum(log.correct) / len(log.correct)
                print(f"  Greedy [{done}/{len(problems)}] acc={acc:.2%} bullets={pb.size}")

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    log.final_playbook = pb
    return log

In [None]:
# --- Strategy 2: PUCT-ACE with virtual loss, async eval, Q-estimator variants ---

@dataclass
class MCTSNode:
    playbook: Playbook
    parent: Optional["MCTSNode"] = None
    children: List["MCTSNode"] = field(default_factory=list)
    visits: int = 0
    reward_history: List[float] = field(default_factory=list)
    results: List[bool] = field(default_factory=list)
    # Virtual loss for parallel leaf selection
    virtual_loss_count: int = 0
    # AB-MCTS: Beta posterior for wider-vs-deeper decision
    expand_alpha: float = 1.0
    expand_beta: float = 1.0

    def q_value(self, mode: str = "mean") -> float:
        if not self.reward_history:
            return 0.5  # optimistic prior for unvisited nodes
        if mode == "mean":
            return sum(self.reward_history) / len(self.reward_history)
        elif mode == "ema":
            alpha = 0.4
            q = self.reward_history[0]
            for r in self.reward_history[1:]:
                q = alpha * r + (1 - alpha) * q
            return q
        elif mode == "bayesian":
            s = sum(self.reward_history)
            n = len(self.reward_history)
            return (s + 1) / (n + 2)
        elif mode == "variance":
            # Variance-aware Q (inspired by arXiv:2512.21648)
            # UCB-V style: mean + c * sqrt(variance / n)
            n = len(self.reward_history)
            mean_q = sum(self.reward_history) / n
            if n < 2:
                return mean_q + 0.5  # high uncertainty bonus when few samples
            var_q = sum((r - mean_q) ** 2 for r in self.reward_history) / n
            return mean_q + 0.5 * math.sqrt(var_q / n)
        return 0.5

    def effective_q(self, mode: str = "mean") -> float:
        """Q-value adjusted for virtual losses (used during parallel selection)."""
        if not self.reward_history and self.virtual_loss_count == 0:
            return 0.5
        effective_n = len(self.reward_history) + self.virtual_loss_count
        if effective_n == 0:
            return 0.5
        real_sum = sum(self.reward_history)
        # Virtual losses assume reward of 0
        return real_sum / effective_n

    @property
    def effective_visits(self) -> int:
        return self.visits + self.virtual_loss_count

    def add_virtual_loss(self):
        self.virtual_loss_count += 1

    def remove_virtual_loss(self):
        self.virtual_loss_count = max(0, self.virtual_loss_count - 1)


def puct_select(node: MCTSNode, c_puct: float = 1.5, q_mode: str = "mean",
                use_virtual_loss: bool = False) -> MCTSNode:
    """Select best child via PUCT, recurse to leaf."""
    if not node.children:
        return node
    n_parent = node.effective_visits if use_virtual_loss else node.visits
    best, best_score = None, -float("inf")
    for child in node.children:
        prior = 1.0 / len(node.children)
        if use_virtual_loss:
            exploit = child.effective_q(q_mode)
            n_child = child.effective_visits
        else:
            exploit = child.q_value(q_mode)
            n_child = child.visits
        explore = c_puct * prior * math.sqrt(n_parent) / (1 + n_child)
        score = exploit + explore
        if score > best_score:
            best_score = score
            best = child
    return puct_select(best, c_puct, q_mode, use_virtual_loss)


def select_batch_with_virtual_loss(root: MCTSNode, k: int, c_puct: float = 1.5,
                                    q_mode: str = "mean") -> List[Tuple[MCTSNode, List[MCTSNode]]]:
    """Select k diverse leaves using virtual loss to discourage repeat selection."""
    selections = []
    for _ in range(k):
        node = root
        path = []
        while node.children:
            n_parent = node.effective_visits
            best, best_score = None, -float("inf")
            for child in node.children:
                prior = 1.0 / len(node.children)
                exploit = child.effective_q(q_mode)
                explore = c_puct * prior * math.sqrt(n_parent) / (1 + child.effective_visits)
                score = exploit + explore
                if score > best_score:
                    best_score = score
                    best = child
            best.add_virtual_loss()
            path.append(best)
            node = best
        selections.append((node, path))
    return selections


def backprop(node: MCTSNode, reward: float):
    """Backpropagate reward up the tree."""
    while node is not None:
        node.visits += 1
        node.reward_history.append(reward)
        node = node.parent


def should_expand(node: MCTSNode, pw_alpha: float = 0.5) -> bool:
    """Progressive widening: ceil(visits^alpha) children allowed."""
    max_children = max(1, math.ceil(node.visits ** pw_alpha))
    return len(node.children) < max_children


async def _eval_leaf_async(leaf: MCTSNode, batch: List[dict]) -> Tuple[List[bool], str, str, Playbook]:
    """Evaluate a leaf's playbook on a batch of problems asynchronously."""
    eval_pb = leaf.playbook.copy()

    # Parallel generate all problems in batch
    gen_tasks = [generate_async(prob["question"], eval_pb) for prob in batch]
    gen_results = await asyncio.gather(*gen_tasks)

    # Parallel reflect all
    ref_tasks = []
    for prob, (answer, bullets_used, raw) in zip(batch, gen_results):
        ref_tasks.append(reflect_async(
            prob["question"], raw, answer, prob["answer"], bullets_used, eval_pb
        ))
    ref_results = await asyncio.gather(*ref_tasks)

    # Apply tags
    batch_correct = []
    last_reflection = ""
    last_question = ""
    for prob, (answer, bullets_used, raw), (reflection, tags) in zip(batch, gen_results, ref_results):
        correct = answers_match(answer, prob["answer"])
        batch_correct.append(correct)
        for bid, label in tags.items():
            eval_pb.tag(bid, label)
        last_reflection = reflection
        last_question = prob["question"]

    return batch_correct, last_reflection, last_question, eval_pb


def run_puct(problems: List[dict], batch_size: int = 3, c_puct: float = 1.5,
             q_mode: str = "mean") -> RunLog:
    """PUCT tree search with virtual loss parallel leaf selection and async eval."""
    reset_call_counter()
    log = RunLog()
    root = MCTSNode(playbook=make_initial_playbook())
    prob_idx = 0
    # Number of leaves to evaluate in parallel
    n_parallel = min(4, max(1, len(problems) // (batch_size * 4)))

    async def _run():
        nonlocal prob_idx

        while prob_idx < len(problems):
            # How many leaves can we evaluate given remaining problems?
            remaining = len(problems) - prob_idx
            n_leaves = min(n_parallel, remaining // batch_size) if remaining >= batch_size else 1

            if n_leaves <= 1:
                # Single leaf evaluation (no virtual loss needed)
                leaf = puct_select(root, c_puct, q_mode)
                batch_end = min(prob_idx + batch_size, len(problems))
                batch = problems[prob_idx:batch_end]

                batch_correct, last_reflection, last_question, eval_pb = \
                    await _eval_leaf_async(leaf, batch)

                prob_idx = batch_end
                reward = sum(batch_correct) / len(batch_correct)
                leaf.results.extend(batch_correct)
                log.correct.extend(batch_correct)
                for _ in batch_correct:
                    log.playbook_sizes.append(eval_pb.size)

                if should_expand(leaf):
                    new_pb = await curate_async(eval_pb, last_reflection, last_question)
                    child = MCTSNode(playbook=new_pb, parent=leaf)
                    leaf.children.append(child)

                backprop(leaf, reward)
            else:
                # Parallel multi-leaf evaluation with virtual loss
                selections = select_batch_with_virtual_loss(root, n_leaves, c_puct, q_mode)

                # Build batches for each leaf
                eval_tasks = []
                leaf_batches = []
                for leaf, path in selections:
                    batch_end = min(prob_idx + batch_size, len(problems))
                    if prob_idx >= len(problems):
                        break
                    batch = problems[prob_idx:batch_end]
                    prob_idx = batch_end
                    leaf_batches.append((leaf, path, batch))
                    eval_tasks.append(_eval_leaf_async(leaf, batch))

                # Evaluate all leaves in parallel
                eval_results = await asyncio.gather(*eval_tasks)

                # Process results: remove virtual losses, backprop, expand
                for (leaf, path, batch), (batch_correct, last_ref, last_q, eval_pb) in \
                        zip(leaf_batches, eval_results):
                    # Remove virtual losses
                    for node in path:
                        node.remove_virtual_loss()

                    reward = sum(batch_correct) / len(batch_correct)
                    leaf.results.extend(batch_correct)
                    log.correct.extend(batch_correct)
                    for _ in batch_correct:
                        log.playbook_sizes.append(eval_pb.size)

                    if should_expand(leaf):
                        new_pb = await curate_async(eval_pb, last_ref, last_q)
                        child = MCTSNode(playbook=new_pb, parent=leaf)
                        leaf.children.append(child)

                    backprop(leaf, reward)

            done = len(log.correct)
            if done % 10 <= batch_size * n_parallel or done == len(problems):
                acc = sum(log.correct) / len(log.correct)
                depth = _tree_depth(root)
                size = _tree_size(root)
                branches = _branch_count(root)
                print(f"  PUCT-{q_mode} [{done}/{len(problems)}] acc={acc:.2%} "
                      f"depth={depth} nodes={size} branches={branches}")

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    best_leaf = _best_leaf(root, q_mode)
    log.final_playbook = best_leaf.playbook
    return log


def _tree_depth(node: MCTSNode) -> int:
    if not node.children:
        return 0
    return 1 + max(_tree_depth(c) for c in node.children)


def _tree_size(node: MCTSNode) -> int:
    return 1 + sum(_tree_size(c) for c in node.children)


def _branch_count(node: MCTSNode) -> int:
    count = 1 if len(node.children) > 1 else 0
    return count + sum(_branch_count(c) for c in node.children)


def _best_leaf(node: MCTSNode, q_mode: str = "mean") -> MCTSNode:
    if not node.children:
        return node
    best = max(node.children, key=lambda c: c.q_value(q_mode) if c.visits > 0 else -1)
    return _best_leaf(best, q_mode)

In [None]:
# --- Strategy 3: Batch Thompson Sampling (Optimized) ---

THOMPSON_SEED_PROBLEMS = 6
THOMPSON_N_VARIANTS = 5
THOMPSON_BATCH_SIZE = 5

def run_thompson(problems: List[dict]) -> RunLog:
    """Batch Thompson Sampling with async seed and batched parallel generate+reflect.

    Budget: 50 generate + 50 reflect + 5 curate = 105 calls.
    """
    reset_call_counter()
    log = RunLog()
    base_pb = make_initial_playbook()
    pool = [base_pb]

    seed = problems[:THOMPSON_SEED_PROBLEMS]
    remaining = problems[THOMPSON_SEED_PROBLEMS:]

    async def _run():
        nonlocal pool

        # --- Phase 1: Seed (Parallel) ---
        gen_tasks = [generate_async(prob["question"], base_pb) for prob in seed]
        gen_results = await asyncio.gather(*gen_tasks)

        ref_tasks = []
        for prob, (answer, bullets_used, raw) in zip(seed, gen_results):
            ref_tasks.append(reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, base_pb
            ))
        ref_results = await asyncio.gather(*ref_tasks)

        reflections = []
        for prob, (answer, bullets_used, raw), (reflection, tags) in zip(seed, gen_results, ref_results):
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(base_pb.size)
            for bid, label in tags.items():
                base_pb.tag(bid, label)
            reflections.append((reflection, prob["question"]))

        # Curate variants (parallel)
        pool = await asyncio.gather(*[
            curate_async(base_pb, ref_text, ref_q)
            for ref_text, ref_q in reflections[:THOMPSON_N_VARIANTS]
        ])
        pool = list(pool)

        print(f"  Thompson: created {len(pool)} playbook variants from {len(seed)} seed problems")

        # --- Phase 2: Batch Thompson Sampling (Parallel) ---
        alphas = [1.0] * len(pool)
        betas_param = [1.0] * len(pool)

        for batch_start in range(0, len(remaining), THOMPSON_BATCH_SIZE):
            batch = remaining[batch_start : batch_start + THOMPSON_BATCH_SIZE]

            # Select arms for batch based on current posterior
            selected_indices = []
            selected_pbs = []
            for _ in batch:
                ts_samples = [np.random.beta(a, b) for a, b in zip(alphas, betas_param)]
                chosen = int(np.argmax(ts_samples))
                selected_indices.append(chosen)
                selected_pbs.append(pool[chosen])

            # Parallel Generation
            gen_tasks = [
                generate_async(prob["question"], pb)
                for prob, pb in zip(batch, selected_pbs)
            ]
            gen_results = await asyncio.gather(*gen_tasks)

            # Parallel Reflection
            ref_tasks = []
            for prob, (ans, bullets, raw), pb in zip(batch, gen_results, selected_pbs):
                ref_tasks.append(reflect_async(
                    prob["question"], raw, ans, prob["answer"], bullets, pb
                ))
            ref_results = await asyncio.gather(*ref_tasks)

            # Batch Update
            for i, (prob, (ans, bullets, raw), (ref, tags)) in enumerate(zip(batch, gen_results, ref_results)):
                correct = answers_match(ans, prob["answer"])
                log.correct.append(correct)

                chosen_idx = selected_indices[i]
                pb = selected_pbs[i]
                log.playbook_sizes.append(pb.size)

                for bid, label in tags.items():
                    pb.tag(bid, label)

                if correct:
                    alphas[chosen_idx] += 1.0
                else:
                    betas_param[chosen_idx] += 1.0

            # Logging
            done = THOMPSON_SEED_PROBLEMS + batch_start + len(batch)
            if done % 10 == 0 or done == len(problems):
                acc = sum(log.correct) / len(log.correct)
                pulls = [int(a + b - 2) for a, b in zip(alphas, betas_param)]
                best_var = int(np.argmax([a / (a + b) for a, b in zip(alphas, betas_param)]))
                print(f"  Thompson [{done}/{len(problems)}] acc={acc:.2%} pulls={pulls} best=variant-{best_var}")

        best_idx = int(np.argmax([a / (a + b) for a, b in zip(alphas, betas_param)]))
        log.final_playbook = pool[best_idx]

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    return log

In [None]:
# --- Strategy 3b: Dynamic Thompson Sampling (async, pool grows over time) ---

THOMPSON_DYN_SEED_PROBLEMS = 6
THOMPSON_DYN_N_VARIANTS = 5
ARM_ADD_INTERVAL = 10

def run_thompson_dynamic(problems: List[dict]) -> RunLog:
    """Thompson Sampling with dynamic arm addition (async).

    Inspired by OPTS (arXiv:2503.01163) bandit-based strategy selection.
    Note: OPTS uses TS for mutation operator selection; our dynamic arm addition
    from the best arm is a novel extension for the playbook evolution setting.

    Budget: ~109 calls (17 seed + 88 bandit + ~4 dynamic curate).
    """
    reset_call_counter()
    log = RunLog()

    base_pb = make_initial_playbook()
    pool = [base_pb]

    seed = problems[:THOMPSON_DYN_SEED_PROBLEMS]
    remaining = problems[THOMPSON_DYN_SEED_PROBLEMS:]

    async def _run():
        nonlocal pool

        # --- Phase 1: Seed (parallel generate+reflect) ---
        gen_tasks = [generate_async(prob["question"], base_pb) for prob in seed]
        gen_results = await asyncio.gather(*gen_tasks)

        ref_tasks = []
        for prob, (answer, bullets_used, raw) in zip(seed, gen_results):
            ref_tasks.append(reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, base_pb
            ))
        ref_results = await asyncio.gather(*ref_tasks)

        reflections = []
        for prob, (answer, bullets_used, raw), (reflection, tags) in zip(seed, gen_results, ref_results):
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(base_pb.size)
            for bid, label in tags.items():
                base_pb.tag(bid, label)
            reflections.append((reflection, prob["question"]))

        pool = await asyncio.gather(*[
            curate_async(base_pb, ref_text, ref_q)
            for ref_text, ref_q in reflections[:THOMPSON_DYN_N_VARIANTS]
        ])
        pool = list(pool)

        print(f"  Thompson-Dyn: created {len(pool)} initial variants from {len(seed)} seed problems")

        # --- Phase 2: Thompson Sampling with dynamic arm addition ---
        alphas = [1.0] * len(pool)
        betas_param = [1.0] * len(pool)
        last_reflections = {}

        for i, prob in enumerate(remaining):
            if i > 0 and i % ARM_ADD_INTERVAL == 0:
                posterior_means = [a / (a + b) for a, b in zip(alphas, betas_param)]
                best_arm = int(np.argmax(posterior_means))
                best_pb = pool[best_arm]

                if best_arm in last_reflections:
                    ref_text, ref_q = last_reflections[best_arm]
                else:
                    ref_text, ref_q = next(iter(last_reflections.values()), reflections[-1])

                new_variant = await curate_async(best_pb, ref_text, ref_q)
                pool.append(new_variant)
                alphas.append(1.0)
                betas_param.append(1.0)
                print(f"  Thompson-Dyn: added arm {len(pool)-1} (curated from best arm {best_arm}, "
                      f"posterior={posterior_means[best_arm]:.2f})")

            ts_samples = [np.random.beta(a, b) for a, b in zip(alphas, betas_param)]
            chosen = int(np.argmax(ts_samples))
            pb = pool[chosen]

            answer, bullets_used, raw = await generate_async(prob["question"], pb)
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(pb.size)

            reflection, tags = await reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, pb
            )
            for bid, label in tags.items():
                pb.tag(bid, label)

            last_reflections[chosen] = (reflection, prob["question"])

            if correct:
                alphas[chosen] += 1.0
            else:
                betas_param[chosen] += 1.0

            if (THOMPSON_DYN_SEED_PROBLEMS + i + 1) % 10 == 0:
                acc = sum(log.correct) / len(log.correct)
                pulls = [int(a + b - 2) for a, b in zip(alphas, betas_param)]
                posterior_means = [a / (a + b) for a, b in zip(alphas, betas_param)]
                best_var = int(np.argmax(posterior_means))
                print(f"  Thompson-Dyn [{THOMPSON_DYN_SEED_PROBLEMS + i + 1}/{len(problems)}] "
                      f"acc={acc:.2%} arms={len(pool)} pulls={pulls} best=variant-{best_var}")

        posterior_means = [a / (a + b) for a, b in zip(alphas, betas_param)]
        best_idx = int(np.argmax(posterior_means))
        log.final_playbook = pool[best_idx]

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    return log

In [None]:
# --- Strategy 5: Adaptive Progressive Widening (AB-MCTS-inspired, async) ---
# Inspired by Sakana AI (arXiv:2503.04412). Our implementation differs from the
# original paper: we use regret-based Beta updates instead of separate GEN/CONT
# node types with backed-up score distributions. Renamed from "AB-MCTS" to
# "Adaptive Progressive Widening" to reflect this distinction.
#
# Key fix: removed counterfactual assumption in deepen+regress case. Previously
# we incremented expand_alpha when deepening didn't improve Q, assuming expansion
# would have helped — but we never observed that counterfactual. Now we only
# update the posterior when we have direct evidence about the chosen action.

def run_ab_mcts(problems: List[dict], batch_size: int = 3, c_puct: float = 1.5) -> RunLog:
    """Adaptive Progressive Widening with Thompson-sampled expand/deepen (async).

    Each node uses Beta(expand_alpha, expand_beta) to decide wider vs deeper.
    Posterior updated only on direct evidence (no counterfactual assumptions):
      - Went wider AND reward > prev_q -> expand_alpha += 1 (expanding helped)
      - Went wider AND reward <= prev_q -> expand_beta += 1 (expanding didn't help)
      - Went deeper AND reward > prev_q -> expand_beta += 1 (deepening helped)
      - Went deeper AND reward <= prev_q -> no update (no evidence either way)

    Uses Bayesian Q-estimator. Budget: 50 gen + 50 ref + variable curate <= ~117.
    """
    reset_call_counter()
    log = RunLog()
    q_mode = "bayesian"
    root = MCTSNode(playbook=make_initial_playbook())
    prob_idx = 0
    expand_count = 0
    deepen_count = 0

    async def _run():
        nonlocal prob_idx, expand_count, deepen_count

        while prob_idx < len(problems):
            leaf = puct_select(root, c_puct, q_mode)
            prev_q = leaf.q_value(q_mode)

            batch_end = min(prob_idx + batch_size, len(problems))
            batch = problems[prob_idx:batch_end]
            prob_idx = batch_end

            batch_correct, last_reflection, last_question, eval_pb = \
                await _eval_leaf_async(leaf, batch)

            reward = sum(batch_correct) / len(batch_correct)
            leaf.results.extend(batch_correct)
            log.correct.extend(batch_correct)
            for _ in batch_correct:
                log.playbook_sizes.append(eval_pb.size)

            # Thompson-sampled wider-vs-deeper decision
            ts_sample = np.random.beta(leaf.expand_alpha, leaf.expand_beta)
            go_wider = ts_sample > 0.5 and leaf.visits > 0

            if go_wider:
                expand_count += 1
                new_pb = await curate_async(eval_pb, last_reflection, last_question)
                child = MCTSNode(playbook=new_pb, parent=leaf)
                leaf.children.append(child)
                backprop(leaf, reward)

                # Direct evidence: did expanding help?
                if reward > prev_q:
                    leaf.expand_alpha += 1
                else:
                    leaf.expand_beta += 1
            else:
                deepen_count += 1
                backprop(leaf, reward)

                # Direct evidence only: did deepening help?
                if reward > prev_q:
                    leaf.expand_beta += 1
                # No update if deepening didn't help — we have no evidence
                # that expanding would have been better (counterfactual).

            done = len(log.correct)
            if done % 10 <= batch_size or done == len(problems):
                acc = sum(log.correct) / len(log.correct)
                depth = _tree_depth(root)
                size = _tree_size(root)
                print(f"  AB-MCTS [{done}/{len(problems)}] acc={acc:.2%} "
                      f"depth={depth} nodes={size} wider={expand_count} deeper={deepen_count}")

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    best_leaf = _best_leaf(root, q_mode)
    log.final_playbook = best_leaf.playbook
    print(f"  AB-MCTS final: wider={expand_count} deeper={deepen_count} "
          f"ratio={expand_count/(expand_count+deepen_count):.2f}")
    return log

In [None]:
# --- Strategy 3c: Discounted Thompson Sampling (async, fixed timing) ---
# From the non-stationary bandits literature (arXiv:2305.10718):
# Discount is applied AFTER the posterior update, not before.
# Floor lowered from 1.0 to 0.1 to allow proper forgetting.

THOMPSON_DISC_GAMMA = 0.95
THOMPSON_DISC_SEED_PROBLEMS = 6
THOMPSON_DISC_N_VARIANTS = 5

def run_thompson_discounted(problems: List[dict]) -> RunLog:
    """Discounted Thompson Sampling with corrected discount timing (async).

    Discount applied AFTER posterior update (per arXiv:2305.10718), not before.
    Floor at 0.1 (not 1.0) to allow proper exponential forgetting.
    Effective lookback window: 1/(1-0.95) = 20 problems.

    Budget: 105 calls (17 seed + 88 bandit).
    """
    reset_call_counter()
    log = RunLog()

    base_pb = make_initial_playbook()
    pool = [base_pb]

    seed = problems[:THOMPSON_DISC_SEED_PROBLEMS]
    remaining = problems[THOMPSON_DISC_SEED_PROBLEMS:]

    async def _run():
        # --- Phase 1: Seed (parallel generate+reflect) ---
        gen_tasks = [generate_async(prob["question"], base_pb) for prob in seed]
        gen_results = await asyncio.gather(*gen_tasks)

        ref_tasks = []
        for prob, (answer, bullets_used, raw) in zip(seed, gen_results):
            ref_tasks.append(reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, base_pb
            ))
        ref_results = await asyncio.gather(*ref_tasks)

        reflections = []
        for prob, (answer, bullets_used, raw), (reflection, tags) in zip(seed, gen_results, ref_results):
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(base_pb.size)
            for bid, label in tags.items():
                base_pb.tag(bid, label)
            reflections.append((reflection, prob["question"]))

        pool = await asyncio.gather(*[
            curate_async(base_pb, ref_text, ref_q)
            for ref_text, ref_q in reflections[:THOMPSON_DISC_N_VARIANTS]
        ])
        pool = list(pool)

        print(f"  Thompson-Disc: created {len(pool)} playbook variants from {len(seed)} seed problems")
        print(f"  Thompson-Disc: gamma={THOMPSON_DISC_GAMMA}, effective horizon={1/(1-THOMPSON_DISC_GAMMA):.0f} problems")

        # --- Phase 2: Discounted Thompson Sampling ---
        alphas = [1.0] * len(pool)
        betas_param = [1.0] * len(pool)

        for i, prob in enumerate(remaining):
            # Thompson sample (BEFORE any discounting)
            ts_samples = [np.random.beta(a, b) for a, b in zip(alphas, betas_param)]
            chosen = int(np.argmax(ts_samples))
            pb = pool[chosen]

            answer, bullets_used, raw = await generate_async(prob["question"], pb)
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(pb.size)

            reflection, tags = await reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, pb
            )
            for bid, label in tags.items():
                pb.tag(bid, label)

            # Update Beta posterior for the chosen variant
            if correct:
                alphas[chosen] += 1.0
            else:
                betas_param[chosen] += 1.0

            # Discount ALL arms AFTER update (per arXiv:2305.10718)
            for k in range(len(pool)):
                alphas[k] *= THOMPSON_DISC_GAMMA
                betas_param[k] *= THOMPSON_DISC_GAMMA
                # Floor at 0.1 (not 1.0) to allow proper forgetting
                alphas[k] = max(alphas[k], 0.1)
                betas_param[k] = max(betas_param[k], 0.1)

            if (THOMPSON_DISC_SEED_PROBLEMS + i + 1) % 10 == 0:
                acc = sum(log.correct) / len(log.correct)
                pulls = [int(a + b - 2) for a, b in zip(alphas, betas_param)]
                best_var = int(np.argmax([a / (a + b) for a, b in zip(alphas, betas_param)]))
                print(f"  Thompson-Disc [{THOMPSON_DISC_SEED_PROBLEMS + i + 1}/{len(problems)}] "
                      f"acc={acc:.2%} pulls={pulls} best=variant-{best_var}")

        best_idx = int(np.argmax([a / (a + b) for a, b in zip(alphas, betas_param)]))
        log.final_playbook = pool[best_idx]

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    return log

In [None]:
# --- Strategy 4: Majority Vote (fully parallel, no evolution) ---

MAJORITY_N_SAMPLES = 2

def run_majority_vote(problems: List[dict]) -> RunLog:
    """No playbook evolution. Static initial playbook. All generate calls in parallel.

    Budget: 50 * 2 = 100 generate calls. No reflect or curate.
    """
    reset_call_counter()
    log = RunLog()
    pb = make_initial_playbook()

    async def _run():
        # Fire ALL 100 generate calls at once — no dependencies between them
        tasks = []
        for prob in problems:
            for _ in range(MAJORITY_N_SAMPLES):
                tasks.append(generate_async(prob["question"], pb))

        all_results = await asyncio.gather(*tasks)

        # Group results by problem (every MAJORITY_N_SAMPLES consecutive results)
        for i, prob in enumerate(problems):
            answers = []
            for j in range(MAJORITY_N_SAMPLES):
                answer, _, _ = all_results[i * MAJORITY_N_SAMPLES + j]
                answers.append(answer)

            vote_counts = Counter(answers)
            majority_answer = vote_counts.most_common(1)[0][0]
            correct = answers_match(majority_answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(pb.size)

        acc = sum(log.correct) / len(log.correct)
        print(f"  MajVote [done] acc={acc:.2%}")

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    log.final_playbook = pb
    return log

In [None]:
# --- Strategy 5: Best-of-N (Rejection Sampling, fully parallel) ---

BEST_OF_N_SAMPLES = 2

def run_best_of_n(problems: List[dict]) -> RunLog:
    """No playbook evolution. Generate N solutions per problem, pick the one with
    highest answer frequency (like majority vote but framed as rejection sampling).

    This is the "does a smart tree search actually beat simply generating N random
    variations and picking the best one?" baseline.

    Budget: 50 * 2 = 100 generate calls. No reflect or curate.
    """
    reset_call_counter()
    log = RunLog()
    pb = make_initial_playbook()

    async def _run():
        # Fire all generate calls in parallel
        tasks = []
        for prob in problems:
            for _ in range(BEST_OF_N_SAMPLES):
                tasks.append(generate_async(prob["question"], pb))

        all_results = await asyncio.gather(*tasks)

        for i, prob in enumerate(problems):
            answers = []
            raws = []
            for j in range(BEST_OF_N_SAMPLES):
                answer, _, raw = all_results[i * BEST_OF_N_SAMPLES + j]
                answers.append(answer)
                raws.append(raw)

            # Pick the answer that appears most (ties broken by first occurrence)
            # Unlike majority vote which just takes the mode, best-of-N conceptually
            # selects the "best" — here we use consistency as a proxy for confidence
            vote_counts = Counter(answers)
            best_answer = vote_counts.most_common(1)[0][0]
            correct = answers_match(best_answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(pb.size)

        acc = sum(log.correct) / len(log.correct)
        print(f"  Best-of-N [done] acc={acc:.2%}")

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    log.final_playbook = pb
    return log

In [None]:
# --- Strategy 6: Beam Search over Playbook Variants ---
# The "missing link" between Greedy (width=1) and MCTS (complex tree).
# Beam search maintains K candidate playbooks, evaluates all on each batch,
# prunes to top-K, then curates from the survivors.

BEAM_WIDTH = 3
BEAM_EVAL_BATCH = 5

def run_beam_search(problems: List[dict], beam_width: int = BEAM_WIDTH,
                    eval_batch: int = BEAM_EVAL_BATCH) -> RunLog:
    """Beam search over playbook space.

    Maintains K playbook candidates. Each round:
    1. Evaluate all K beams on the same batch of problems (parallel)
    2. Rank beams by batch accuracy
    3. Keep top-K
    4. Curate each surviving beam to produce next-generation candidates

    Hypothesis: captures ~80% of tree search gains at ~20% of the cost.

    Budget: ~50 gen + ~50 ref + ~30 curate = ~130 calls (slightly over budget
    due to K-way parallel eval, but comparable total compute since batches overlap).
    """
    reset_call_counter()
    log = RunLog()

    # Initialize beam with K copies of the base playbook curated from different seeds
    base_pb = make_initial_playbook()

    async def _run():
        # Seed phase: generate+reflect on first few problems to get diverse curations
        seed = problems[:eval_batch]
        remaining = problems[eval_batch:]

        gen_tasks = [generate_async(prob["question"], base_pb) for prob in seed]
        gen_results = await asyncio.gather(*gen_tasks)

        ref_tasks = []
        for prob, (answer, bullets_used, raw) in zip(seed, gen_results):
            ref_tasks.append(reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, base_pb
            ))
        ref_results = await asyncio.gather(*ref_tasks)

        reflections = []
        for prob, (answer, bullets_used, raw), (reflection, tags) in zip(seed, gen_results, ref_results):
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(base_pb.size)
            for bid, label in tags.items():
                base_pb.tag(bid, label)
            reflections.append((reflection, prob["question"]))

        # Create initial beam from diverse curations
        beams = []
        _beam_variants = await asyncio.gather(*[
            curate_async(base_pb, reflections[i % len(reflections)][0], reflections[i % len(reflections)][1])
            for i in range(beam_width)
        ])
        for variant in _beam_variants:
            beams.append({"playbook": variant, "score": 0.0, "total_correct": 0, "total_seen": 0})

        print(f"  Beam: initialized {len(beams)} beams from {len(seed)} seed problems")

        # Main beam search loop
        prob_idx = 0
        while prob_idx < len(remaining):
            batch_end = min(prob_idx + eval_batch, len(remaining))
            batch = remaining[prob_idx:batch_end]
            prob_idx = batch_end

            # Evaluate ALL beams on the same batch (parallel across beams AND problems)
            all_eval_tasks = []
            for beam in beams:
                beam_tasks = [generate_async(prob["question"], beam["playbook"]) for prob in batch]
                all_eval_tasks.extend(beam_tasks)

            all_gen_results = await asyncio.gather(*all_eval_tasks)

            # Reflect on all results
            all_ref_tasks = []
            idx = 0
            for beam in beams:
                for prob in batch:
                    answer, bullets_used, raw = all_gen_results[idx]
                    all_ref_tasks.append(reflect_async(
                        prob["question"], raw, answer, prob["answer"], bullets_used, beam["playbook"]
                    ))
                    idx += 1

            all_ref_results = await asyncio.gather(*all_ref_tasks)

            # Score each beam on this batch
            idx = 0
            beam_batch_scores = []
            beam_reflections = []
            for beam in beams:
                batch_correct = 0
                last_ref = ""
                last_q = ""
                for prob in batch:
                    answer, bullets_used, raw = all_gen_results[idx]
                    reflection, tags = all_ref_results[idx]
                    correct = answers_match(answer, prob["answer"])
                    if correct:
                        batch_correct += 1
                    for bid, label in tags.items():
                        beam["playbook"].tag(bid, label)
                    last_ref = reflection
                    last_q = prob["question"]
                    idx += 1

                beam["total_correct"] += batch_correct
                beam["total_seen"] += len(batch)
                beam["score"] = beam["total_correct"] / beam["total_seen"]
                beam_batch_scores.append(batch_correct / len(batch))
                beam_reflections.append((last_ref, last_q))

            # Log results from the BEST beam for this batch
            best_beam_idx = int(np.argmax(beam_batch_scores))
            best_beam = beams[best_beam_idx]
            # Re-evaluate best beam's answers for logging (use cached results)
            gen_offset = best_beam_idx * len(batch)
            for j, prob in enumerate(batch):
                answer, _, _ = all_gen_results[gen_offset + j]
                correct = answers_match(answer, prob["answer"])
                log.correct.append(correct)
                log.playbook_sizes.append(best_beam["playbook"].size)

            # Prune: keep top-K beams by cumulative score
            beams.sort(key=lambda b: b["score"], reverse=True)
            beams = beams[:beam_width]

            # Curate each surviving beam
            new_beams = []
            _new_pbs = await asyncio.gather(*[
                curate_async(beam["playbook"], ref_text, ref_q)
                for beam, (ref_text, ref_q) in zip(beams, beam_reflections[:beam_width])
            ])
            for new_pb, beam in zip(_new_pbs, beams):
                new_beams.append({
                    "playbook": new_pb,
                    "score": beam["score"],
                    "total_correct": beam["total_correct"],
                    "total_seen": beam["total_seen"],
                })
            beams = new_beams

            done = len(log.correct)
            if done % 10 <= eval_batch or done == len(problems):
                acc = sum(log.correct) / len(log.correct)
                scores = [f"{b['score']:.0%}" for b in beams]
                print(f"  Beam [{done}/{len(problems)}] acc={acc:.2%} beam_scores={scores}")

        # Final playbook = best beam
        beams.sort(key=lambda b: b["score"], reverse=True)
        log.final_playbook = beams[0]["playbook"]

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    return log

In [None]:
# --- Strategy 7: UCB Bandit (deterministic counterpart to Thompson Sampling) ---
# Standard UCB1 (Auer et al., 2002) uses a deterministic upper confidence bound
# instead of Thompson Sampling's stochastic posterior sampling.
# Hypothesis: Thompson is usually better for exploration, but UCB is often more
# sample-efficient. Comparing them isolates the "randomness" variable.

UCB_SEED_PROBLEMS = 6
UCB_N_VARIANTS = 5
UCB_C = math.sqrt(2)  # standard UCB1 exploration constant

def run_ucb_bandit(problems: List[dict]) -> RunLog:
    """Flat bandit with UCB1 selection over playbook pool (async).

    UCB1 selects: argmax_k [ mean_reward_k + c * sqrt(ln(t) / n_k) ]
    where t = total pulls, n_k = pulls of arm k, c = sqrt(2).

    Budget: 50 generate + 50 reflect + 5 curate = 105 calls.
    """
    reset_call_counter()
    log = RunLog()

    base_pb = make_initial_playbook()
    pool = [base_pb]

    seed = problems[:UCB_SEED_PROBLEMS]
    remaining = problems[UCB_SEED_PROBLEMS:]

    async def _run():
        nonlocal pool

        # --- Phase 1: Seed (parallel generate+reflect) ---
        gen_tasks = [generate_async(prob["question"], base_pb) for prob in seed]
        gen_results = await asyncio.gather(*gen_tasks)

        ref_tasks = []
        for prob, (answer, bullets_used, raw) in zip(seed, gen_results):
            ref_tasks.append(reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, base_pb
            ))
        ref_results = await asyncio.gather(*ref_tasks)

        reflections = []
        for prob, (answer, bullets_used, raw), (reflection, tags) in zip(seed, gen_results, ref_results):
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(base_pb.size)
            for bid, label in tags.items():
                base_pb.tag(bid, label)
            reflections.append((reflection, prob["question"]))

        pool = await asyncio.gather(*[
            curate_async(base_pb, ref_text, ref_q)
            for ref_text, ref_q in reflections[:UCB_N_VARIANTS]
        ])
        pool = list(pool)

        print(f"  UCB: created {len(pool)} playbook variants from {len(seed)} seed problems")

        # --- Phase 2: UCB1 selection ---
        rewards = [[] for _ in pool]  # per-arm reward history
        total_pulls = 0

        for i, prob in enumerate(remaining):
            # UCB1 selection: ensure each arm pulled at least once
            unpulled = [k for k in range(len(pool)) if not rewards[k]]
            if unpulled:
                chosen = unpulled[0]
            else:
                total_pulls_val = sum(len(r) for r in rewards)
                ucb_scores = []
                for k in range(len(pool)):
                    mean_r = sum(rewards[k]) / len(rewards[k])
                    explore = UCB_C * math.sqrt(math.log(total_pulls_val) / len(rewards[k]))
                    ucb_scores.append(mean_r + explore)
                chosen = int(np.argmax(ucb_scores))

            pb = pool[chosen]
            answer, bullets_used, raw = await generate_async(prob["question"], pb)
            correct = answers_match(answer, prob["answer"])
            log.correct.append(correct)
            log.playbook_sizes.append(pb.size)

            reflection, tags = await reflect_async(
                prob["question"], raw, answer, prob["answer"], bullets_used, pb
            )
            for bid, label in tags.items():
                pb.tag(bid, label)

            rewards[chosen].append(1.0 if correct else 0.0)
            total_pulls += 1

            if (UCB_SEED_PROBLEMS + i + 1) % 10 == 0:
                acc = sum(log.correct) / len(log.correct)
                pulls = [len(r) for r in rewards]
                means = [f"{sum(r)/len(r):.0%}" if r else "?" for r in rewards]
                best_var = int(np.argmax([sum(r)/len(r) if r else 0 for r in rewards]))
                print(f"  UCB [{UCB_SEED_PROBLEMS + i + 1}/{len(problems)}] "
                      f"acc={acc:.2%} pulls={pulls} means={means} best=variant-{best_var}")

        best_idx = int(np.argmax([sum(r)/len(r) if r else 0 for r in rewards]))
        log.final_playbook = pool[best_idx]

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    return log

In [None]:
# --- Strategy 8: Discounted MCTS (PUCT with gamma-decayed Q-values) ---
# Applies the same gamma=0.95 discount idea from Thompson-Disc but inside the
# MCTS tree nodes. Standard PUCT suffers from non-stationarity: if a child node
# improves via curation, the parent's average Q doesn't update fast enough.
# Discounting recent rewards more heavily addresses this.

DISC_MCTS_GAMMA = 0.95

def run_discounted_mcts(problems: List[dict], batch_size: int = 3, c_puct: float = 1.5) -> RunLog:
    """PUCT tree search with discounted Q-values for non-stationary nodes.

    Q-values use exponentially weighted mean: recent rewards count more.
    This addresses the problem where a node's playbook improves over time
    (via child curations) but the parent's stale Q-average doesn't reflect it.

    Uses Bayesian base Q-estimator with discount overlay.
    Budget: 50 gen + 50 ref + variable curate <= ~117.
    """
    reset_call_counter()
    log = RunLog()
    root = MCTSNode(playbook=make_initial_playbook())
    prob_idx = 0

    def discounted_q(node: MCTSNode) -> float:
        """Compute gamma-discounted Q: recent rewards weighted more heavily."""
        if not node.reward_history:
            return 0.5  # optimistic prior
        # Exponentially discount: most recent reward has weight 1, next has gamma, etc.
        weights = []
        w = 1.0
        for _ in reversed(node.reward_history):
            weights.append(w)
            w *= DISC_MCTS_GAMMA
        weights.reverse()
        total_w = sum(weights)
        discounted = sum(r * w for r, w in zip(node.reward_history, weights))
        return discounted / total_w

    def disc_puct_select(node: MCTSNode, c: float = 1.5) -> MCTSNode:
        """PUCT selection using discounted Q."""
        if not node.children:
            return node
        n_parent = node.visits
        best, best_score = None, -float("inf")
        for child in node.children:
            prior = 1.0 / len(node.children)
            exploit = discounted_q(child)
            explore = c * prior * math.sqrt(n_parent) / (1 + child.visits)
            score = exploit + explore
            if score > best_score:
                best_score = score
                best = child
        return disc_puct_select(best, c)

    async def _run():
        nonlocal prob_idx

        while prob_idx < len(problems):
            leaf = disc_puct_select(root, c_puct)

            batch_end = min(prob_idx + batch_size, len(problems))
            batch = problems[prob_idx:batch_end]
            prob_idx = batch_end

            batch_correct, last_reflection, last_question, eval_pb = \
                await _eval_leaf_async(leaf, batch)

            reward = sum(batch_correct) / len(batch_correct)
            leaf.results.extend(batch_correct)
            log.correct.extend(batch_correct)
            for _ in batch_correct:
                log.playbook_sizes.append(eval_pb.size)

            if should_expand(leaf):
                new_pb = await curate_async(eval_pb, last_reflection, last_question)
                child = MCTSNode(playbook=new_pb, parent=leaf)
                leaf.children.append(child)

            backprop(leaf, reward)

            done = len(log.correct)
            if done % 10 <= batch_size or done == len(problems):
                acc = sum(log.correct) / len(log.correct)
                depth = _tree_depth(root)
                size = _tree_size(root)
                print(f"  Disc-MCTS [{done}/{len(problems)}] acc={acc:.2%} "
                      f"depth={depth} nodes={size}")

    asyncio.run(_run())
    log.call_counts = get_call_counts()
    best_leaf = _best_leaf(root, "mean")  # use mean for final selection
    log.final_playbook = best_leaf.playbook
    return log

## 5. Experiment Runner

In [None]:
print("=" * 60)
print("Running Majority Vote (no evolution baseline)...")
print("=" * 60)
majority_log = load_checkpoint("majority_vote")
if majority_log is None:
    majority_log = run_majority_vote(problems)
    save_checkpoint("majority_vote", majority_log)
print(f"MajVote done. Final acc (last 20): {majority_log.final_accuracy:.2%}")
print(f"Calls: {majority_log.call_counts}")

In [None]:
print("=" * 60)
print("Running Greedy ACE (baseline)...")
print("=" * 60)
greedy_log = load_checkpoint("greedy_ace")
if greedy_log is None:
    greedy_log = run_greedy(problems)
    save_checkpoint("greedy_ace", greedy_log)
print(f"Greedy done. Final acc (last 20): {greedy_log.final_accuracy:.2%}")
print(f"Calls: {greedy_log.call_counts}")

In [None]:
print("=" * 60)
print("Running Thompson Sampling...")
print("=" * 60)
thompson_log = load_checkpoint("thompson")
if thompson_log is None:
    thompson_log = run_thompson(problems)
    save_checkpoint("thompson", thompson_log)
print(f"Thompson done. Final acc (last 20): {thompson_log.final_accuracy:.2%}")
print(f"Calls: {thompson_log.call_counts}")

In [None]:
print("=" * 60)
print("Running Dynamic Thompson Sampling...")
print("=" * 60)
thompson_dyn_log = load_checkpoint("thompson_dyn")
if thompson_dyn_log is None:
    thompson_dyn_log = run_thompson_dynamic(problems)
    save_checkpoint("thompson_dyn", thompson_dyn_log)
print(f"Thompson-Dyn done. Final acc (last 20): {thompson_dyn_log.final_accuracy:.2%}")
print(f"Calls: {thompson_dyn_log.call_counts}")

In [None]:
puct_logs = {}
for q_mode in ["mean", "ema", "bayesian", "variance"]:
    ckpt_name = f"puct_{q_mode}"
    print("=" * 60)
    print(f"Running PUCT-{q_mode.upper()}...")
    print("=" * 60)
    log = load_checkpoint(ckpt_name)
    if log is None:
        log = run_puct(problems, q_mode=q_mode)
        save_checkpoint(ckpt_name, log)
    puct_logs[q_mode] = log
    print(f"PUCT-{q_mode} done. Final acc (last 20): {log.final_accuracy:.2%}")
    print(f"Calls: {log.call_counts}")
    print()

In [None]:
print("=" * 60)
print("Running AB-MCTS (Adaptive Progressive Widening)...")
print("=" * 60)
ab_mcts_log = load_checkpoint("ab_mcts")
if ab_mcts_log is None:
    ab_mcts_log = run_ab_mcts(problems)
    save_checkpoint("ab_mcts", ab_mcts_log)
print(f"AB-MCTS done. Final acc (last 20): {ab_mcts_log.final_accuracy:.2%}")
print(f"Calls: {ab_mcts_log.call_counts}")

In [None]:
print("=" * 60)
print("Running Discounted Thompson Sampling (gamma=0.95)...")
print("=" * 60)
thompson_disc_log = load_checkpoint("thompson_disc")
if thompson_disc_log is None:
    thompson_disc_log = run_thompson_discounted(problems)
    save_checkpoint("thompson_disc", thompson_disc_log)
print(f"Thompson-Disc done. Final acc (last 20): {thompson_disc_log.final_accuracy:.2%}")
print(f"Calls: {thompson_disc_log.call_counts}")

In [None]:
print("=" * 60)
print("Running Best-of-N (rejection sampling baseline)...")
print("=" * 60)
best_of_n_log = load_checkpoint("best_of_n")
if best_of_n_log is None:
    best_of_n_log = run_best_of_n(problems)
    save_checkpoint("best_of_n", best_of_n_log)
print(f"Best-of-N done. Final acc (last 20): {best_of_n_log.final_accuracy:.2%}")
print(f"Calls: {best_of_n_log.call_counts}")


In [None]:
print("=" * 60)
print("Running UCB Bandit...")
print("=" * 60)
ucb_log = load_checkpoint("ucb_bandit")
if ucb_log is None:
    ucb_log = run_ucb_bandit(problems)
    save_checkpoint("ucb_bandit", ucb_log)
print(f"UCB done. Final acc (last 20): {ucb_log.final_accuracy:.2%}")
print(f"Calls: {ucb_log.call_counts}")


In [None]:
print("=" * 60)
print("Running Beam Search (width=3)...")
print("=" * 60)
beam_log = load_checkpoint("beam_search")
if beam_log is None:
    beam_log = run_beam_search(problems)
    save_checkpoint("beam_search", beam_log)
print(f"Beam done. Final acc (last 20): {beam_log.final_accuracy:.2%}")
print(f"Calls: {beam_log.call_counts}")


In [None]:
print("=" * 60)
print("Running Discounted MCTS (gamma=0.95)...")
print("=" * 60)
disc_mcts_log = load_checkpoint("disc_mcts")
if disc_mcts_log is None:
    disc_mcts_log = run_discounted_mcts(problems)
    save_checkpoint("disc_mcts", disc_mcts_log)
print(f"Disc-MCTS done. Final acc (last 20): {disc_mcts_log.final_accuracy:.2%}")
print(f"Calls: {disc_mcts_log.call_counts}")


## 6. Analysis & Plotting

In [None]:
# Build results dict with all 14 conditions
results = {
    "Majority Vote": majority_log,
    "Best-of-N": best_of_n_log,
    "Greedy ACE": greedy_log,
    "Thompson": thompson_log,
    "UCB Bandit": ucb_log,
    "Thompson-Dyn": thompson_dyn_log,
    "Thompson-Disc": thompson_disc_log,
    "Beam Search": beam_log,
}
for q_mode, log in puct_logs.items():
    results[f"PUCT-{q_mode.upper()}"] = log
results["AB-MCTS"] = ab_mcts_log
results["Disc-MCTS"] = disc_mcts_log

COLORS = {
    "Majority Vote": "#9467bd",
    "Best-of-N": "#aec7e8",
    "Greedy ACE": "#1f77b4",
    "Thompson": "#8c564b",
    "UCB Bandit": "#c49c94",
    "Thompson-Dyn": "#e377c2",
    "Thompson-Disc": "#bcbd22",
    "Beam Search": "#98df8a",
    "PUCT-MEAN": "#ff7f0e",
    "PUCT-EMA": "#2ca02c",
    "PUCT-BAYESIAN": "#d62728",
    "PUCT-VARIANCE": "#7f7f7f",
    "AB-MCTS": "#17becf",
    "Disc-MCTS": "#dbdb8d",
}

fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle("ACE Search Strategy Comparison (GSM8K, 50 problems)", fontsize=14, fontweight="bold")

# --- (a) Running accuracy curves ---
ax = axes[0, 0]
for name, log in results.items():
    acc = log.running_accuracy
    ax.plot(range(1, len(acc) + 1), acc, label=name, color=COLORS[name], linewidth=2)
ax.set_xlabel("Problems Solved")
ax.set_ylabel("Running Accuracy")
ax.set_title("(a) Running Accuracy Over Time")
ax.legend(fontsize=5, loc="lower right", ncol=2)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# --- (b) Final accuracy bar chart with bootstrap CI ---
ax = axes[0, 1]

def bootstrap_ci(data, n_boot=1000, ci=0.95):
    data = np.array(data, dtype=float)
    means = [np.mean(np.random.choice(data, size=len(data), replace=True)) for _ in range(n_boot)]
    lo = np.percentile(means, (1 - ci) / 2 * 100)
    hi = np.percentile(means, (1 + ci) / 2 * 100)
    return np.mean(data), lo, hi

names = list(results.keys())
means, lows, highs = [], [], []
for name in names:
    tail = results[name].correct[-20:]
    m, lo, hi = bootstrap_ci(tail)
    means.append(m)
    lows.append(m - lo)
    highs.append(hi - m)

bars = ax.bar(names, means, color=[COLORS[n] for n in names], yerr=[lows, highs], capsize=3)
ax.set_ylabel("Accuracy (last 20 problems)")
ax.set_title("(b) Final Accuracy Comparison")
ax.set_ylim(0, 1)
ax.tick_params(axis='x', rotation=45, labelsize=4)
for bar, m in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f"{m:.0%}",
            ha="center", fontsize=5)

# --- (c) Playbook size over time ---
ax = axes[1, 0]
for name, log in results.items():
    sizes = log.playbook_sizes
    ax.plot(range(1, len(sizes) + 1), sizes, label=name, color=COLORS[name], linewidth=2)
ax.set_xlabel("Problems Solved")
ax.set_ylabel("Playbook Bullets")
ax.set_title("(c) Playbook Size Over Time")
ax.legend(fontsize=5, ncol=2)
ax.grid(True, alpha=0.3)

# --- (d) Cost breakdown ---
ax = axes[1, 1]
roles = ["generate", "reflect", "curate"]
x = np.arange(len(names))
width = 0.22
for i, role in enumerate(roles):
    counts = [results[n].call_counts.get(role, 0) for n in names]
    ax.bar(x + i * width, counts, width, label=role.capitalize())
ax.set_xticks(x + width)
ax.set_xticklabels(names, fontsize=3.5, rotation=45)
ax.set_ylabel("LLM Calls")
ax.set_title("(d) LLM Call Breakdown")
ax.legend()

plt.tight_layout()
plt.savefig("search_ace_results.png", dpi=150, bbox_inches="tight")
plt.show()
print("Plots saved to search_ace_results.png")


In [None]:
# --- Statistical comparison ---

print("=" * 60)
print("Statistical Comparison (Bootstrap 95% CI on accuracy difference vs Greedy)")
print("=" * 60)

def bootstrap_diff_ci(a, b, n_boot=2000):
    a, b = np.array(a, dtype=float), np.array(b, dtype=float)
    diffs = []
    for _ in range(n_boot):
        ia = np.random.choice(len(a), len(a), replace=True)
        ib = np.random.choice(len(b), len(b), replace=True)
        diffs.append(np.mean(b[ib]) - np.mean(a[ia]))
    lo, hi = np.percentile(diffs, [2.5, 97.5])
    return np.mean(diffs), lo, hi

greedy_tail = greedy_log.correct[-20:]

# Each condition vs Greedy
comparisons = [
    ("Majority Vote", majority_log.correct[-20:]),
    ("Best-of-N", best_of_n_log.correct[-20:]),
    ("Thompson", thompson_log.correct[-20:]),
    ("UCB Bandit", ucb_log.correct[-20:]),
    ("Thompson-Dyn", thompson_dyn_log.correct[-20:]),
    ("Thompson-Disc", thompson_disc_log.correct[-20:]),
    ("Beam Search", beam_log.correct[-20:]),
]
for q_mode in ["mean", "ema", "bayesian", "variance"]:
    comparisons.append((f"PUCT-{q_mode.upper()}", puct_logs[q_mode].correct[-20:]))
comparisons.append(("AB-MCTS", ab_mcts_log.correct[-20:]))
comparisons.append(("Disc-MCTS", disc_mcts_log.correct[-20:]))

for name, tail in comparisons:
    mean_diff, lo, hi = bootstrap_diff_ci(greedy_tail, tail)
    sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
    print(f"{name} vs Greedy: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Thompson vs UCB (isolate randomness variable)
print()
t_tail = thompson_log.correct[-20:]
ucb_tail = ucb_log.correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(t_tail, ucb_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"UCB vs Thompson: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Beam Search vs Greedy (does width>1 help?)
beam_tail = beam_log.correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(greedy_tail, beam_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"Beam vs Greedy: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Beam Search vs PUCT-Mean (beam vs tree)
pm_tail = puct_logs["mean"].correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(pm_tail, beam_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"Beam vs PUCT-Mean: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Disc-MCTS vs PUCT-Mean (does discounting help in tree search?)
disc_tail = disc_mcts_log.correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(pm_tail, disc_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"Disc-MCTS vs PUCT-Mean: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Thompson-Dyn vs static Thompson
print()
td_tail = thompson_dyn_log.correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(t_tail, td_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"Thompson-Dyn vs Thompson: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Thompson-Disc vs static Thompson
tdisc_tail = thompson_disc_log.correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(t_tail, tdisc_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"Thompson-Disc vs Thompson: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Thompson vs each PUCT
print()
for q_mode in ["mean", "ema", "bayesian", "variance"]:
    p_tail = puct_logs[q_mode].correct[-20:]
    mean_diff, lo, hi = bootstrap_diff_ci(t_tail, p_tail)
    sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
    print(f"PUCT-{q_mode.upper()} vs Thompson: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# AB-MCTS vs PUCT-Bayesian
print()
ab_tail = ab_mcts_log.correct[-20:]
pb_tail = puct_logs["bayesian"].correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(pb_tail, ab_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"AB-MCTS vs PUCT-Bayesian: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# PUCT-VARIANCE vs PUCT-Bayesian
pv_tail = puct_logs["variance"].correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(pb_tail, pv_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"PUCT-VARIANCE vs PUCT-Bayesian: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")

# Best-of-N vs Majority Vote (same budget, different selection)
bon_tail = best_of_n_log.correct[-20:]
mv_tail = majority_log.correct[-20:]
mean_diff, lo, hi = bootstrap_diff_ci(mv_tail, bon_tail)
sig = "YES" if lo > 0 else ("YES (worse)" if hi < 0 else "NO")
print(f"Best-of-N vs Majority Vote: diff={mean_diff:+.1%} 95%CI=[{lo:+.1%}, {hi:+.1%}] significant={sig}")


## 7. Results Summary

### Experiment Design
- **Dataset**: GSM8K (50 problems, shuffled, seed=42)
- **Model**: Qwen2.5-7B-Instruct (local via vLLM, bfloat16)
- **14 conditions** spanning no-evolution -> flat bandit -> beam search -> sequential -> tree search -> adaptive tree search
- **GPU optimization**: AsyncOpenAI with Semaphore(64), prefix caching, virtual loss parallel PUCT

### Budget Comparison
| Condition | Generate | Reflect | Curate | Total |
|-----------|----------|---------|--------|-------|
| Majority Vote | 100 | 0 | 0 | 100 |
| Best-of-N | 100 | 0 | 0 | 100 |
| Greedy ACE | 50 | 50 | 10 | 110 |
| Thompson | 50 | 50 | 5 | 105 |
| UCB Bandit | 50 | 50 | 5 | 105 |
| Thompson-Disc | 50 | 50 | 5 | 105 |
| Thompson-Dyn | 50 | 50 | ~9 | ~109 |
| Beam Search | ~150 | ~150 | ~27 | ~327 |
| PUCT-Mean | 50 | 50 | <=17 | <=117 |
| PUCT-EMA | 50 | 50 | <=17 | <=117 |
| PUCT-Bayesian | 50 | 50 | <=17 | <=117 |
| PUCT-Variance | 50 | 50 | <=17 | <=117 |
| AB-MCTS | 50 | 50 | variable | <=117 |
| Disc-MCTS | 50 | 50 | <=17 | <=117 |

### Strategy Taxonomy
| Strategy | Search Type | Evolves Playbook? | Exploration Mechanism | Parallelization |
|----------|------------|-------------------|----------------------|------------------|
| Majority Vote | None | No | Sample diversity only | All 100 calls parallel |
| Best-of-N | None | No | Rejection sampling (consistency as confidence) | All 100 calls parallel |
| Greedy ACE | Sequential | Yes | None (greedy) | Batch 5 gen + 5 ref between curate |
| Thompson | Flat bandit | Pool of variants (fixed) | Beta posterior sampling | Seed phase parallel |
| UCB Bandit | Flat bandit | Pool of variants (fixed) | Deterministic UCB1 bound | Seed phase parallel |
| Thompson-Disc | Flat bandit | Pool of variants (fixed) | Discounted Beta posterior (gamma=0.95) | Seed phase parallel |
| Thompson-Dyn | Flat bandit | Pool grows dynamically | Beta posterior + periodic arm addition | Seed phase parallel |
| Beam Search | Width-K beam | Yes (prune + curate each round) | K parallel candidates, top-K survival | All K beams evaluated in parallel |
| PUCT-Mean | Tree | Yes (progressive widening) | UCB with mean Q | Virtual loss, K=4 leaves |
| PUCT-EMA | Tree | Yes (progressive widening) | UCB with recency-weighted Q (alpha=0.4) | Virtual loss, K=4 leaves |
| PUCT-Bayesian | Tree | Yes (progressive widening) | UCB with shrinkage-to-0.5 Q | Virtual loss, K=4 leaves |
| PUCT-Variance | Tree | Yes (progressive widening) | UCB with variance-aware Q (arXiv:2512.21648) | Virtual loss, K=4 leaves |
| AB-MCTS | Tree (adaptive) | Yes (adaptive widening) | UCB + Thompson-sampled wider-vs-deeper | Async batch eval |
| Disc-MCTS | Tree (discounted) | Yes (progressive widening) | UCB with gamma-discounted Q (gamma=0.95) | Async batch eval |

### Key Questions
1. **Does evolution help at all?** Majority Vote vs Greedy — if MajVote wins, sampling > evolution
2. **Does smart selection beat random?** Best-of-N vs Majority Vote — same budget, different selection
3. **Flat vs tree?** Thompson vs PUCT — is tree structure worth the overhead at 50 problems?
4. **Beam vs tree?** Beam Search vs PUCT — does the simpler beam approach capture most of tree search gains?
5. **Stochastic vs deterministic bandits?** Thompson vs UCB — does randomized exploration beat deterministic UCB1?
6. **Does dynamic pool help?** Thompson-Dyn vs Thompson — does periodic arm addition improve over frozen pool?
7. **Does discounting help in bandits?** Thompson-Disc vs Thompson — does forgetting old observations improve adaptation?
8. **Does discounting help in trees?** Disc-MCTS vs PUCT-Mean — does gamma-decay in tree nodes help non-stationarity?
9. **Which Q estimator?** Mean vs EMA vs Bayesian vs Variance — does exploration strategy matter for PUCT?
10. **Fixed vs adaptive branching?** AB-MCTS vs PUCT-Bayesian — does Thompson-sampled wider/deeper beat fixed progressive widening?
11. **Free lunch?** Does any strategy consistently beat Greedy with matched budget?

### New Condition Designs

#### Best-of-N (Rejection Sampling)
Generates N solutions per problem with the static initial playbook, picks the most frequently occurring answer (consistency as a proxy for confidence). Same budget as Majority Vote (100 generate calls). Tests whether a "smart" tree search actually beats simply generating many random variations and picking the most consistent one.

#### Beam Search (Width-K)
The "missing link" between Greedy (width=1) and MCTS (complex tree). Maintains K=3 candidate playbooks. Each round: (1) evaluate all K beams on the same problem batch in parallel, (2) rank by cumulative accuracy, (3) prune to top-K, (4) curate each survivor. Hypothesis: captures ~80% of tree search gains at ~20% of the cost. Note: higher total budget (~327 calls) because all K beams are evaluated on every batch — the tradeoff is breadth of exploration vs total compute.

#### UCB Bandit (Deterministic)
Standard UCB1 (Auer et al., 2002) as the deterministic counterpart to Thompson Sampling. Selects: argmax_k [mean_reward_k + sqrt(2) * sqrt(ln(t) / n_k)]. Same seed/pool structure as Thompson. Isolates the "randomness" variable: Thompson uses stochastic posterior sampling, UCB uses a deterministic confidence bound. Thompson is usually better for exploration, but UCB is often more sample-efficient.

#### Discounted MCTS (gamma=0.95)
Applies the non-stationarity fix from Thompson-Disc but inside the MCTS tree. Standard PUCT suffers when a child node improves (via curation) but the parent's average Q is stale. Discounted MCTS uses exponentially weighted mean Q-values where recent rewards have weight 1 and older rewards decay by gamma=0.95 per step. This lets the tree adapt to improving playbooks without needing the EMA estimator's fixed alpha.

### Previous Condition Designs

#### Thompson-Disc (non-stationary bandits, arXiv:2305.10718)
Thompson-Disc applies exponential discounting to Beta posteriors AFTER each update:
- After each round, multiply all arms' alpha and beta by gamma=0.95
- Floor at 0.1 (not 1.0) to allow proper forgetting
- Effective lookback window: 1/(1-gamma) = 20 problems

#### AB-MCTS (Adaptive Progressive Widening, inspired by arXiv:2503.04412)
Replaces PUCT's fixed progressive widening with an adaptive decision per node:
- Each node maintains Beta(expand_alpha, expand_beta) posterior
- Thompson-sample to decide: go WIDER (curate new child) or go DEEPER (re-evaluate)
- Posterior updated only on direct evidence (no counterfactual assumptions)

#### PUCT-Variance (arXiv:2512.21648)
Variance-aware Q-estimator: Q = mean(rewards) + 0.5 * sqrt(var(rewards) / n)
- High variance nodes get explored more (UCB-V style)
- Hypothesis: better calibrated exploration for heteroscedastic rewards


In [None]:
# --- 8. Save & Export Results ---
import shutil
from google.colab import files

def generate_report(results, filename="experiment_report.md"):
    with open(filename, "w") as f:
        f.write("# Search-Augmented ACE: Experiment Report\n\n")
        f.write(f"**Date:** {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")

        f.write("## 1. Final Accuracy Summary\n")
        f.write("| Strategy | Accuracy (Last 20) | Budget Used |\n")
        f.write("| :--- | :--- | :--- |\n")

        for name, log in results.items():
            acc = log.final_accuracy
            calls = sum(log.call_counts.values())
            f.write(f"| {name} | {acc:.2%} | {calls} |\n")

        f.write("\n## 2. Strategy Details\n")
        for name, log in results.items():
            f.write(f"### {name}\n")
            f.write(f"- **Final Playbook Size:** {log.playbook_sizes[-1] if log.playbook_sizes else 0} bullets\n")
            f.write(f"- **Call Breakdown:** {dict(log.call_counts)}\n")
            if log.final_playbook:
                f.write("- **Final Playbook Preview:**\n")
                content = log.final_playbook.to_str()
                # Truncate if too long for report
                preview = content[:500] + "..." if len(content) > 500 else content
                f.write(f"```\n{preview}\n```\n")
            f.write("\n")

# 1. Generate human-readable report
generate_report(results, "ace_experiment_summary.md")
print("Generated readable report: ace_experiment_summary.md")

# 2. Archive everything (Checkpoints + Plot + Report)
# This allows you to 'reanalyze' later by loading the .pkl files
shutil.make_archive('ace_experiment_data', 'zip', '.', 'checkpoints')
# Manually add the report and plot to a final zip if they aren't in checkpoints
final_zip = "full_experiment_results.zip"
os.system(f"zip -r {final_zip} checkpoints ace_experiment_summary.md search_ace_results.png")

print(f"Created archive: {final_zip}")

# 3. Download to your local machine
try:
    files.download(final_zip)
except Exception as e:
    print(f"Auto-download failed (browser block?). Download '{final_zip}' manually from the file explorer on the left.")

In [None]:
# --- Cleanup: kill vLLM server ---
try:
    os.kill(vllm_proc.pid, signal.SIGTERM)
    print(f"Killed vLLM server (PID {vllm_proc.pid})")
except ProcessLookupError:
    print("vLLM server already stopped")