# Verified Archetype Discovery PoC

**Hypothesis**: Identifying "archetype" problems first improves learning compared to random curriculum. The key innovation is rigorous verification of strategies before trusting them.

**Game of 24**: Given 4 numbers, use +, -, *, / to make 24 (each number used exactly once).

**Approach**:
1. Embed problems and cluster to find representative archetypes
2. Rigorously verify strategies extracted from archetypes
3. Bootstrap playbook from verified strategies
4. Use LinUCB contextual bandit for curriculum selection
5. Compare archetype-first vs random curriculum

**Setup**: Qwen2.5-7B-Instruct via vLLM on A100 (bfloat16, prefix caching, async parallel eval).

## 1. Setup & Dependencies

In [None]:
!pip install "numpy<2.0"



In [None]:
!pip install vllm==0.6.6 openai==1.58.1 datasets==3.2.0 matplotlib==3.9.3 numpy==1.26.4 nest_asyncio==1.6.0 sentence-transformers scikit-learn==1.4.0 scipy==1.12.0

Collecting vllm==0.6.6
  Using cached vllm-0.6.6-cp38-abi3-manylinux1_x86_64.whl.metadata (11 kB)
Collecting openai==1.58.1
  Using cached openai-1.58.1-py3-none-any.whl.metadata (27 kB)
Collecting datasets==3.2.0
  Using cached datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting matplotlib==3.9.3
  Using cached matplotlib-3.9.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting scikit-learn==1.4.0
  Using cached scikit_learn-1.4.0-1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting scipy==1.12.0
  Using cached scipy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting blake3 (from vllm==0.6.6)
  Using cached blake3-1.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.6 kB)
Collecting prometheus-fastapi-instrumentat

In [None]:
import subprocess
import time
import os
import signal
import json
import re
import copy
import math
import random
import pickle
import asyncio
import itertools
import nest_asyncio
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Set, Any
from collections import defaultdict, Counter
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_distances
from scipy import stats
from sentence_transformers import SentenceTransformer
from openai import OpenAI, AsyncOpenAI

# Allow nested event loops (required for Colab/Jupyter)
nest_asyncio.apply()

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Checkpoint directory
CHECKPOINT_DIR = Path("checkpoints_archetype")
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Concurrency for async LLM calls
MAX_CONCURRENT_LLM = 128

print(f"Setup complete. Seed={SEED}")

In [None]:
# Launch vLLM server in background
VLLM_PORT = 8000
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"

print("Launching vLLM server...")
vllm_log = open('/tmp/vllm_server.log', 'w')

vllm_proc = subprocess.Popen(
    [
        "python", "-m", "vllm.entrypoints.openai.api_server",
        "--model", MODEL_NAME,
        "--port", str(VLLM_PORT),
        "--max-model-len", "8192",
        "--gpu-memory-utilization", "0.95",
        "--dtype", "bfloat16",
        "--max-num-seqs", "1024",
        "--max-num-batched-tokens", "16384",
        "--enable-prefix-caching",
        "--disable-log-requests",
    ],
    stdout=vllm_log,
    stderr=subprocess.STDOUT,
)

# Wait for server to be ready
client = OpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="dummy")
aclient = AsyncOpenAI(base_url=f"http://localhost:{VLLM_PORT}/v1", api_key="dummy")
print("Waiting for vLLM server to be ready...")

for attempt in range(600):
    try:
        client.models.list()
        print(f"vLLM ready after {attempt + 1}s")
        break
    except Exception:
        time.sleep(1)
else:
    raise RuntimeError("vLLM server failed to start within 180s")

# Warmup: trigger CUDA graph compilation for various batch sizes
# CUDA best practice: Pre-compile graphs for expected batch sizes to avoid
# JIT compilation overhead during actual inference
print("Warming up vLLM (CUDA graph compilation)...")

async def warmup_vllm():
    """Warmup with varied batch sizes to pre-compile CUDA graphs."""
    warmup_prompt = "What is 2+2? Answer briefly."

    for batch_size in [1, 4, 16, 32]:
        print(f"  Warming up batch size {batch_size}...")
        tasks = [
            aclient.chat.completions.create(
                model=MODEL_NAME,
                messages=[{"role": "user", "content": warmup_prompt}],
                max_tokens=16,
                temperature=0.0,
            )
            for _ in range(batch_size)
        ]
        try:
            await asyncio.gather(*tasks)
        except Exception as e:
            print(f"  Warmup batch {batch_size} failed (non-fatal): {e}")

    print("Warmup complete.")

asyncio.run(warmup_vllm())

print(f"vLLM server running on port {VLLM_PORT} (PID: {vllm_proc.pid})")

## 2. Game of 24 Data Loading & Validation

In [None]:
# --- Game of 24 Problem Generation & Validation ---
from concurrent.futures import ProcessPoolExecutor, as_completed

def generate_all_game24_problems(max_num: int = 13) -> List[Tuple[int, int, int, int]]:
    """Generate all unique Game of 24 problems with numbers 1-max_num."""
    problems = set()
    for a in range(1, max_num + 1):
        for b in range(a, max_num + 1):
            for c in range(b, max_num + 1):
                for d in range(c, max_num + 1):
                    problems.add((a, b, c, d))
    return list(problems)


def solve_24_exhaustive(numbers: Tuple[int, ...], target: float = 24.0) -> List[str]:
    """
    Find all distinct solutions to Game of 24 using exhaustive search.
    Returns list of expression strings that evaluate to target.
    """
    if len(numbers) == 1:
        if abs(numbers[0] - target) < 1e-9:
            return [str(int(numbers[0])) if numbers[0] == int(numbers[0]) else str(numbers[0])]
        return []

    solutions = set()
    ops = [('+', lambda a, b: a + b),
           ('-', lambda a, b: a - b),
           ('*', lambda a, b: a * b),
           ('/', lambda a, b: a / b if b != 0 else float('inf'))]

    # Try all pairs of numbers
    for i in range(len(numbers)):
        for j in range(len(numbers)):
            if i == j:
                continue
            a, b = numbers[i], numbers[j]
            remaining = tuple(numbers[k] for k in range(len(numbers)) if k != i and k != j)

            for op_str, op_func in ops:
                try:
                    result = op_func(a, b)
                    if result == float('inf') or result == float('-inf'):
                        continue

                    # Format expression part
                    a_str = str(int(a)) if a == int(a) else str(a)
                    b_str = str(int(b)) if b == int(b) else str(b)
                    expr_part = f"({a_str} {op_str} {b_str})"

                    # Recurse with new number list
                    new_numbers = remaining + (result,)
                    sub_solutions = solve_24_exhaustive(new_numbers, target)

                    for sub_sol in sub_solutions:
                        # Replace the result placeholder with the expression
                        result_str = str(int(result)) if result == int(result) else str(result)
                        if result_str in sub_sol:
                            full_expr = sub_sol.replace(result_str, expr_part, 1)
                            solutions.add(full_expr)
                        else:
                            solutions.add(f"{expr_part} -> {sub_sol}")
                except (ZeroDivisionError, OverflowError):
                    continue

    return list(solutions)


def is_solvable_24(numbers: Tuple[int, ...]) -> bool:
    """Check if a Game of 24 problem is solvable."""
    return len(solve_24_exhaustive(numbers)) > 0


def count_solutions_24(numbers: Tuple[int, ...]) -> int:
    """Count the number of distinct solutions."""
    return len(solve_24_exhaustive(numbers))


def _check_problem(nums: Tuple[int, ...]) -> Optional[Dict]:
    """Worker function for parallel solvability check."""
    solutions = solve_24_exhaustive(nums)
    if solutions:
        return {
            "numbers": nums,
            "text": f"Use {nums[0]}, {nums[1]}, {nums[2]}, {nums[3]} to make 24",
            "n_solutions": len(solutions),
        }
    return None


# Generate solvable problems
print("Generating Game of 24 problems...")
all_problems = generate_all_game24_problems(max_num=13)
print(f"Total candidate problems: {len(all_problems)}")

# Filter to solvable problems using parallel processing
# CUDA best practice: Parallelize sequential code (CPU-side)
print("Filtering to solvable problems (parallel)...")
solvable_problems = []

# Use ProcessPoolExecutor for CPU-bound exhaustive search
N_WORKERS = min(8, os.cpu_count() or 4)
print(f"  Using {N_WORKERS} worker processes...")

with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
    futures = {executor.submit(_check_problem, nums): nums for nums in all_problems}
    done_count = 0
    for future in as_completed(futures):
        done_count += 1
        result = future.result()
        if result is not None:
            solvable_problems.append(result)
        if done_count % 200 == 0:
            print(f"  Processed {done_count}/{len(all_problems)}...")

print(f"Solvable problems: {len(solvable_problems)}")

# Shuffle and take subset for experiment
rng = random.Random(SEED)
rng.shuffle(solvable_problems)
problems = solvable_problems[:200]  # Use 200 for clustering, 100 for eval

print(f"\nUsing {len(problems)} problems for experiment")
print(f"Example: {problems[0]}")

In [None]:
# --- Known Hard Game of 24 Problems (for adversarial testing) ---
# These require non-obvious operations like fractions or specific orderings

HARD_PROBLEMS = [
    # Requires fractions
    {"numbers": (1, 5, 5, 5), "text": "Use 1, 5, 5, 5 to make 24", "solution": "5 * (5 - 1/5) = 24", "difficulty": "fraction"},
    {"numbers": (3, 3, 8, 8), "text": "Use 3, 3, 8, 8 to make 24", "solution": "8 / (3 - 8/3) = 24", "difficulty": "fraction"},
    {"numbers": (1, 3, 4, 6), "text": "Use 1, 3, 4, 6 to make 24", "solution": "6 / (1 - 3/4) = 24", "difficulty": "fraction"},
    {"numbers": (1, 4, 5, 6), "text": "Use 1, 4, 5, 6 to make 24", "solution": "4 / (1 - 5/6) = 24", "difficulty": "fraction"},
    {"numbers": (1, 6, 6, 8), "text": "Use 1, 6, 6, 8 to make 24", "solution": "8 / (1 - 6/6) - nope, (6 - 1) * 6 - 8 = 22 - no, 8 * 6 / (6 - 1) = 48/5", "difficulty": "tricky"},
    # Requires specific ordering
    {"numbers": (2, 3, 5, 12), "text": "Use 2, 3, 5, 12 to make 24", "solution": "(5 - 3 + 2) * 12 / 2 = 24 - check", "difficulty": "ordering"},
    {"numbers": (1, 2, 7, 7), "text": "Use 1, 2, 7, 7 to make 24", "solution": "(7 + 1) * (7 - 2) / ... = (7 - 1) * (7 - 2) = 30 - no, 7 * 7 / 2 - 1 = 23.5", "difficulty": "tricky"},
    {"numbers": (4, 4, 7, 7), "text": "Use 4, 4, 7, 7 to make 24", "solution": "(7 - 4) * (7 + 4/4) = 3 * 8 = 24? - no 4/4=1, 7+1=8, 7-4=3, 3*8=24!", "difficulty": "tricky"},
    {"numbers": (3, 3, 7, 7), "text": "Use 3, 3, 7, 7 to make 24", "solution": "(7 + 7) * (3 - 3) = 0 - no, (3 + 3/7) * 7 = 24", "difficulty": "fraction"},
    {"numbers": (2, 5, 5, 10), "text": "Use 2, 5, 5, 10 to make 24", "solution": "(5 - 5/10) * 2 = 9 - no, 10 * 5 / 2 - 5 = 20", "difficulty": "tricky"},
]

# Verify and correct hard problems
verified_hard = []
for hp in HARD_PROBLEMS:
    nums = hp["numbers"]
    solutions = solve_24_exhaustive(nums)
    if solutions:
        hp["verified_solutions"] = solutions[:3]  # Keep up to 3
        hp["n_solutions"] = len(solutions)
        verified_hard.append(hp)
        print(f"{nums}: {len(solutions)} solutions - {solutions[0][:50]}...")
    else:
        print(f"{nums}: NO SOLUTION (removing from hard set)")

HARD_PROBLEMS = verified_hard
print(f"\nVerified {len(HARD_PROBLEMS)} hard problems")

## 3. Problem Embedding & Clustering

In [None]:
# Load sentence transformer for embeddings
# GPU optimization: Use CUDA with bfloat16 for A100 (matches vLLM dtype)
print("Loading sentence transformer model (GPU, bfloat16)...")
embed_model = SentenceTransformer('all-MiniLM-L6-v2', device='cuda', model_kwargs={"torch_dtype": "bfloat16"})
print(f"Model loaded on {embed_model.device}")


def compute_numerical_features(nums: Tuple[int, ...]) -> np.ndarray:
    """Compute numerical features for a Game of 24 problem."""
    return np.array([
        sum(nums) / 52,  # sum normalized by max possible (13*4)
        np.prod(nums) / (13**4),  # product normalized
        max(nums) / 13,  # max normalized
        min(nums) / 13,  # min normalized
        len(set(nums)) / 4,  # uniqueness ratio
        (max(nums) - min(nums)) / 12,  # range normalized
        np.std(nums) / 5,  # std normalized
        sum(1 for n in nums if n % 2 == 0) / 4,  # even ratio
        sum(1 for n in nums if 24 % n == 0) / 4,  # divisor of 24 ratio
        1.0 if any(a * b == 24 for a, b in itertools.combinations(nums, 2)) else 0.0,  # has factor pair
    ]) * 2.0  # Scale up numerical importance


def compute_problem_features_batch(problems_list: List[Dict], batch_size: int = 128) -> np.ndarray:
    """
    Compute feature vectors for all problems using batched encoding.

    CUDA optimization: Batched GPU inference maximizes device utilization
    and minimizes kernel launch overhead. Increased batch_size for A100.
    """
    # Batch encode all texts at once (GPU-efficient)
    texts = [p["text"] for p in problems_list]
    print(f"  Batch encoding {len(texts)} texts...")
    text_embeddings = embed_model.encode(
        texts,
        batch_size=batch_size,
        convert_to_numpy=True,
        show_progress_bar=True,
    )

    # Compute numerical features (CPU, vectorized where possible)
    print("  Computing numerical features...")
    num_features = np.array([
        compute_numerical_features(p["numbers"]) for p in problems_list
    ])

    # Concatenate: [text_embedding, scaled_num_features]
    combined = np.concatenate([text_embeddings, num_features], axis=1)
    return combined


# Compute embeddings for all problems (batched for GPU efficiency)
print("Computing problem embeddings (batched)...")
embeddings = compute_problem_features_batch(problems, batch_size=128)

# Assign embeddings back to problems
for i, prob in enumerate(problems):
    prob["embedding"] = embeddings[i]

print(f"Embeddings shape: {embeddings.shape}")

In [None]:
# K-means clustering
N_CLUSTERS = 10

print(f"Clustering into {N_CLUSTERS} clusters...")
kmeans = KMeans(n_clusters=N_CLUSTERS, random_state=SEED, n_init=10)
cluster_labels = kmeans.fit_predict(embeddings)

# Assign cluster info to problems
for i, prob in enumerate(problems):
    prob["cluster"] = cluster_labels[i]
    # Distance to cluster center (for centrality)
    center = kmeans.cluster_centers_[cluster_labels[i]]
    prob["dist_to_center"] = np.linalg.norm(prob["embedding"] - center)

# Cluster statistics
cluster_sizes = Counter(cluster_labels)
print(f"\nCluster sizes: {dict(cluster_sizes)}")

# Show example from each cluster
print("\nCluster examples:")
for c in range(N_CLUSTERS):
    cluster_probs = [p for p in problems if p["cluster"] == c]
    if cluster_probs:
        ex = min(cluster_probs, key=lambda p: p["dist_to_center"])
        print(f"  Cluster {c}: {ex['numbers']} (size={len(cluster_probs)})")

## 4. Archetype Identification Pipeline

In [None]:
# --- 4.1 Embedding Centrality ---

def compute_centrality_score(problem: Dict, cluster_probs: List[Dict]) -> float:
    """
    Centrality = negative distance to cluster center.
    Higher is better (closer to center).
    """
    # Normalize by max distance in cluster
    max_dist = max(p["dist_to_center"] for p in cluster_probs) + 1e-6
    return 1.0 - (problem["dist_to_center"] / max_dist)


# Compute centrality for all problems
for c in range(N_CLUSTERS):
    cluster_probs = [p for p in problems if p["cluster"] == c]
    for prob in cluster_probs:
        prob["centrality"] = compute_centrality_score(prob, cluster_probs)

print("Centrality scores computed.")
print(f"Example: {problems[0]['numbers']} centrality = {problems[0]['centrality']:.3f}")

In [None]:
# --- 4.2 Solution Diversity (via LLM) ---

call_counter = defaultdict(int)
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM)

# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0


async def llm_call_async(
    system: str,
    user: str,
    role: str = "generate",
    temperature: float = 0.7,
    max_tokens: int = 512,  # CUDA optimization: reduced from 512 to minimize decoding steps
    retries: int = MAX_RETRIES,
) -> str:
    """
    Async LLM call with semaphore-based concurrency control and retry logic.

    Args:
        system: System prompt.
        user: User prompt.
        role: Role tag for call counting.
        temperature: Sampling temperature.
        max_tokens: Maximum tokens to generate.
        retries: Number of retries on failure.

    Returns:
        Generated text, or empty string on failure.
    """
    call_counter[role] += 1

    for attempt in range(retries):
        async with _llm_semaphore:
            try:
                resp = await aclient.chat.completions.create(
                    model=MODEL_NAME,
                    messages=[
                        {"role": "system", "content": system},
                        {"role": "user", "content": user},
                    ],
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
                content = resp.choices[0].message.content
                return content.strip() if content else ""
            except Exception as e:
                if attempt < retries - 1:
                    await asyncio.sleep(RETRY_DELAY * (attempt + 1))
                else:
                    print(f"Async LLM call failed after {retries} attempts ({role}): {e}")
                    return ""


def llm_call(
    system: str,
    user: str,
    role: str = "generate",
    temperature: float = 0.7,
    max_tokens: int = 512,  # CUDA optimization: reduced from 512
    retries: int = MAX_RETRIES,
) -> str:
    """
    Sync LLM call with retry logic.

    Args:
        system: System prompt.
        user: User prompt.
        role: Role tag for call counting.
        temperature: Sampling temperature.
        max_tokens: Maximum tokens to generate.
        retries: Number of retries on failure.

    Returns:
        Generated text, or empty string on failure.
    """
    call_counter[role] += 1

    for attempt in range(retries):
        try:
            resp = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": system},
                    {"role": "user", "content": user},
                ],
                temperature=temperature,
                max_tokens=max_tokens,
            )
            content = resp.choices[0].message.content
            return content.strip() if content else ""
        except Exception as e:
            if attempt < retries - 1:
                time.sleep(RETRY_DELAY * (attempt + 1))
            else:
                print(f"LLM call failed after {retries} attempts ({role}): {e}")
                return ""


GAME24_SYSTEM = """You are a Game of 24 solver. Given 4 numbers, find an expression using +, -, *, / that equals 24.
Each number must be used exactly once. Show your work step by step.

Format your final answer as: ANSWER: <expression> = 24

Example:
Numbers: 1, 2, 3, 4
Solution: 1 * 2 * 3 * 4 = 24
ANSWER: 1 * 2 * 3 * 4 = 24"""


async def get_llm_solutions(problem: Dict, n_samples: int = 5) -> List[str]:
    """Get N diverse solutions from LLM for a problem."""
    nums = problem["numbers"]
    user_prompt = f"Numbers: {nums[0]}, {nums[1]}, {nums[2]}, {nums[3]}\n\nFind an expression that equals 24."

    # CUDA optimization: max_tokens=512 since Game of 24 answers are short (~50 tokens)
    tasks = [
        llm_call_async(GAME24_SYSTEM, user_prompt, role="diversity", temperature=0.9, max_tokens=512)
        for _ in range(n_samples)
    ]
    responses = await asyncio.gather(*tasks, return_exceptions=True)

    # Extract expressions
    solutions = []
    for resp in responses:
        if isinstance(resp, Exception) or not resp:
            continue
        match = re.search(r"ANSWER:\s*(.+?)\s*=\s*24", resp, re.IGNORECASE)
        if match:
            solutions.append(match.group(1).strip())
        else:
            # Fallback: look for any expression = 24
            match2 = re.search(r"([\d\s+\-*/()]+)\s*=\s*24", resp)
            if match2:
                solutions.append(match2.group(1).strip())

    return solutions


async def _compute_diversity_for_problem(prob: Dict) -> Dict:
    """Compute diversity for a single problem. Returns updated problem dict."""
    llm_solutions = await get_llm_solutions(prob, n_samples=5)
    unique_solutions = set(llm_solutions)
    return {
        "numbers": prob["numbers"],
        "llm_solutions": llm_solutions,
        "diversity": len(unique_solutions) / 5.0,
    }


# Compute diversity for candidate archetypes (top 3 per cluster by centrality)
print("Computing solution diversity for archetype candidates...")
candidates = []
for c in range(N_CLUSTERS):
    cluster_probs = sorted(
        [p for p in problems if p["cluster"] == c],
        key=lambda p: -p["centrality"]
    )[:3]
    candidates.extend(cluster_probs)

print(f"Evaluating {len(candidates)} candidates...")


async def compute_all_diversity():
    """Compute diversity for all candidates IN PARALLEL."""
    # CUDA best practice: Launch all independent work concurrently
    tasks = [_compute_diversity_for_problem(prob) for prob in candidates]
    print(f"  Launching {len(tasks)} parallel diversity evaluations...")

    results = await asyncio.gather(*tasks, return_exceptions=True)

    # Assign results back to candidates
    for prob, result in zip(candidates, results):
        if isinstance(result, Exception):
            print(f"  Warning: diversity computation failed for {prob['numbers']}: {result}")
            prob["llm_solutions"] = []
            prob["diversity"] = 0.0
        else:
            prob["llm_solutions"] = result["llm_solutions"]
            prob["diversity"] = result["diversity"]


asyncio.run(compute_all_diversity())
print(f"Diversity computed. LLM calls: {call_counter['diversity']}")

In [None]:
# --- 4.3 Structural Simplicity ---

def compute_solution_depth(expr: str) -> int:
    """Estimate the depth of a solution expression tree."""
    # Count nested parentheses as proxy for depth
    max_depth = 0
    current_depth = 0
    for char in expr:
        if char == '(':
            current_depth += 1
            max_depth = max(max_depth, current_depth)
        elif char == ')':
            current_depth -= 1
    # Also count operators as contributing to depth
    n_ops = sum(1 for c in expr if c in '+-*/')
    return max_depth + n_ops // 2


def compute_simplicity_score(problem: Dict) -> float:
    """
    Simplicity = inverse of minimum solution tree depth.
    Simpler problems (fewer operations) are better archetypes.
    """
    # Use ground truth solutions if available
    solutions = solve_24_exhaustive(problem["numbers"])
    if not solutions:
        return 0.0

    min_depth = min(compute_solution_depth(s) for s in solutions[:10])  # Check up to 10
    # Normalize: depth of 3 is simple (score=1), depth of 10 is complex (score~0.3)
    return 1.0 / (1.0 + min_depth / 3.0)


# Compute simplicity for candidates
print("Computing simplicity scores...")
for prob in candidates:
    prob["simplicity"] = compute_simplicity_score(prob)

print(f"Simplicity computed.")
print(f"Example: {candidates[0]['numbers']} simplicity = {candidates[0]['simplicity']:.3f}")

In [None]:
# --- 4.4 Combined Ranking ---

def compute_archetype_score(problem: Dict) -> float:
    """
    Combined score = geometric mean of centrality * diversity * simplicity.
    Geometric mean penalizes being weak in any dimension.
    """
    centrality = problem.get("centrality", 0.5)
    diversity = problem.get("diversity", 0.5)
    simplicity = problem.get("simplicity", 0.5)

    # Add small epsilon to avoid zero
    eps = 0.01
    return (max(centrality, eps) * max(diversity, eps) * max(simplicity, eps)) ** (1/3)


# Compute combined scores
for prob in candidates:
    prob["archetype_score"] = compute_archetype_score(prob)

# Select top-10 archetypes (ensuring cluster diversity)
archetypes = []
used_clusters = set()

# First pass: one per cluster
sorted_candidates = sorted(candidates, key=lambda p: -p["archetype_score"])
for prob in sorted_candidates:
    if prob["cluster"] not in used_clusters and len(archetypes) < 10:
        archetypes.append(prob)
        used_clusters.add(prob["cluster"])

# Second pass: fill remaining slots with best overall
for prob in sorted_candidates:
    if prob not in archetypes and len(archetypes) < 10:
        archetypes.append(prob)

print(f"\nSelected {len(archetypes)} archetypes:")
print("-" * 80)
for i, arch in enumerate(archetypes):
    print(f"{i+1}. {arch['numbers']} | cluster={arch['cluster']} | "
          f"score={arch['archetype_score']:.3f} "
          f"(cent={arch['centrality']:.2f}, div={arch['diversity']:.2f}, simp={arch['simplicity']:.2f})")

## 5. Verification Suite

In [None]:
# --- 5.1 & 5.2 Multi-path Consistency & Execution Validation ---

# Allowed tokens for safe expression evaluation
_SAFE_EXPR_PATTERN = re.compile(r'^[\d\s+\-*/().]+$')


def safe_eval_expression(expr: str, numbers: Tuple[int, ...]) -> Tuple[bool, float, str]:
    """
    Safely evaluate a math expression and verify it:
    1. Uses all 4 numbers exactly once
    2. Evaluates to 24

    Security: Uses AST parsing to ensure only arithmetic operations are performed.
    No function calls, attribute access, or arbitrary code execution allowed.

    Returns: (is_valid, result, error_msg)
    """
    import ast
    import operator

    try:
        # Clean expression
        expr_clean = expr.replace('x', '*').replace('X', '*')
        expr_clean = re.sub(r'\s+', '', expr_clean)

        # First check: only allowed characters
        if not _SAFE_EXPR_PATTERN.match(expr_clean):
            return False, 0.0, "Invalid characters in expression"

        # Extract numbers from expression
        expr_numbers = [int(n) for n in re.findall(r'\d+', expr_clean)]

        # Check if all numbers are used exactly once
        if sorted(expr_numbers) != sorted(numbers):
            return False, 0.0, f"Numbers mismatch: expected {sorted(numbers)}, got {sorted(expr_numbers)}"

        # Safe AST-based evaluation
        # Only allow: numbers, binary ops (+, -, *, /), unary minus, parentheses
        ALLOWED_OPS = {
            ast.Add: operator.add,
            ast.Sub: operator.sub,
            ast.Mult: operator.mul,
            ast.Div: operator.truediv,
            ast.USub: operator.neg,
        }

        def _eval_node(node):
            if isinstance(node, ast.Expression):
                return _eval_node(node.body)
            elif isinstance(node, ast.Constant):
                if isinstance(node.value, (int, float)):
                    return float(node.value)
                raise ValueError(f"Unsupported constant type: {type(node.value)}")
            elif isinstance(node, ast.Num):  # Python 3.7 compatibility
                return float(node.n)
            elif isinstance(node, ast.BinOp):
                op_type = type(node.op)
                if op_type not in ALLOWED_OPS:
                    raise ValueError(f"Unsupported operator: {op_type.__name__}")
                left = _eval_node(node.left)
                right = _eval_node(node.right)
                if op_type == ast.Div and right == 0:
                    raise ZeroDivisionError("Division by zero")
                return ALLOWED_OPS[op_type](left, right)
            elif isinstance(node, ast.UnaryOp):
                op_type = type(node.op)
                if op_type not in ALLOWED_OPS:
                    raise ValueError(f"Unsupported unary operator: {op_type.__name__}")
                return ALLOWED_OPS[op_type](_eval_node(node.operand))
            else:
                raise ValueError(f"Unsupported AST node: {type(node).__name__}")

        tree = ast.parse(expr_clean, mode='eval')
        result = _eval_node(tree)

        # Check if result is 24
        if abs(result - 24) < 1e-6:
            return True, result, ""
        else:
            return False, result, f"Result is {result}, not 24"

    except ZeroDivisionError:
        return False, 0.0, "Division by zero"
    except (SyntaxError, ValueError) as e:
        return False, 0.0, str(e)
    except Exception as e:
        return False, 0.0, f"Evaluation error: {e}"


async def verify_multipath_consistency(problem: Dict, n_paths: int = 5) -> Dict:
    """
    Generate N independent solutions and check consistency.
    Returns verification results.
    """
    nums = problem["numbers"]
    user_prompt = f"Numbers: {nums[0]}, {nums[1]}, {nums[2]}, {nums[3]}\n\nFind an expression that equals 24."

    # Generate N solutions with different temperatures
    # CUDA optimization: max_tokens=512 for short Game of 24 answers
    tasks = [
        llm_call_async(GAME24_SYSTEM, user_prompt, role="verify", temperature=0.8 + 0.1 * i, max_tokens=512)
        for i in range(n_paths)
    ]
    responses = await asyncio.gather(*tasks, return_exceptions=True)

    valid_count = 0
    expressions = []

    for resp in responses:
        # Handle exceptions from asyncio.gather
        if isinstance(resp, Exception):
            expressions.append({"expr": "", "valid": False, "result": 0, "error": str(resp)})
            continue
        if not resp:
            expressions.append({"expr": "", "valid": False, "result": 0, "error": "Empty response"})
            continue

        # Extract expression
        match = re.search(r"ANSWER:\s*(.+?)\s*=\s*24", resp, re.IGNORECASE)
        if match:
            expr = match.group(1).strip()
        else:
            match2 = re.search(r"([\d\s+\-*/()]+)\s*=\s*24", resp)
            expr = match2.group(1).strip() if match2 else ""

        if expr:
            is_valid, result, error = safe_eval_expression(expr, nums)
            expressions.append({"expr": expr, "valid": is_valid, "result": result, "error": error})
            if is_valid:
                valid_count += 1
        else:
            expressions.append({"expr": "", "valid": False, "result": 0, "error": "No expression found"})

    consistency = valid_count / n_paths
    execution_rate = sum(1 for e in expressions if e["valid"] or e["result"] != 0) / n_paths

    return {
        "consistency": consistency,
        "execution_rate": execution_rate,
        "valid_count": valid_count,
        "total_paths": n_paths,
        "expressions": expressions,
    }

In [None]:
# --- 5.3 Perturbation Testing ---

def generate_perturbations(numbers: Tuple[int, ...], n_perturb: int = 3) -> List[Tuple[int, ...]]:
    """
    Generate perturbed versions of a problem by swapping numbers.
    Only return perturbations that are solvable.
    """
    perturbations = []
    nums = list(numbers)

    # Try swapping one number at a time
    for i in range(4):
        for delta in [-1, 1, -2, 2]:
            new_nums = nums.copy()
            new_val = nums[i] + delta
            if 1 <= new_val <= 13:
                new_nums[i] = new_val
                new_tuple = tuple(sorted(new_nums))
                if new_tuple != numbers and is_solvable_24(new_tuple):
                    perturbations.append(new_tuple)
                    if len(perturbations) >= n_perturb:
                        return perturbations

    return perturbations


async def test_perturbation_robustness(problem: Dict, strategy: str, n_perturb: int = 3) -> Dict:
    """
    Test if a strategy generalizes to perturbed problems.
    """
    perturbations = generate_perturbations(problem["numbers"], n_perturb)

    if not perturbations:
        return {"robustness": 1.0, "tested": 0, "passed": 0}  # No perturbations possible

    system_with_strategy = f"""{GAME24_SYSTEM}

Use this strategy: {strategy}"""

    # CUDA optimization: max_tokens=512 for short Game of 24 answers
    tasks = []
    for perturbed_nums in perturbations:
        user_prompt = f"Numbers: {perturbed_nums[0]}, {perturbed_nums[1]}, {perturbed_nums[2]}, {perturbed_nums[3]}\n\nFind an expression that equals 24."
        tasks.append(llm_call_async(system_with_strategy, user_prompt, role="perturb", temperature=0.7, max_tokens=512))

    responses = await asyncio.gather(*tasks, return_exceptions=True)

    passed = 0
    for resp, perturbed_nums in zip(responses, perturbations):
        if isinstance(resp, Exception) or not resp:
            continue
        match = re.search(r"ANSWER:\s*(.+?)\s*=\s*24", resp, re.IGNORECASE)
        if match:
            expr = match.group(1).strip()
            is_valid, _, _ = safe_eval_expression(expr, perturbed_nums)
            if is_valid:
                passed += 1

    robustness = passed / len(perturbations) if perturbations else 1.0
    return {"robustness": robustness, "tested": len(perturbations), "passed": passed}

In [None]:
# --- 5.4 Adversarial Probing ---

async def test_adversarial(strategy: str, hard_problems: List[Dict]) -> Dict:
    """
    Test strategy against known-hard problems.
    """
    if not hard_problems:
        return {"adversarial_score": 0.0, "solved": 0, "total": 0, "results": []}

    system_with_strategy = f"""{GAME24_SYSTEM}

Use this strategy: {strategy}"""

    # CUDA optimization: max_tokens=512 for short Game of 24 answers
    tasks = []
    for hp in hard_problems:
        nums = hp["numbers"]
        user_prompt = f"Numbers: {nums[0]}, {nums[1]}, {nums[2]}, {nums[3]}\n\nFind an expression that equals 24."
        tasks.append(llm_call_async(system_with_strategy, user_prompt, role="adversarial", temperature=0.7, max_tokens=512))

    responses = await asyncio.gather(*tasks, return_exceptions=True)

    solved = 0
    results = []
    for resp, hp in zip(responses, hard_problems):
        if isinstance(resp, Exception) or not resp:
            results.append({"numbers": hp["numbers"], "valid": False, "expr": "", "error": str(resp) if isinstance(resp, Exception) else "Empty response"})
            continue

        match = re.search(r"ANSWER:\s*(.+?)\s*=\s*24", resp, re.IGNORECASE)
        if match:
            expr = match.group(1).strip()
            is_valid, result, error = safe_eval_expression(expr, hp["numbers"])
            results.append({"numbers": hp["numbers"], "valid": is_valid, "expr": expr})
            if is_valid:
                solved += 1
        else:
            results.append({"numbers": hp["numbers"], "valid": False, "expr": ""})

    adversarial_score = solved / len(hard_problems)
    return {"adversarial_score": adversarial_score, "solved": solved, "total": len(hard_problems), "results": results}

In [None]:
# --- 5.5 Confidence Scoring ---

async def compute_full_verification(problem: Dict, strategy: str = None) -> Dict:
    """
    Run full verification suite and compute confidence score.

    Confidence = consistency * execution_rate * robustness * (1 + adversarial_bonus)
    """
    # Multi-path consistency
    multipath = await verify_multipath_consistency(problem, n_paths=5)

    # Extract strategy from valid solutions
    if strategy is None:
        valid_exprs = [e["expr"] for e in multipath["expressions"] if e["valid"]]
        if valid_exprs:
            # Use the most common valid expression as the strategy
            strategy = f"For numbers like {problem['numbers']}, try: {valid_exprs[0]}"
        else:
            # Fallback to ground truth
            solutions = solve_24_exhaustive(problem["numbers"])
            if solutions:
                strategy = f"For numbers like {problem['numbers']}, try: {solutions[0]}"
            else:
                strategy = "Look for factor pairs that multiply to 24"

    # Perturbation robustness
    perturb = await test_perturbation_robustness(problem, strategy, n_perturb=3)

    # Adversarial testing (only on a subset)
    adversarial = await test_adversarial(strategy, HARD_PROBLEMS[:5])

    # Compute confidence
    consistency = multipath["consistency"]
    execution_rate = multipath["execution_rate"]
    robustness = perturb["robustness"]
    adversarial_bonus = adversarial["adversarial_score"] * 0.2  # Up to 20% bonus

    confidence = consistency * execution_rate * robustness * (1 + adversarial_bonus)

    return {
        "confidence": confidence,
        "consistency": consistency,
        "execution_rate": execution_rate,
        "robustness": robustness,
        "adversarial_bonus": adversarial_bonus,
        "strategy": strategy,
        "multipath": multipath,
        "perturbation": perturb,
        "adversarial": adversarial,
    }


# Run verification on all archetypes IN PARALLEL
# CUDA best practice: Parallelize independent work to maximize throughput
print("Running verification suite on archetypes (parallel)...")
print("=" * 80)

async def verify_all_archetypes():
    """Verify all archetypes concurrently."""
    # Launch all verifications in parallel
    tasks = [compute_full_verification(arch) for arch in archetypes]
    print(f"Launching {len(tasks)} parallel verification tasks...")

    verifications = await asyncio.gather(*tasks, return_exceptions=True)

    # Assign results back to archetypes
    for i, (arch, verification) in enumerate(zip(archetypes, verifications)):
        if isinstance(verification, Exception):
            print(f"  Archetype {i+1} ({arch['numbers']}): FAILED - {verification}")
            arch["verification"] = {"confidence": 0.0, "error": str(verification)}
        else:
            arch["verification"] = verification
            print(f"  Archetype {i+1} ({arch['numbers']}): conf={verification['confidence']:.3f} "
                  f"(cons={verification['consistency']:.2f}, exec={verification['execution_rate']:.2f}, "
                  f"rob={verification['robustness']:.2f}, adv={verification['adversarial_bonus']:.2f})")

asyncio.run(verify_all_archetypes())

print("\n" + "=" * 80)
print(f"Verification complete. Total LLM calls: {sum(call_counter.values())}")

## 6. Playbook Bootstrap from Verified Archetypes

In [None]:
# --- Playbook Data Structure ---

@dataclass
class PlaybookBullet:
    """A single strategy in the playbook."""
    id: str
    content: str
    confidence: float
    source_problem: Tuple[int, ...]
    helpful_count: int = 0
    harmful_count: int = 0

    def to_str(self) -> str:
        return f"[{self.id}] confidence={self.confidence:.2f} helpful={self.helpful_count} harmful={self.harmful_count} :: {self.content}"


@dataclass
class Playbook:
    """Collection of strategies for solving Game of 24."""
    bullets: List[PlaybookBullet] = field(default_factory=list)
    _next_id: int = 1

    def add(self, content: str, confidence: float, source_problem: Tuple[int, ...]) -> str:
        """Add a new bullet and return its ID."""
        bid = f"arch-{self._next_id:05d}"
        self._next_id += 1
        self.bullets.append(PlaybookBullet(
            id=bid, content=content, confidence=confidence, source_problem=source_problem
        ))
        return bid

    def tag(self, bid: str, label: str):
        """Tag a bullet as helpful or harmful."""
        for b in self.bullets:
            if b.id == bid:
                if label == "helpful":
                    b.helpful_count += 1
                elif label == "harmful":
                    b.harmful_count += 1
                break

    def update_confidence(self, bid: str, delta: float):
        """Adjust confidence for a bullet."""
        for b in self.bullets:
            if b.id == bid:
                b.confidence = max(0.0, min(1.0, b.confidence + delta))
                break

    def to_str(self) -> str:
        """Format playbook for LLM prompt."""
        if not self.bullets:
            return "(empty playbook)"
        lines = ["## GAME OF 24 STRATEGIES"]
        for b in sorted(self.bullets, key=lambda x: -x.confidence):
            lines.append(b.to_str())
        return "\n".join(lines)

    def copy(self) -> "Playbook":
        """Create a deep copy of the playbook."""
        return copy.deepcopy(self)

    @property
    def size(self) -> int:
        """Number of bullets in the playbook."""
        return len(self.bullets)

In [None]:
# --- Extract Strategies from Archetypes ---

STRATEGY_EXTRACTION_SYSTEM = """You are analyzing Game of 24 solutions to extract reusable strategies.

Given a solved problem, extract 1-2 general strategies that could help solve similar problems.
Focus on:
- Patterns involving multiplication/division to reach 24
- How to use specific numbers (like 6, 8, 4, 3, 2, 1)
- When to use fractions vs whole number operations
- Factor pairs that make 24: (1,24), (2,12), (3,8), (4,6)

Format each strategy as a single line starting with "STRATEGY:"
Make strategies general enough to apply to other problems.

Example output:
STRATEGY: When you have 6 and 4, try to make the other numbers equal 1 (e.g., 3/3=1) so 6*4*1=24
STRATEGY: If you have 8, look for ways to make 3 from the other numbers since 8*3=24"""


async def extract_strategies_from_archetype(arch: Dict) -> List[str]:
    """Extract general strategies from a verified archetype."""
    nums = arch["numbers"]
    solutions = solve_24_exhaustive(nums)[:3]  # Use up to 3 solutions

    if not solutions:
        print(f"WARNING: No valid solution found for archetype {nums}, skipping")
        return []

    user_prompt = f"""Problem: Use {nums[0]}, {nums[1]}, {nums[2]}, {nums[3]} to make 24

Valid solutions:
"""
    for sol in solutions:
        user_prompt += f"- {sol}\n"

    user_prompt += "\nExtract 1-2 reusable strategies from these solutions."

    response = await llm_call_async(STRATEGY_EXTRACTION_SYSTEM, user_prompt, role="extract", temperature=0.5)

    if not response:
        return []

    strategies = []
    for line in response.split("\n"):
        line = line.strip()
        if line.startswith("STRATEGY:"):
            strategy = line[len("STRATEGY:"):].strip()
            if strategy:
                strategies.append(strategy)

    return strategies


# Bootstrap playbook from verified archetypes
print("Bootstrapping playbook from verified archetypes...")

# Check if any archetypes passed verification (confidence >= 0.3)
verified_archetypes = [a for a in archetypes if a.get("verification", {}).get("confidence", 0) >= 0.3]
USE_RANDOM_CURRICULUM_FALLBACK = False
if not verified_archetypes:
    print("WARNING: All archetypes failed verification. Falling back to random curriculum.")
    USE_RANDOM_CURRICULUM_FALLBACK = True

playbook = Playbook()

# Add base strategies
BASE_STRATEGIES = [
    "Look for factor pairs of 24: (1,24), (2,12), (3,8), (4,6)",
    "If you have 8, try to make 3 from remaining numbers (8*3=24)",
    "If you have 6, try to make 4 from remaining numbers (6*4=24)",
    "Use division to create fractions when direct operations don't work",
    "Try (a+b)*(c-d) or (a-b)*(c+d) patterns for 24",
]
for strat in BASE_STRATEGIES:
    playbook.add(strat, confidence=0.7, source_problem=(0, 0, 0, 0))


async def bootstrap_playbook():
    for arch in archetypes:
        confidence = arch.get("verification", {}).get("confidence", 0.5)
        if confidence < 0.3:
            print(f"  {arch['numbers']}: skipped (confidence {confidence:.2f} < 0.3)")
            continue

        # REQ-VER-1: Enforce consistency threshold of 0.6
        consistency = arch.get("verification", {}).get("consistency", 0)
        if consistency < 0.6:
            print(f"WARNING: Archetype {arch['numbers']} failed consistency check ({consistency:.2f} < 0.6)")
            continue

        strategies = await extract_strategies_from_archetype(arch)
        for strat in strategies:
            playbook.add(strat, confidence=confidence, source_problem=arch["numbers"])

        print(f"  {arch['numbers']}: extracted {len(strategies)} strategies (conf={confidence:.2f})")


asyncio.run(bootstrap_playbook())

print(f"\nPlaybook initialized with {playbook.size} strategies:")
print("-" * 80)
print(playbook.to_str())

## 7. LinUCB Contextual Bandit Implementation

In [None]:
# --- LinUCB Contextual Bandit ---

class LinUCB:
    """
    Linear Upper Confidence Bound algorithm for contextual bandits.

    Each arm maintains:
    - A_k: d x d matrix (initialized to identity)
    - b_k: d-dimensional vector (initialized to zero)
    - theta_k = A_k^{-1} b_k (estimated coefficients)

    Selection: argmax_k (theta_k^T x_t + alpha * sqrt(x_t^T A_k^{-1} x_t))

    Regularization: Uses lambda * I for numerical stability in matrix inversion.
    """

    def __init__(self, n_arms: int, d: int, alpha: float = 1.0, reg_lambda: float = 1.0):
        """
        Args:
            n_arms: Number of arms.
            d: Feature dimension.
            alpha: Exploration parameter (higher = more exploration).
            reg_lambda: Regularization parameter for matrix inversion.
        """
        self.n_arms = n_arms
        self.d = d
        self.alpha = alpha
        self.reg_lambda = reg_lambda

        # Initialize A with regularization (lambda * I) and b to zero for each arm
        self.A = [reg_lambda * np.eye(d) for _ in range(n_arms)]
        self.b = [np.zeros(d) for _ in range(n_arms)]

        # Track statistics
        self.arm_counts = np.zeros(n_arms)
        self.arm_rewards = np.zeros(n_arms)

    def _safe_inv(self, matrix: np.ndarray) -> np.ndarray:
        """Compute matrix inverse with regularization fallback."""
        try:
            return np.linalg.inv(matrix)
        except np.linalg.LinAlgError:
            # Add extra regularization if singular
            return np.linalg.inv(matrix + self.reg_lambda * np.eye(self.d))

    def select(self, context: np.ndarray) -> int:
        """
        Select arm with highest UCB value.

        Args:
            context: d-dimensional feature vector

        Returns:
            Selected arm index
        """
        ucb_values = np.zeros(self.n_arms)

        for k in range(self.n_arms):
            A_inv = self._safe_inv(self.A[k])
            theta_k = A_inv @ self.b[k]

            # UCB = theta^T x + alpha * sqrt(x^T A^{-1} x)
            exploitation = theta_k @ context
            variance_term = context @ A_inv @ context
            # Clamp variance term to avoid numerical issues
            variance_term = max(0.0, variance_term)
            exploration = self.alpha * np.sqrt(variance_term)
            ucb_values[k] = exploitation + exploration

        return int(np.argmax(ucb_values))

    def update(self, arm: int, context: np.ndarray, reward: float):
        """
        Update arm statistics after observing reward.

        A_k += x_t x_t^T
        b_k += r_t x_t
        """
        self.A[arm] += np.outer(context, context)
        self.b[arm] += reward * context
        self.arm_counts[arm] += 1
        self.arm_rewards[arm] += reward

    def get_theta(self, arm: int) -> np.ndarray:
        """Get estimated coefficients for an arm."""
        A_inv = self._safe_inv(self.A[arm])
        return A_inv @ self.b[arm]

    def get_ucb_gap(self, context: np.ndarray) -> float:
        """Get the gap between best and second-best UCB values."""
        ucb_values = []
        for k in range(self.n_arms):
            A_inv = self._safe_inv(self.A[k])
            theta_k = A_inv @ self.b[k]
            variance_term = max(0.0, context @ A_inv @ context)
            ucb = theta_k @ context + self.alpha * np.sqrt(variance_term)
            ucb_values.append(ucb)

        sorted_ucb = sorted(ucb_values, reverse=True)
        if len(sorted_ucb) > 1:
            return sorted_ucb[0] - sorted_ucb[1]
        return 0.0


print("LinUCB implementation ready.")

In [None]:
# --- Problem Feature Extraction for LinUCB ---

def extract_linucb_features(problem: Dict) -> np.ndarray:
    """
    Extract feature vector for LinUCB context.
    Uses the same numerical features as for clustering.
    """
    nums = problem["numbers"]

    features = [
        sum(nums) / 52,  # sum normalized
        np.prod(nums) / (13**4),  # product normalized
        max(nums) / 13,  # max normalized
        min(nums) / 13,  # min normalized
        len(set(nums)) / 4,  # uniqueness ratio
        (max(nums) - min(nums)) / 12,  # range normalized
        np.std(nums) / 5,  # std normalized
        sum(1 for n in nums if n % 2 == 0) / 4,  # even ratio
        sum(1 for n in nums if 24 % n == 0) / 4,  # divisor of 24 ratio
        1.0 if any(a * b == 24 for a, b in itertools.combinations(nums, 2)) else 0.0,  # has factor pair
        # Additional features for curriculum learning
        problem.get("n_solutions", 5) / 20,  # solution count normalized
        problem.get("centrality", 0.5),  # cluster centrality
        1.0,  # bias term
    ]

    return np.array(features)


# Verify feature extraction
test_features = extract_linucb_features(problems[0])
print(f"Feature dimension: {len(test_features)}")
print(f"Example features: {test_features}")

## 8. Curriculum Loop (Archetype-first vs Random)

In [None]:
# --- Core Evaluation Functions ---

SOLVE_SYSTEM_TEMPLATE = """You are a Game of 24 solver. Given 4 numbers, find an expression using +, -, *, / that equals 24.
Each number must be used exactly once.

{playbook}

When you use a strategy from the playbook, mention its ID (e.g., [arch-00001]).
Show your reasoning, then give the final answer as: ANSWER: <expression> = 24"""


async def solve_with_playbook(problem: Dict, playbook: Playbook) -> Dict:
    """
    Solve a Game of 24 problem using the playbook.

    Returns dict with: correct, expression, result, error, bullets_used, response
    """
    nums = problem["numbers"]

    system = SOLVE_SYSTEM_TEMPLATE.format(playbook=playbook.to_str())
    user_prompt = f"Numbers: {nums[0]}, {nums[1]}, {nums[2]}, {nums[3]}\n\nFind an expression that equals 24."
    response = await llm_call_async(system, user_prompt, role="solve", temperature=0.7, max_tokens=512)

    # Handle empty response
    if not response:
        return {
            "correct": False,
            "expression": "",
            "result": 0.0,
            "error": "Empty LLM response",
            "bullets_used": [],
            "response": "",
        }

    # Extract answer
    match = re.search(r"ANSWER:\s*(.+?)\s*=\s*24", response, re.IGNORECASE)
    if match:
        expr = match.group(1).strip()
    else:
        match2 = re.search(r"([\d\s+\-*/()]+)\s*=\s*24", response)
        expr = match2.group(1).strip() if match2 else ""

    # Validate
    if expr:
        is_correct, result, error = safe_eval_expression(expr, nums)
    else:
        is_correct, result, error = False, 0.0, "No expression found in response"

    # Extract used bullet IDs
    bullets_used = re.findall(r"\[(arch-\d+)\]", response)

    return {
        "correct": is_correct,
        "expression": expr,
        "result": result,
        "error": error,
        "bullets_used": bullets_used,
        "response": response,
    }


def update_playbook_from_result(playbook: Playbook, result: Dict):
    """Update playbook bullet statistics based on result."""
    label = "helpful" if result["correct"] else "harmful"
    for bid in result["bullets_used"]:
        playbook.tag(bid, label)

In [None]:
# --- Run Log ---

@dataclass
class CurriculumRunLog:
    """Tracks results for a curriculum condition."""
    correct: List[bool] = field(default_factory=list)
    playbook_sizes: List[int] = field(default_factory=list)
    confidence_history: List[List[float]] = field(default_factory=list)
    ucb_gaps: List[float] = field(default_factory=list)
    problems_order: List[Tuple[int, ...]] = field(default_factory=list)
    call_counts: Dict[str, int] = field(default_factory=dict)
    final_playbook: Optional[Playbook] = None
    # Strategy churn tracking (REQ-EVAL-2)
    strategies_added: List[int] = field(default_factory=list)
    strategies_removed: List[int] = field(default_factory=list)

    @property
    def cumulative_accuracy(self) -> List[float]:
        """Compute cumulative accuracy at each step."""
        if not self.correct:
            return []
        acc = []
        total = 0
        for i, c in enumerate(self.correct):
            total += int(c)
            acc.append(total / (i + 1))
        return acc

    @property
    def final_accuracy(self) -> float:
        """Final accuracy over all problems."""
        if not self.correct:
            return 0.0
        return sum(self.correct) / len(self.correct)

In [None]:
# --- Archetype-First Curriculum ---

CONFIDENCE_UPDATE_INTERVAL = 10
N_EVAL_PROBLEMS = 100
BATCH_SIZE = 128  # Match MAX_CONCURRENT_LLM for full GPU utilization  # Mini-batch size for parallel LLM calls


def _chunked(iterable, size):
    """Yield successive chunks of specified size."""
    for i in range(0, len(iterable), size):
        yield iterable[i:i + size]


def _update_bullet_confidences(playbook: Playbook, decay: float = 0.7) -> int:
    """
    Update bullet confidences based on helpful/harmful counts.
    Returns number of strategies removed (pruned).
    """
    removed_count = 0
    bullets_to_keep = []
    for b in playbook.bullets:
        total = b.helpful_count + b.harmful_count
        if total > 0:
            ratio = b.helpful_count / total
            b.confidence = decay * b.confidence + (1 - decay) * ratio
            # Prune strategies with very low confidence after sufficient feedback
            if b.confidence < 0.1 and total >= 5:
                removed_count += 1
                continue
        bullets_to_keep.append(b)
    playbook.bullets = bullets_to_keep
    return removed_count


async def run_archetype_curriculum(
    all_problems: List[Dict],
    archetypes: List[Dict],
    initial_playbook: Playbook,
) -> CurriculumRunLog:
    """
    Run archetype-first curriculum with mini-batching for GPU efficiency.

    Mini-batching optimization:
    - Process BATCH_SIZE problems in parallel via asyncio.gather
    - Update playbook/bandit after each batch (not each problem)
    - Enables prefix caching (constant system prompt within batch)
    - LinUCB is robust to delayed feedback

    Phases:
    1. Phase 1: Process archetypes first (warm start)
    2. Phase 2: LinUCB exploration-exploitation on remaining problems
    """
    log = CurriculumRunLog()
    pb = initial_playbook.copy()
    initial_strategy_count = pb.size

    # Track strategies added during bootstrap (already done before this function)
    log.strategies_added.append(initial_strategy_count)

    # Separate archetypes from other problems
    archetype_nums = {a["numbers"] for a in archetypes}
    non_archetypes = [p for p in all_problems if p["numbers"] not in archetype_nums]

    # Limit to N_EVAL_PROBLEMS total
    n_archetypes = min(10, N_EVAL_PROBLEMS // 10)
    eval_archetypes = archetypes[:n_archetypes]
    remaining_budget = N_EVAL_PROBLEMS - len(eval_archetypes)
    eval_non_archetypes = non_archetypes[:remaining_budget]

    # Initialize LinUCB
    N_ARMS = 3
    FEATURE_DIM = 13
    linucb = LinUCB(n_arms=N_ARMS, d=FEATURE_DIM, alpha=1.5, reg_lambda=1.0)

    print(f"\nPhase 1: Processing {len(eval_archetypes)} archetypes (batch_size={BATCH_SIZE})...")

    # Phase 1: Archetypes (batched)
    for batch in _chunked(eval_archetypes, BATCH_SIZE):
        # Solve batch in parallel (same playbook for all - enables prefix caching)
        tasks = [solve_with_playbook(arch, pb) for arch in batch]
        results = await asyncio.gather(*tasks)

        # Update logs and playbook after batch
        for arch, result in zip(batch, results):
            log.correct.append(result["correct"])
            log.playbook_sizes.append(pb.size)
            log.problems_order.append(arch["numbers"])
            log.confidence_history.append([b.confidence for b in pb.bullets])

            # Update playbook bullet counts
            update_playbook_from_result(pb, result)

            # Update LinUCB with archetype results (arm 2 = balanced)
            context = extract_linucb_features(arch)
            linucb.update(2, context, 1.0 if result["correct"] else 0.0)

        acc = sum(log.correct) / len(log.correct)
        print(f"  Archetypes {len(log.correct)}/{len(eval_archetypes)}: accuracy={acc:.2%}")

    print(f"\nPhase 2: LinUCB on {len(eval_non_archetypes)} problems (batch_size={BATCH_SIZE})...")

    # Phase 2: LinUCB exploration (batched)
    problems_done = 0
    for batch in _chunked(eval_non_archetypes, BATCH_SIZE):
        # Pre-compute arm selections for batch (before solving)
        batch_contexts = [extract_linucb_features(prob) for prob in batch]
        batch_arms = [linucb.select(ctx) for ctx in batch_contexts]

        # Record UCB gaps
        for ctx in batch_contexts:
            log.ucb_gaps.append(linucb.get_ucb_gap(ctx))

        # Solve batch in parallel (same playbook - enables prefix caching)
        tasks = [solve_with_playbook(prob, pb) for prob in batch]
        results = await asyncio.gather(*tasks)

        # Update logs, playbook, and bandit after batch
        for prob, result, ctx, arm in zip(batch, results, batch_contexts, batch_arms):
            log.correct.append(result["correct"])
            log.playbook_sizes.append(pb.size)
            log.problems_order.append(prob["numbers"])
            log.confidence_history.append([b.confidence for b in pb.bullets])

            # Update playbook and bandit
            update_playbook_from_result(pb, result)
            reward = 1.0 if result["correct"] else 0.0
            linucb.update(arm, ctx, reward)

        problems_done += len(batch)

        # Retrospective confidence update and pruning (every CONFIDENCE_UPDATE_INTERVAL)
        if problems_done % CONFIDENCE_UPDATE_INTERVAL == 0:
            removed = _update_bullet_confidences(pb)
            if removed > 0:
                log.strategies_removed.append(removed)
                print(f"  Pruned {removed} low-confidence strategies")

        total_done = len(eval_archetypes) + problems_done
        if total_done % 20 == 0:
            acc = sum(log.correct) / len(log.correct)
            print(f"  Problems {total_done}/{N_EVAL_PROBLEMS}: accuracy={acc:.2%}")

    log.final_playbook = pb
    log.call_counts = dict(call_counter)

    print(f"\nArchetype curriculum complete. Final accuracy: {log.final_accuracy:.2%}")
    print(f"Strategy churn: +{sum(log.strategies_added)} added, -{sum(log.strategies_removed)} removed")
    return log

In [None]:
# --- Random Curriculum (Baseline) ---

async def run_random_curriculum(
    all_problems: List[Dict],
    initial_playbook: Playbook,
) -> CurriculumRunLog:
    """
    Run random curriculum baseline with mini-batching for GPU efficiency.

    Mini-batching optimization:
    - Process BATCH_SIZE problems in parallel via asyncio.gather
    - Update playbook after each batch (not each problem)
    - Enables prefix caching (constant system prompt within batch)
    """
    log = CurriculumRunLog()
    pb = initial_playbook.copy()
    initial_strategy_count = pb.size

    # Track initial strategies
    log.strategies_added.append(initial_strategy_count)

    # Shuffle problems
    eval_problems = all_problems[:N_EVAL_PROBLEMS]
    rng = random.Random(SEED + 1)  # Different seed for random order
    shuffled = eval_problems.copy()
    rng.shuffle(shuffled)

    print(f"\nProcessing {len(shuffled)} problems in random order (batch_size={BATCH_SIZE})...")

    problems_done = 0
    for batch in _chunked(shuffled, BATCH_SIZE):
        # Solve batch in parallel (same playbook - enables prefix caching)
        tasks = [solve_with_playbook(prob, pb) for prob in batch]
        results = await asyncio.gather(*tasks)

        # Update logs and playbook after batch
        for prob, result in zip(batch, results):
            log.correct.append(result["correct"])
            log.playbook_sizes.append(pb.size)
            log.problems_order.append(prob["numbers"])
            log.confidence_history.append([b.confidence for b in pb.bullets])

            # Update playbook
            update_playbook_from_result(pb, result)

        problems_done += len(batch)

        # Retrospective confidence update and pruning (every CONFIDENCE_UPDATE_INTERVAL)
        if problems_done % CONFIDENCE_UPDATE_INTERVAL == 0:
            removed = _update_bullet_confidences(pb)
            if removed > 0:
                log.strategies_removed.append(removed)
                print(f"  Pruned {removed} low-confidence strategies")

        if problems_done % 20 == 0:
            acc = sum(log.correct) / len(log.correct)
            print(f"  Problems {problems_done}/{len(shuffled)}: accuracy={acc:.2%}")

    log.final_playbook = pb
    log.call_counts = dict(call_counter)

    print(f"\nRandom curriculum complete. Final accuracy: {log.final_accuracy:.2%}")
    print(f"Strategy churn: +{sum(log.strategies_added)} added, -{sum(log.strategies_removed)} removed")
    return log

In [None]:
# --- Run Both Conditions ---

print("="*80)
print("RUNNING ARCHETYPE-FIRST CURRICULUM")
print("="*80)

# Reset call counter
call_counter.clear()

# Check for fallback mode (set in cell-23 if all archetypes failed verification)
if USE_RANDOM_CURRICULUM_FALLBACK:
    print("FALLBACK MODE: All archetypes failed verification, using random curriculum for both conditions")
    archetype_log = asyncio.run(run_random_curriculum(problems, playbook))
else:
    archetype_log = asyncio.run(run_archetype_curriculum(problems, archetypes, playbook))

print("\n" + "="*80)
print("RUNNING RANDOM CURRICULUM (BASELINE)")
print("="*80)

# Reset call counter
call_counter.clear()

random_log = asyncio.run(run_random_curriculum(problems, playbook))

print("\n" + "="*80)
print("COMPARISON SUMMARY")
print("="*80)
print(f"Archetype-first accuracy: {archetype_log.final_accuracy:.2%}")
print(f"Random baseline accuracy: {random_log.final_accuracy:.2%}")
print(f"Improvement: {(archetype_log.final_accuracy - random_log.final_accuracy)*100:.1f} percentage points")
print(f"\nStrategy churn (archetype): +{sum(archetype_log.strategies_added)} added, -{sum(archetype_log.strategies_removed)} removed")
print(f"Strategy churn (random): +{sum(random_log.strategies_added)} added, -{sum(random_log.strategies_removed)} removed")

## 8.5 Coverage Efficiency Ablation

Measure how many verified archetypes are needed for effective learning.
This answers the proposal question: "how many verified archetypes are needed?"

In [None]:
# --- Coverage Efficiency: Ablation on Number of Archetypes ---

ARCHETYPE_COUNTS = [0, 2, 4, 6, 8, 10]  # Vary number of archetypes used

async def run_archetype_ablation(
    all_problems: List[Dict],
    archetypes: List[Dict],
    initial_playbook: Playbook,
    n_archetypes: int,
) -> CurriculumRunLog:
    """
    Run curriculum with exactly n_archetypes in Phase 1.
    n_archetypes=0 is equivalent to random curriculum (no archetype warm-start).
    """
    log = CurriculumRunLog()
    pb = initial_playbook.copy()
    log.strategies_added.append(pb.size)

    # Select subset of archetypes
    eval_archetypes = archetypes[:n_archetypes]

    # Remaining problems (exclude used archetypes)
    archetype_nums = {a["numbers"] for a in eval_archetypes}
    non_archetypes = [p for p in all_problems if p["numbers"] not in archetype_nums]
    remaining_budget = N_EVAL_PROBLEMS - len(eval_archetypes)
    eval_non_archetypes = non_archetypes[:remaining_budget]

    # Initialize LinUCB
    N_ARMS = 3
    FEATURE_DIM = 13
    linucb = LinUCB(n_arms=N_ARMS, d=FEATURE_DIM, alpha=1.5, reg_lambda=1.0)

    # Phase 1: Archetypes (if any)
    if eval_archetypes:
        for batch in _chunked(eval_archetypes, BATCH_SIZE):
            tasks = [solve_with_playbook(arch, pb) for arch in batch]
            results = await asyncio.gather(*tasks)
            for arch, result in zip(batch, results):
                log.correct.append(result["correct"])
                log.playbook_sizes.append(pb.size)
                log.problems_order.append(arch["numbers"])
                update_playbook_from_result(pb, result)
                context = extract_linucb_features(arch)
                linucb.update(2, context, 1.0 if result["correct"] else 0.0)

    # Phase 2: Remaining problems with LinUCB
    problems_done = 0
    for batch in _chunked(eval_non_archetypes, BATCH_SIZE):
        batch_contexts = [extract_linucb_features(prob) for prob in batch]
        batch_arms = [linucb.select(ctx) for ctx in batch_contexts]

        tasks = [solve_with_playbook(prob, pb) for prob in batch]
        results = await asyncio.gather(*tasks)

        for prob, result, ctx, arm in zip(batch, results, batch_contexts, batch_arms):
            log.correct.append(result["correct"])
            log.playbook_sizes.append(pb.size)
            log.problems_order.append(prob["numbers"])
            update_playbook_from_result(pb, result)
            linucb.update(arm, ctx, 1.0 if result["correct"] else 0.0)

        problems_done += len(batch)
        if problems_done % CONFIDENCE_UPDATE_INTERVAL == 0:
            _update_bullet_confidences(pb)

    log.final_playbook = pb
    return log


print("=" * 80)
print("COVERAGE EFFICIENCY ABLATION")
print("=" * 80)
print(f"Testing archetype counts: {ARCHETYPE_COUNTS}")

# Reset call counter
call_counter.clear()

ablation_results = {}

async def run_ablation():
    """Run all ablation conditions IN PARALLEL for GPU efficiency."""
    # Launch all ablation runs concurrently - GPU semaphore manages load
    tasks = [
        run_archetype_ablation(problems, archetypes, playbook, n_arch)
        for n_arch in ARCHETYPE_COUNTS
    ]
    print(f"  Launching {len(tasks)} ablation conditions in parallel...")
    results = await asyncio.gather(*tasks)

    # Collect results
    for n_arch, log in zip(ARCHETYPE_COUNTS, results):
        ablation_results[n_arch] = {
            "final_accuracy": log.final_accuracy,
            "cumulative_accuracy": log.cumulative_accuracy,
            "correct": log.correct,
        }
        print(f"  {n_arch} archetypes: {log.final_accuracy:.2%}")

asyncio.run(run_ablation())

# Summary
print("\n" + "=" * 80)
print("COVERAGE EFFICIENCY SUMMARY")
print("=" * 80)
print(f"{'# Archetypes':<15} {'Final Accuracy':<15} {'Δ vs 0':<15}")
print("-" * 45)
baseline_acc = ablation_results[0]["final_accuracy"]
for n_arch in ARCHETYPE_COUNTS:
    acc = ablation_results[n_arch]["final_accuracy"]
    delta = acc - baseline_acc
    print(f"{n_arch:<15} {acc:<15.2%} {delta:+.1%}")

# Find diminishing returns point
best_n = max(ARCHETYPE_COUNTS, key=lambda n: ablation_results[n]["final_accuracy"])
print(f"\nOptimal archetype count: {best_n}")

## 9. Analysis & Plotting

In [None]:
# --- Plot 1: Accuracy Curves ---

def calculate_mean_confidence_over_time(log: CurriculumRunLog) -> List[float]:
    """Calculate mean bullet confidence at each step."""
    precisions = []
    for conf_snapshot in log.confidence_history:
        if conf_snapshot:
            precisions.append(np.mean(conf_snapshot))
        else:
            precisions.append(0.5)
    return precisions if precisions else [0.5] * len(log.correct)


fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot 1: Cumulative accuracy over time
ax1 = axes[0, 0]
ax1.plot(archetype_log.cumulative_accuracy, label='Archetype-first', color='blue', linewidth=2)
ax1.plot(random_log.cumulative_accuracy, label='Random', color='red', linewidth=2, linestyle='--')
ax1.axvline(x=len(archetypes), color='green', linestyle=':', label='End of archetype phase')
ax1.set_xlabel('Problems Solved')
ax1.set_ylabel('Cumulative Accuracy')
ax1.set_title('Learning Curves: Archetype-first vs Random')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 1)

# Plot 2: Strategy Confidence Over Time
ax2 = axes[0, 1]
arch_precision = calculate_mean_confidence_over_time(archetype_log)
rand_precision = calculate_mean_confidence_over_time(random_log)
ax2.plot(arch_precision, label='Archetype-first', color='blue', linewidth=2)
ax2.plot(rand_precision, label='Random', color='red', linewidth=2, linestyle='--')
ax2.set_xlabel('Problems Solved')
ax2.set_ylabel('Mean Strategy Confidence')
ax2.set_title('Strategy Confidence Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Coverage Efficiency (Ablation)
ax3 = axes[0, 2]
if ablation_results:
    n_archs = list(ablation_results.keys())
    accuracies = [ablation_results[n]["final_accuracy"] for n in n_archs]
    ax3.plot(n_archs, accuracies, 'bo-', linewidth=2, markersize=8)
    ax3.axhline(y=ablation_results[0]["final_accuracy"], color='red', linestyle='--',
                label=f'Random baseline: {ablation_results[0]["final_accuracy"]:.1%}')
    ax3.set_xlabel('Number of Archetypes')
    ax3.set_ylabel('Final Accuracy')
    ax3.set_title('Coverage Efficiency: How Many Archetypes Needed?')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.set_xticks(n_archs)

# Plot 4: Final Confidence Distribution
ax4 = axes[1, 0]
if archetype_log.final_playbook and random_log.final_playbook:
    arch_confs = [b.confidence for b in archetype_log.final_playbook.bullets]
    rand_confs = [b.confidence for b in random_log.final_playbook.bullets]
    ax4.hist(arch_confs, bins=10, alpha=0.6, label='Archetype-first', color='blue')
    ax4.hist(rand_confs, bins=10, alpha=0.6, label='Random', color='red')
    ax4.set_xlabel('Strategy Confidence')
    ax4.set_ylabel('Count')
    ax4.set_title('Final Strategy Confidence Distribution')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

# Plot 5: UCB Gap (Exploration vs Exploitation)
ax5 = axes[1, 1]
if archetype_log.ucb_gaps:
    ax5.plot(archetype_log.ucb_gaps, color='blue', linewidth=1, alpha=0.7)
    mean_gap = np.mean(archetype_log.ucb_gaps)
    ax5.axhline(y=mean_gap, color='blue', linestyle='--', label=f'Mean: {mean_gap:.3f}')
    ax5.set_xlabel('Problem Index (Phase 2)')
    ax5.set_ylabel('UCB Gap')
    ax5.set_title('LinUCB Exploration-Exploitation Gap')
    ax5.legend()
    ax5.grid(True, alpha=0.3)

# Plot 6: Ablation Learning Curves
ax6 = axes[1, 2]
if ablation_results:
    colors = plt.cm.viridis(np.linspace(0, 1, len(ARCHETYPE_COUNTS)))
    for n_arch, color in zip(ARCHETYPE_COUNTS, colors):
        cum_acc = ablation_results[n_arch]["cumulative_accuracy"]
        ax6.plot(cum_acc, label=f'{n_arch} archetypes', color=color, linewidth=1.5, alpha=0.8)
    ax6.set_xlabel('Problems Solved')
    ax6.set_ylabel('Cumulative Accuracy')
    ax6.set_title('Learning Curves by Archetype Count')
    ax6.legend(fontsize=8)
    ax6.grid(True, alpha=0.3)
    ax6.set_ylim(0, 1)

plt.tight_layout()
plt.savefig('archetype_discovery_results.png', dpi=150)
plt.show()

print("Plots saved to archetype_discovery_results.png")

## 10. Results Summary & Statistical Tests

In [None]:
# --- Bootstrap Confidence Intervals ---

def bootstrap_ci(
    data: List[bool],
    n_bootstrap: int = 1000,
    ci: float = 0.95,
) -> Tuple[float, float, float]:
    """
    Calculate bootstrap confidence interval for accuracy.

    Returns: (mean, lower_bound, upper_bound)
    """
    data_arr = np.array(data, dtype=float)
    n = len(data_arr)

    rng = np.random.RandomState(SEED)
    bootstrap_means = [
        np.mean(rng.choice(data_arr, size=n, replace=True))
        for _ in range(n_bootstrap)
    ]

    alpha = 1 - ci
    lower = np.percentile(bootstrap_means, alpha / 2 * 100)
    upper = np.percentile(bootstrap_means, (1 - alpha / 2) * 100)
    mean = np.mean(data_arr)

    return mean, lower, upper


def paired_comparison(arch_correct: List[bool], rand_correct: List[bool]) -> Dict:
    """
    Compare two conditions using Wilcoxon signed-rank test.
    Uses sliding window accuracy for paired comparison.
    """
    window = 10

    def windowed_accuracy(correct_list):
        return [
            np.mean(correct_list[max(0, i - window):i + 1])
            for i in range(len(correct_list))
        ]

    arch_windows = windowed_accuracy(arch_correct)
    rand_windows = windowed_accuracy(rand_correct)

    # Truncate to same length
    min_len = min(len(arch_windows), len(rand_windows))
    arch_windows = arch_windows[:min_len]
    rand_windows = rand_windows[:min_len]

    # Wilcoxon test
    try:
        stat, p_value = stats.wilcoxon(arch_windows, rand_windows, alternative='greater')
    except ValueError:
        # All differences are zero
        stat, p_value = 0.0, 1.0

    return {
        "statistic": stat,
        "p_value": p_value,
        "significant": p_value < 0.05,
    }


def cohens_h(p1: float, p2: float) -> float:
    """Cohen's h effect size for comparing two proportions."""
    return 2 * (np.arcsin(np.sqrt(p1)) - np.arcsin(np.sqrt(p2)))


# --- Results Summary ---

print("=" * 80)
print("STATISTICAL ANALYSIS")
print("=" * 80)

# Bootstrap CIs
arch_mean, arch_lower, arch_upper = bootstrap_ci(archetype_log.correct)
rand_mean, rand_lower, rand_upper = bootstrap_ci(random_log.correct)

print(f"\nArchetype-first Accuracy: {arch_mean:.2%} (95% CI: [{arch_lower:.2%}, {arch_upper:.2%}])")
print(f"Random Baseline Accuracy: {rand_mean:.2%} (95% CI: [{rand_lower:.2%}, {rand_upper:.2%}])")

# Wilcoxon test
wilcoxon_result = paired_comparison(archetype_log.correct, random_log.correct)
print(f"\nWilcoxon Signed-Rank Test:")
print(f"  Statistic: {wilcoxon_result['statistic']:.2f}")
print(f"  P-value: {wilcoxon_result['p_value']:.4f}")
print(f"  Significant (p < 0.05): {wilcoxon_result['significant']}")

# Effect size
effect = cohens_h(arch_mean, rand_mean)
print(f"\nEffect Size (Cohen's h): {effect:.3f}")
if abs(effect) < 0.2:
    effect_interpretation = "Small effect"
elif abs(effect) < 0.5:
    effect_interpretation = "Medium effect"
else:
    effect_interpretation = "Large effect"
print(f"  Interpretation: {effect_interpretation}")

# Coverage efficiency
def find_threshold_index(acc_list: List[float], threshold: float) -> Optional[int]:
    """Find first index where accuracy reaches threshold."""
    for i, a in enumerate(acc_list):
        if a >= threshold:
            return i
    return None

arch_50_idx = find_threshold_index(archetype_log.cumulative_accuracy, 0.5)
rand_50_idx = find_threshold_index(random_log.cumulative_accuracy, 0.5)

print(f"\nCoverage Efficiency:")
print(f"  Archetype-first reached 50% acc after: {arch_50_idx if arch_50_idx is not None else 'N/A'} problems")
print(f"  Random reached 50% acc after: {rand_50_idx if rand_50_idx is not None else 'N/A'} problems")

# Strategy churn
print(f"\nFinal Playbook Stats:")
if archetype_log.final_playbook:
    total_tags = sum(b.helpful_count + b.harmful_count for b in archetype_log.final_playbook.bullets)
    print(f"  Archetype-first: {archetype_log.final_playbook.size} strategies, {total_tags} total tags")
if random_log.final_playbook:
    total_tags = sum(b.helpful_count + b.harmful_count for b in random_log.final_playbook.bullets)
    print(f"  Random: {random_log.final_playbook.size} strategies, {total_tags} total tags")

In [None]:
# --- Final Summary ---

print("\n" + "=" * 80)
print("EXPERIMENT SUMMARY")
print("=" * 80)

# Determine comparison outcomes
outperforms = "outperforms" if arch_mean > rand_mean else "underperforms compared to"
transfer_quality = "reliable" if arch_mean > 0.5 else "limited"
ucb_effectiveness = "effectively" if archetype_log.ucb_gaps and np.mean(archetype_log.ucb_gaps) > 0.1 else "minimally"
significance_note = "(significant)" if wilcoxon_result['significant'] else "(not significant)"

total_llm_calls = sum(archetype_log.call_counts.values()) + sum(random_log.call_counts.values())
verification_calls = call_counter.get('verify', 0) + call_counter.get('perturb', 0) + call_counter.get('adversarial', 0)

print(f"""
Verified Archetype Discovery PoC Results
-----------------------------------------

Task: Game of 24 (mathematical reasoning)
Model: Qwen2.5-7B-Instruct via vLLM
Problems: {N_EVAL_PROBLEMS} total ({len(archetypes)} archetypes + {N_EVAL_PROBLEMS - len(archetypes)} others)

Archetype Selection:
- Clustered {len(problems)} problems into {N_CLUSTERS} clusters
- Selected top-10 archetypes by: centrality * diversity * simplicity
- Verified strategies with multi-path consistency, perturbation testing, adversarial probing

Results:
- Archetype-first accuracy: {archetype_log.final_accuracy:.2%}
- Random baseline accuracy: {random_log.final_accuracy:.2%}
- Improvement: {(archetype_log.final_accuracy - random_log.final_accuracy)*100:+.1f} percentage points
- Statistical significance: p = {wilcoxon_result['p_value']:.4f} {significance_note}

Key Findings:
1. Archetype-first curriculum {outperforms} random baseline
2. Verified strategies from archetypes provide {transfer_quality} transfer
3. LinUCB contextual bandit {ucb_effectiveness} balances exploration/exploitation

LLM Budget:
- Total calls: ~{total_llm_calls}
- Verification calls: ~{verification_calls}
""")

# Save results
results = {
    "archetype_log": {
        "correct": archetype_log.correct,
        "final_accuracy": archetype_log.final_accuracy,
        "cumulative_accuracy": archetype_log.cumulative_accuracy,
    },
    "random_log": {
        "correct": random_log.correct,
        "final_accuracy": random_log.final_accuracy,
        "cumulative_accuracy": random_log.cumulative_accuracy,
    },
    "statistics": {
        "arch_ci": (arch_mean, arch_lower, arch_upper),
        "rand_ci": (rand_mean, rand_lower, rand_upper),
        "wilcoxon": wilcoxon_result,
        "effect_size": effect,
    },
    "archetypes": [a["numbers"] for a in archetypes],
}

with open(CHECKPOINT_DIR / "final_results.pkl", "wb") as f:
    pickle.dump(results, f)

print(f"Results saved to {CHECKPOINT_DIR / 'final_results.pkl'}")

In [None]:
# --- Cleanup ---

print("Shutting down vLLM server...")
if vllm_proc:
    vllm_proc.terminate()
    try:
        vllm_proc.wait(timeout=10)
    except subprocess.TimeoutExpired:
        vllm_proc.kill()
    print(f"vLLM server (PID {vllm_proc.pid}) terminated.")

print("\nNotebook complete.")