# AIMO3 Baseline Notebook
## AI Mathematical Olympiad – Progress Prize 3

### Quick Start Guide
1. **RUN_MODE Selection**:
   - `"local_ref"`: Debug mode - runs evaluation on `reference.csv` (10 problems)
   - `"submit_auto"`: Kaggle submission mode - uses `kaggle_evaluation` API

2. **Model Setup (Kaggle)**:
   - Add your model as a Kaggle Dataset/Model input
   - Update `MODEL_PATH` in CONFIG to point to `/kaggle/input/your-model-name`
   - Or use Kaggle's built-in models

3. **Telemetry**:
   - Logs saved to `/kaggle/working/aimo3_telemetry.jsonl`
   - Download after submission for analysis

## CELL A — CONFIG (Constants)

In [None]:
# ================================
# CELL A — CONFIG (Constants)
# ================================

import os

# ----- RUN MODE AUTO DETECTION -----
# Automatically set to "submit_auto" if KAGGLE_IS_COMPETITION_RERUN is set
_is_kaggle_rerun = os.getenv("KAGGLE_IS_COMPETITION_RERUN") == "1"
if _is_kaggle_rerun:
    RUN_MODE = "submit_auto"
else:
    # Available modes:
    # "local_ref": Run evaluation on reference.csv (10 problems) - for debugging
    # "submit_auto": Kaggle submission mode - uses kaggle_evaluation API
    # "quick_tune": Fast Optuna tuning on reference subset (5-10 trials)
    # "full_tune": Full Optuna tuning on reference subset (20-50 trials)
    RUN_MODE = "local_ref"  # Change as needed

# ----- TIME BUDGET -----
TIME_BUDGET_SEC_PER_PROBLEM = 180  # Increased for complex problems

# ----- GENERATION PARAMS (defaults, can be tuned) -----
K_BASE = 8              # Increased: more candidates for voting
K_MAX_HARD = 16         # Max candidates for hard problems  
TEMPERATURE_BASE = 0.7  # Increased: more diversity
TEMPERATURE_HARD = 0.9  # Temperature for exploration
MAX_NEW_TOKENS = 4096   # Increased for longer reasoning
TOP_P = 0.95            # Nucleus sampling
TOP_K = 50              # Top-k sampling

# ----- VOTING & SELECTION -----
VOTING_THRESHOLD = 0.5  # Lowered: require more consensus
SELECTION_STRATEGY = "majority_vote"  # "majority_vote", "verifier_weighted", "consensus"

# ----- PATHS -----
LOG_PATH = "/kaggle/working/aimo3_telemetry.jsonl"
CACHE_DIR = "/kaggle/working/cache"
BEST_CONFIG_PATH = "/kaggle/working/best_config.json"

# ----- MODEL CONFIG -----
MODEL_PATH = "/kaggle/input/qwq-32b-preview/transformers/default/1"  # QwQ-32B model
MODEL_ID = "/kaggle/input/qwq-32b-preview/transformers/default/1"

# ----- PROMPT STYLE -----
# "strict_final": Two-pass with strict FINAL: tag
# "tir": Tool-Integrated Reasoning (original)
# "concise": Direct answer only
PROMPT_STYLE = "strict_final"

# ----- MODE POLICY -----
MODE_POLICY = "stable"  # "stable" or "diverse"

# ----- DATA PATHS -----
REFERENCE_CSV_PATH = None
TEST_CSV_PATH = None

# Detect environment and set paths
if os.path.exists("/kaggle/input"):
    REFERENCE_CSV_PATH = "/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv"
    TEST_CSV_PATH = "/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv"
    if not os.path.exists(REFERENCE_CSV_PATH):
        REFERENCE_CSV_PATH = "reference.csv"
        TEST_CSV_PATH = "test.csv"
else:
    REFERENCE_CSV_PATH = "reference.csv"
    TEST_CSV_PATH = "test.csv"

# Create working directories
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) if os.path.dirname(LOG_PATH) else None
os.makedirs(CACHE_DIR, exist_ok=True) if CACHE_DIR else None

# ----- TUNING PARAMS -----
TUNE_TRIALS_QUICK = 10
TUNE_TRIALS_FULL = 30
TUNE_TIMEOUT_SEC = 3600  # 1 hour max for tuning

print(f"RUN_MODE: {RUN_MODE} (auto-detected rerun: {_is_kaggle_rerun})")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"PROMPT_STYLE: {PROMPT_STYLE}")
print(f"K_BASE: {K_BASE}, TEMPERATURE_BASE: {TEMPERATURE_BASE}")

## CELL B — IMPORTS + SEED CONTROL

In [None]:
# ================================
# CELL B — IMPORTS + SEED CONTROL
# ================================

import os
import re
import sys
import time
import json
import math
import random
import hashlib
import warnings
from typing import Optional, Tuple, List, Dict, Any, Union
from collections import Counter
from contextlib import redirect_stdout, redirect_stderr
import io

import pandas as pd

warnings.filterwarnings("ignore")

# ----- SEED CONTROL -----
BASE_SEED = 42
IS_KAGGLE_RERUN = os.getenv("KAGGLE_IS_COMPETITION_RERUN") == "1"

def get_seed_for_problem(problem_id: str = None) -> int:
    """
    Get seed based on mode policy.
    - stable: Always return BASE_SEED (deterministic across runs)
    - diverse: Return deterministic seed based on problem_id hash (for rerun diversity)
    """
    if MODE_POLICY == "diverse" and problem_id is not None:
        # Deterministic hash-based seed for diversity
        hash_val = int(hashlib.md5(problem_id.encode()).hexdigest()[:8], 16)
        return BASE_SEED + (hash_val % 10000)
    else:
        return BASE_SEED

def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    try:
        import numpy as np
        np.random.seed(seed)
    except ImportError:
        pass
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except ImportError:
        pass

# Initialize with base seed (will be updated per-problem if diverse mode)
CURRENT_SEED = BASE_SEED
set_seed(CURRENT_SEED)

print(f"IS_KAGGLE_RERUN: {IS_KAGGLE_RERUN}")
print(f"MODE_POLICY: {MODE_POLICY}")
print(f"BASE_SEED: {BASE_SEED}")

## CELL C — LAZY MODEL LOADER

In [None]:
# ================================
# CELL C — LAZY MODEL LOADER
# ================================

_model_cache = {"model": None, "tokenizer": None, "device": None, "loaded": False, "skip_reason": None}

def get_device():
    try:
        import torch
        if torch.cuda.is_available():
            return "cuda"
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return "mps"
    except ImportError:
        pass
    return "cpu"

def load_model():
    """
    Lazy load model and tokenizer.
    In Kaggle rerun: always use local_files_only=True.
    If MODEL_PATH is None/missing: skip load immediately and log warning.
    """
    global _model_cache
    
    if _model_cache["loaded"]:
        return _model_cache["model"], _model_cache["tokenizer"], _model_cache["device"]
    
    # Check if we should skip loading
    has_local_model = MODEL_PATH is not None and os.path.exists(MODEL_PATH)
    
    if IS_KAGGLE_RERUN and not has_local_model:
        # In rerun mode without local model, skip loading
        print("WARNING: No local model available (MODEL_PATH is None or missing)")
        print("Skipping model load - using fallback solver")
        _model_cache["model"] = None
        _model_cache["tokenizer"] = None
        _model_cache["device"] = "cpu"
        _model_cache["loaded"] = True
        _model_cache["skip_reason"] = "no_local_model_in_rerun"
        return None, None, "cpu"
    
    print("Loading model...")
    start_time = time.time()
    
    try:
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        device = get_device()
        _model_cache["device"] = device
        print(f"Using device: {device}")
        
        # Determine model source and local_files_only setting
        if has_local_model:
            model_source = MODEL_PATH
            local_only = True
            print(f"Loading from local path: {model_source}")
        elif IS_KAGGLE_RERUN:
            # Should not reach here due to check above, but safety fallback
            print("WARNING: Cannot load model - no local path and in rerun mode")
            raise RuntimeError("No local model available in rerun mode")
        else:
            # Local development with internet - can download
            model_source = MODEL_ID
            local_only = False
            print(f"Loading from HuggingFace: {model_source}")
        
        _model_cache["tokenizer"] = AutoTokenizer.from_pretrained(
            model_source, trust_remote_code=True, local_files_only=local_only
        )
        
        dtype = torch.float16 if device == "cuda" else torch.float32
        _model_cache["model"] = AutoModelForCausalLM.from_pretrained(
            model_source, torch_dtype=dtype,
            device_map="auto" if device == "cuda" else None,
            trust_remote_code=True, local_files_only=local_only
        )
        
        if device != "cuda":
            _model_cache["model"] = _model_cache["model"].to(device)
        
        _model_cache["loaded"] = True
        print(f"Model loaded in {time.time() - start_time:.2f}s")
        
    except Exception as e:
        print(f"WARNING: Could not load model: {e}")
        print("Using rule-based solver...")
        _model_cache["model"] = None
        _model_cache["tokenizer"] = None
        _model_cache["device"] = "cpu"
        _model_cache["loaded"] = True
        _model_cache["skip_reason"] = str(e)
    
    return _model_cache["model"], _model_cache["tokenizer"], _model_cache["device"]

def is_model_available():
    if _model_cache["loaded"]:
        return _model_cache["model"] is not None
    if MODEL_PATH and os.path.exists(MODEL_PATH):
        return True
    return False

print("Model loader initialized (lazy loading)")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"IS_KAGGLE_RERUN: {IS_KAGGLE_RERUN}")

## CELL D — TOOL-INTEGRATED REASONING + SAFE PYTHON EXECUTOR

In [None]:
# ================================
# CELL D — TIR-lite + SAFE PYTHON EXECUTOR + STRICT PROMPTS
# ================================

ALLOWED_MODULES = {"math", "fractions", "itertools", "functools", "collections", "decimal", "numbers", "cmath", "random", "statistics"}

try:
    import sympy
    ALLOWED_MODULES.add("sympy")
except ImportError:
    pass

def create_safe_globals():
    """Create a safe globals dict with pre-loaded allowed modules."""
    import math, fractions, itertools, functools, collections, decimal, random, statistics
    
    safe_globals = {
        "__builtins__": {
            "abs": abs, "all": all, "any": any, "bin": bin, "bool": bool, "chr": chr,
            "dict": dict, "divmod": divmod, "enumerate": enumerate, "filter": filter,
            "float": float, "frozenset": frozenset, "hex": hex, "int": int,
            "isinstance": isinstance, "len": len, "list": list, "map": map, "max": max,
            "min": min, "oct": oct, "ord": ord, "pow": pow, "print": print, "range": range,
            "repr": repr, "reversed": reversed, "round": round, "set": set, "slice": slice,
            "sorted": sorted, "str": str, "sum": sum, "tuple": tuple, "type": type, "zip": zip,
            "True": True, "False": False, "None": None, "complex": complex,
        },
        "math": math,
        "fractions": fractions,
        "Fraction": fractions.Fraction,
        "itertools": itertools,
        "functools": functools,
        "collections": collections,
        "decimal": decimal,
        "Decimal": decimal.Decimal,
        "random": random,
        "statistics": statistics,
    }
    
    try:
        import sympy
        safe_globals["sympy"] = sympy
    except ImportError:
        pass
    
    return safe_globals

def strip_imports(code: str) -> str:
    """Remove import statements since modules are pre-loaded."""
    lines = code.split('\n')
    filtered = []
    for line in lines:
        stripped = line.strip()
        if stripped.startswith('import ') or stripped.startswith('from '):
            skip = False
            for mod in ALLOWED_MODULES:
                if mod in stripped:
                    skip = True
                    break
            if skip:
                continue
        filtered.append(line)
    return '\n'.join(filtered)

def run_python(code: str, timeout_sec: float = 10.0) -> Tuple[bool, str]:
    """Execute Python code in a sandboxed environment."""
    import signal
    
    code = strip_imports(code)
    output_capture = io.StringIO()
    safe_globals = create_safe_globals()
    safe_locals = {}
    
    def timeout_handler(signum, frame):
        raise TimeoutError("Code execution timed out")
    
    old_handler = signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(int(timeout_sec))
    
    try:
        with redirect_stdout(output_capture), redirect_stderr(output_capture):
            exec(code, safe_globals, safe_locals)
        
        output = output_capture.getvalue()
        
        for var_name in ["result", "answer", "ans", "final", "output"]:
            if var_name in safe_locals:
                val = safe_locals[var_name]
                if output:
                    output += f"\n{var_name} = {val}"
                else:
                    output = f"{var_name} = {val}"
                break
        
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)
        return True, output if output else "Execution completed (no output)"
        
    except TimeoutError as e:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)
        return False, f"Timeout: {str(e)}"
    except Exception as e:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)
        return False, f"Error: {type(e).__name__}: {str(e)}"

def parse_python_block(text: str) -> Optional[str]:
    """Extract Python code block from text."""
    patterns = [r"```python\s*\n(.*?)```", r"```py\s*\n(.*?)```", r"```\s*\n(.*?)```"]
    for pattern in patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            return match.group(1).strip()
    return None

# ============================================================
# NEW STRICT FINAL PROMPT - Forces model to use FINAL: tag
# ============================================================
def get_strict_final_prompt(problem: str) -> str:
    """
    Two-pass style prompt that enforces strict FINAL: <integer> format.
    This prevents the LASTINT fallback from grabbing wrong numbers.
    """
    return f"""You are solving a mathematical olympiad problem. Follow these rules EXACTLY:

PROBLEM:
{problem}

INSTRUCTIONS:
1. Work through the problem step by step, showing your reasoning.
2. If you need to compute something, you may use Python code in ```python ... ``` blocks.
3. After completing your solution, you MUST end with EXACTLY ONE line in this format:

FINAL: <your_integer_answer>

CRITICAL RULES:
- The FINAL line must contain ONLY the tag and a single non-negative integer
- The answer must be between 0 and 99999
- Do NOT put any other numbers on the FINAL line
- Do NOT include units, commas, or explanations on the FINAL line
- If the problem asks for a remainder mod N, give just that remainder

Example of correct final line: FINAL: 42
Example of INCORRECT: FINAL: The answer is 42
Example of INCORRECT: FINAL: 42 (mod 1000)

BEGIN YOUR SOLUTION:"""

def get_answer_only_prompt(problem: str, reasoning: str = "") -> str:
    """
    Second pass: extract just the final answer from reasoning.
    """
    return f"""Based on the following problem and solution, extract ONLY the final numerical answer.

PROBLEM:
{problem}

SOLUTION/REASONING:
{reasoning}

What is the final integer answer? Respond with ONLY a single integer on one line, nothing else.
FINAL:"""

def get_tir_prompt(problem: str) -> str:
    """Generate Tool-Integrated Reasoning prompt (original style)."""
    return f"""You are a mathematical problem solver. Solve the following problem step by step.

RULES:
1. Think through the problem carefully.
2. If you need to compute something, write Python code in a ```python ... ``` block.
3. After your reasoning, provide your final answer as: FINAL: <integer>
4. The answer must be an integer between 0 and 99999.

PROBLEM:
{problem}

SOLUTION:"""

def get_concise_prompt(problem: str) -> str:
    """Generate concise direct-answer prompt."""
    return f"""Solve this math problem. Give only the final integer answer (0-99999).

Problem: {problem}

FINAL:"""

def get_prompt(problem: str, style: str = None) -> str:
    """Get prompt based on configured style."""
    style = style or PROMPT_STYLE
    if style == "strict_final":
        return get_strict_final_prompt(problem)
    elif style == "tir":
        return get_tir_prompt(problem)
    elif style == "concise":
        return get_concise_prompt(problem)
    else:
        return get_strict_final_prompt(problem)

print("Safe Python executor initialized")
print(f"Prompt style: {PROMPT_STYLE}")
print(f"Available modules: {ALLOWED_MODULES}")

## CELL E — ANSWER EXTRACTION + VALIDATION

In [None]:
# ================================
# CELL E — ANSWER EXTRACTION + VALIDATION (STRICT PRIORITY)
# ================================

def clean_number_string(s: str) -> str:
    """Clean number string: remove commas, spaces, handle negatives."""
    s = s.strip()
    s = s.replace(",", "").replace(" ", "")
    # Handle negative (take absolute value for this competition)
    if s.startswith("-"):
        s = s[1:]
    return s

def extract_answer(text: str) -> Tuple[Optional[int], str]:
    """
    Extract integer answer from text with STRICT priority order.
    
    Priority (highest to lowest):
    1. FINAL: <int> - our strict format
    2. ANSWER: <int> - explicit answer tag
    3. The answer is <int> - explicit statement
    4. \\boxed{<int>} - LaTeX boxed
    5. Last standalone integer (LASTINT) - fallback only
    """
    if not text:
        return None, "empty"
    
    text = text.strip()
    
    # Priority 1: FINAL: pattern (STRICT - our enforced format)
    final_patterns = [
        r"FINAL\s*:\s*(\d+)\s*$",  # At end of line
        r"FINAL\s*:\s*(\d+)",        # Anywhere
        r"^\s*(\d+)\s*$",            # Just a number on its own line (for answer-only prompts)
    ]
    
    # Check each line from bottom up for FINAL pattern
    lines = text.strip().split("\n")
    for line in reversed(lines):
        line = line.strip()
        for pattern in final_patterns[:2]:  # Only FINAL patterns
            match = re.search(pattern, line, re.IGNORECASE)
            if match:
                try:
                    num_str = clean_number_string(match.group(1))
                    return int(num_str), "FINAL"
                except ValueError:
                    continue
    
    # Priority 2: ANSWER: pattern
    answer_patterns = [
        r"ANSWER\s*:\s*(\d+)",
        r"answer\s*:\s*(\d+)",
        r"Answer\s*:\s*(\d+)",
        r"final answer\s*:\s*(\d+)",
        r"Final answer\s*:\s*(\d+)",
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            try:
                num_str = clean_number_string(match.group(1))
                return int(num_str), "ANSWER"
            except ValueError:
                continue
    
    # Priority 3: "The answer is <int>" pattern
    the_answer_patterns = [
        r"[Tt]he\s+answer\s+is\s*:?\s*(\d+)",
        r"[Tt]he\s+final\s+answer\s+is\s*:?\s*(\d+)",
        r"[Aa]nswer\s*=\s*(\d+)",
    ]
    
    for pattern in the_answer_patterns:
        match = re.search(pattern, text)
        if match:
            try:
                num_str = clean_number_string(match.group(1))
                return int(num_str), "THE_ANSWER_IS"
            except ValueError:
                continue
    
    # Priority 4: \boxed{} pattern (LaTeX)
    boxed_patterns = [
        r"\\boxed\{(\d+)\}",
        r"\$\\boxed\{(\d+)\}\$",
        r"boxed\{(\d+)\}",
    ]
    
    for pattern in boxed_patterns:
        match = re.search(pattern, text)
        if match:
            try:
                num_str = clean_number_string(match.group(1))
                return int(num_str), "BOXED"
            except ValueError:
                continue
    
    # Priority 5: LASTINT fallback - but with more care
    # Only consider integers that look like answers (not dates, counts, etc.)
    # Look for integers at the end of the text or after "=" or ":"
    
    # First try: integers after "=" or "is" near the end
    late_text = text[-500:] if len(text) > 500 else text
    late_patterns = [
        r"=\s*(\d+)\s*$",
        r"=\s*(\d+)\s*[.\n]",
        r"is\s+(\d+)\s*[.\n]",
        r"get\s+(\d+)\s*[.\n]",
    ]
    
    for pattern in late_patterns:
        matches = list(re.finditer(pattern, late_text, re.IGNORECASE))
        if matches:
            try:
                num_str = clean_number_string(matches[-1].group(1))
                return int(num_str), "LASTINT_CONTEXT"
            except ValueError:
                continue
    
    # Final fallback: last integer in text
    integers = re.findall(r"\b(\d+)\b", text)
    if integers:
        try:
            num_str = clean_number_string(integers[-1])
            return int(num_str), "LASTINT"
        except ValueError:
            pass
    
    return None, "none"

def validate_answer(answer: Optional[int]) -> Tuple[bool, int]:
    """Validate and clamp answer to valid range [0, 99999]."""
    if answer is None:
        return False, 0
    
    if not isinstance(answer, (int, float)):
        try:
            answer = int(answer)
        except (ValueError, TypeError):
            return False, 0
    
    answer = int(answer)  # Ensure int type
    
    if 0 <= answer <= 99999:
        return True, answer
    
    # Clamp to valid range
    return False, max(0, min(99999, answer))

def safe_extract_answer(text: str) -> Tuple[int, Dict[str, Any]]:
    """Safely extract and validate answer, with fallback."""
    raw_answer, method = extract_answer(text)
    is_valid, final_answer = validate_answer(raw_answer)
    
    metadata = {
        "raw_answer": raw_answer,
        "method": method,
        "is_valid": is_valid,
        "fallback_used": not is_valid or method in ["LASTINT", "LASTINT_CONTEXT"],
    }
    
    if not is_valid:
        final_answer = 0
        metadata["fallback_value"] = 0
    
    return final_answer, metadata

print("Answer extraction functions initialized (strict priority)")

## CELL F — CANDIDATE GENERATION + SELF-CONSISTENCY

In [None]:
# ================================
# CELL F — CANDIDATE GENERATION + SELF-CONSISTENCY (IMPROVED)
# ================================

def generate_one(problem: str, temperature: float = 0.7, max_new_tokens: int = None, 
                 prompt_style: str = None, top_p: float = None, top_k: int = None) -> Tuple[str, Dict[str, Any]]:
    """Generate one solution candidate."""
    model, tokenizer, device = load_model()
    
    max_new_tokens = max_new_tokens or MAX_NEW_TOKENS
    prompt_style = prompt_style or PROMPT_STYLE
    top_p = top_p or TOP_P
    top_k = top_k or TOP_K
    
    prompt = get_prompt(problem, prompt_style)
    
    meta = {
        "prompt_style": prompt_style, 
        "temperature": temperature, 
        "max_new_tokens": max_new_tokens,
        "top_p": top_p,
        "top_k": top_k
    }
    
    if model is None:
        return "", meta
    
    try:
        import torch
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        gen_kwargs = {
            "max_new_tokens": max_new_tokens,
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
        }
        
        if temperature > 0:
            gen_kwargs.update({
                "do_sample": True,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
            })
        else:
            gen_kwargs["do_sample"] = False
        
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)
        
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract response after prompt
        if prompt in generated:
            response = generated[len(prompt):].strip()
        else:
            response = generated.strip()
        
        return response, meta
        
    except Exception as e:
        meta["error"] = str(e)
        return "", meta

def execute_code_in_response(response: str) -> str:
    """Execute Python code blocks in response and append output."""
    code = parse_python_block(response)
    if code:
        ok, output = run_python(code)
        if ok:
            response += f"\n\n[Code Output]\n{output}"
        else:
            response += f"\n\n[Code Error]\n{output}"
    return response

def generate_candidates(problem: str, k: int = None, temperature_schedule: List[float] = None,
                       config: Dict = None) -> List[Dict[str, Any]]:
    """
    Generate k solution candidates with optional temperature schedule.
    """
    k = k or K_BASE
    config = config or {}
    
    if temperature_schedule is None:
        # Varied temperature schedule for diversity
        base_temp = config.get("temperature", TEMPERATURE_BASE)
        temperature_schedule = [base_temp + 0.1 * (i % 3) for i in range(k)]
    
    candidates = []
    
    for i in range(k):
        temp = temperature_schedule[i] if i < len(temperature_schedule) else TEMPERATURE_BASE
        
        raw_text, meta = generate_one(
            problem, 
            temperature=temp,
            max_new_tokens=config.get("max_new_tokens", MAX_NEW_TOKENS),
            prompt_style=config.get("prompt_style", PROMPT_STYLE),
            top_p=config.get("top_p", TOP_P),
            top_k=config.get("top_k", TOP_K)
        )
        
        # Execute any code in response
        processed_text = execute_code_in_response(raw_text)
        
        # Extract answer
        answer, answer_meta = safe_extract_answer(processed_text)
        
        candidates.append({
            "answer": answer,
            "raw_text": raw_text,
            "processed_text": processed_text,
            "metadata": {**meta, **answer_meta},
            "candidate_idx": i,
        })
    
    return candidates

def vote_candidates(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Apply self-consistency voting to candidates."""
    if not candidates:
        return {"top_answer": 0, "top_count": 0, "total": 0, "vote_margin": 0.0, "entropy": 0.0, "answer_counts": {}}
    
    answers = [c["answer"] for c in candidates]
    counter = Counter(answers)
    total = len(answers)
    
    most_common = counter.most_common()
    top_answer, top_count = most_common[0]
    
    second_count = most_common[1][1] if len(most_common) > 1 else 0
    vote_margin = (top_count - second_count) / total
    
    # Calculate entropy for confidence estimation
    entropy = 0.0
    for count in counter.values():
        p = count / total
        if p > 0:
            entropy -= p * math.log2(p)
    
    # Track which extraction methods were used
    methods_used = Counter(c["metadata"].get("method", "none") for c in candidates)
    
    return {
        "top_answer": top_answer,
        "top_count": top_count,
        "total": total,
        "vote_margin": vote_margin,
        "entropy": entropy,
        "answer_counts": dict(counter.most_common(10)),
        "methods_used": dict(methods_used),
    }

def select_best_answer(candidates: List[Dict[str, Any]], vote_result: Dict[str, Any],
                       strategy: str = None) -> Tuple[int, Dict[str, Any]]:
    """
    Select best answer using configured strategy.
    """
    strategy = strategy or SELECTION_STRATEGY
    
    if not candidates:
        return 0, {"strategy": strategy, "reason": "no_candidates"}
    
    if strategy == "majority_vote":
        return vote_result["top_answer"], {"strategy": strategy, "confidence": vote_result["vote_margin"]}
    
    elif strategy == "verifier_weighted":
        # Weight by parse method reliability
        method_weights = {"FINAL": 1.0, "ANSWER": 0.9, "THE_ANSWER_IS": 0.8, "BOXED": 0.7, "LASTINT_CONTEXT": 0.5, "LASTINT": 0.3}
        
        weighted_votes = Counter()
        for c in candidates:
            method = c["metadata"].get("method", "LASTINT")
            weight = method_weights.get(method, 0.3)
            weighted_votes[c["answer"]] += weight
        
        best_answer = weighted_votes.most_common(1)[0][0]
        return best_answer, {"strategy": strategy, "weighted_votes": dict(weighted_votes.most_common(5))}
    
    elif strategy == "consensus":
        # Require stronger consensus
        if vote_result["vote_margin"] >= 0.5:
            return vote_result["top_answer"], {"strategy": strategy, "reason": "strong_consensus"}
        else:
            # Fall back to verifier weighted
            return select_best_answer(candidates, vote_result, "verifier_weighted")
    
    else:
        return vote_result["top_answer"], {"strategy": "fallback_majority"}

def should_early_stop(vote_result: Dict[str, Any], threshold: float = None) -> bool:
    """Check if we should stop early based on voting confidence."""
    threshold = threshold or VOTING_THRESHOLD
    if vote_result["total"] == 0:
        return False
    return vote_result["top_count"] / vote_result["total"] >= threshold

print("Candidate generation functions initialized")
print(f"Selection strategy: {SELECTION_STRATEGY}")

## CELL G — VERIFIER

In [None]:
# ================================
# CELL G — VERIFIER (Rule-based + Optional LLM)
# ================================

def rule_verifier(problem: str, answer: int) -> Dict[str, Any]:
    checks = []
    passed = True
    reason = "OK"
    
    if not (0 <= answer <= 99999):
        passed = False
        reason = f"Answer {answer} out of valid range [0, 99999]"
        checks.append(("range_check", False, reason))
    else:
        checks.append(("range_check", True, "In valid range"))
    
    problem_lower = problem.lower()
    if "remainder" in problem_lower or "modulo" in problem_lower or "mod " in problem_lower:
        mod_patterns = [r"divided by\s+(\d+)", r"modulo\s+(\d+)", r"mod\s+(\d+)", r"\(mod\s*(\d+)\)"]
        for pattern in mod_patterns:
            match = re.search(pattern, problem_lower)
            if match:
                mod_val = int(match.group(1))
                if answer >= mod_val and mod_val < 100000:
                    checks.append(("mod_check", False, f"Answer {answer} >= modulo {mod_val}"))
                else:
                    checks.append(("mod_check", True, f"Answer {answer} < modulo {mod_val}"))
                break
    
    return {"passed": passed, "reason": reason, "checks": checks}

def llm_verifier(problem: str, answer: int) -> Dict[str, Any]:
    model, tokenizer, device = load_model()
    
    if model is None:
        return {"passed": True, "reason": "LLM not available", "response": ""}
    
    prompt = f"""Given this math problem and proposed answer, quickly check if the answer could be correct.
If you find a clear error or contradiction, say INVALID. Otherwise say VALID.

Problem: {problem}

Proposed Answer: {answer}

Verification (VALID or INVALID):"""
    
    try:
        import torch
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.1, do_sample=False, pad_token_id=tokenizer.eos_token_id)
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
        passed = "INVALID" not in response.upper()
        
        return {"passed": passed, "reason": "VALID" if passed else "INVALID found", "response": response[:200]}
        
    except Exception as e:
        return {"passed": True, "reason": f"Error: {str(e)}", "response": ""}

def verify_answer(problem: str, answer: int, use_llm: bool = False) -> Dict[str, Any]:
    result = {"rule_verifier": rule_verifier(problem, answer), "llm_verifier": None, "final_passed": True}
    
    if not result["rule_verifier"]["passed"]:
        result["final_passed"] = False
    
    if use_llm and result["rule_verifier"]["passed"]:
        result["llm_verifier"] = llm_verifier(problem, answer)
        if not result["llm_verifier"]["passed"]:
            result["final_passed"] = False
    
    return result

print("Verifier functions initialized")

## CELL H — SOLVER ORCHESTRATOR

In [None]:
# ================================
# CELL H — SOLVER ORCHESTRATOR (IMPROVED)
# ================================

def solve_problem(problem_id: str, problem_text: str, config: Dict = None) -> Tuple[int, Dict[str, Any]]:
    """
    Main solver function with configurable parameters for tuning.
    """
    global CURRENT_SEED
    start_time = time.time()
    config = config or {}
    
    # Set seed based on mode policy
    CURRENT_SEED = get_seed_for_problem(problem_id)
    set_seed(CURRENT_SEED)
    
    # Get parameters from config or defaults
    k = config.get("k", K_BASE)
    temperature = config.get("temperature", TEMPERATURE_BASE)
    max_new_tokens = config.get("max_new_tokens", MAX_NEW_TOKENS)
    selection_strategy = config.get("selection_strategy", SELECTION_STRATEGY)
    
    # Initialize telemetry
    telemetry = {
        "id": problem_id,
        "problem_hash": hashlib.md5(problem_text.encode()).hexdigest()[:8],
        "elapsed_sec": 0,
        "k_used": 0,
        "candidates_summary": [],
        "chosen_answer": 0,
        "vote_margin": 0.0,
        "vote_entropy": 0.0,
        "verifier_used": False,
        "verifier_pass": True,
        "tool_calls_count": 0,
        "parse_method_used": "none",
        "methods_distribution": {},
        "difficulty_mode": "NORMAL",
        "temperature_used": temperature,
        "seed_used": CURRENT_SEED,
        "run_policy": MODE_POLICY,
        "is_rerun": IS_KAGGLE_RERUN,
        "model_available": is_model_available(),
        "config_used": config,
        "selection_strategy": selection_strategy,
    }
    
    try:
        # Temperature schedule with variation
        temp_schedule = [temperature + 0.1 * (i % 3) for i in range(k)]
        telemetry["temperature_schedule"] = temp_schedule[:5]  # Log first 5
        
        # Generate candidates with time budget awareness
        candidates = []
        time_per_candidate = (TIME_BUDGET_SEC_PER_PROBLEM * 0.7) / k
        
        for i in range(k):
            elapsed = time.time() - start_time
            remaining = TIME_BUDGET_SEC_PER_PROBLEM - elapsed
            
            if remaining < time_per_candidate * 0.5:
                break  # Stop if not enough time for another candidate
            
            raw_text, meta = generate_one(
                problem_text, 
                temperature=temp_schedule[i],
                max_new_tokens=max_new_tokens,
                prompt_style=config.get("prompt_style", PROMPT_STYLE)
            )
            processed_text = execute_code_in_response(raw_text)
            answer, answer_meta = safe_extract_answer(processed_text)
            
            candidates.append({
                "answer": answer,
                "raw_text_preview": raw_text[:500],
                "gen_metadata": {**meta, **answer_meta}
            })
            
            if parse_python_block(raw_text):
                telemetry["tool_calls_count"] += 1
            
            # Early stop check
            if len(candidates) >= 3:
                partial_vote = vote_candidates(candidates)
                if should_early_stop(partial_vote, 0.8):  # Strong consensus
                    break
        
        telemetry["k_used"] = len(candidates)
        
        # Vote on candidates
        vote_result = vote_candidates(candidates)
        telemetry["vote_margin"] = vote_result["vote_margin"]
        telemetry["vote_entropy"] = vote_result["entropy"]
        telemetry["candidates_summary"] = [(a, c) for a, c in vote_result["answer_counts"].items()]
        telemetry["methods_distribution"] = vote_result.get("methods_used", {})
        
        # Select best answer
        chosen_answer, selection_info = select_best_answer(candidates, vote_result, selection_strategy)
        telemetry["selection_info"] = selection_info
        
        # Find parse method used for chosen answer
        for c in candidates:
            if c["answer"] == chosen_answer:
                telemetry["parse_method_used"] = c["gen_metadata"].get("method", "none")
                break
        
        # Optional verification for low confidence
        if vote_result["vote_margin"] < 0.3:
            verification = verify_answer(problem_text, chosen_answer, use_llm=False)
            telemetry["verifier_used"] = True
            telemetry["verifier_pass"] = verification["final_passed"]
            
            if not verification["final_passed"]:
                # Try next best answer
                for answer, count in vote_result["answer_counts"].items():
                    if answer != chosen_answer:
                        alt_verify = verify_answer(problem_text, answer, use_llm=False)
                        if alt_verify["final_passed"]:
                            chosen_answer = answer
                            telemetry["switched_answer"] = True
                            break
        
        # Final validation
        is_valid, final_answer = validate_answer(chosen_answer)
        if not is_valid:
            telemetry["used_fallback"] = True
            final_answer = 0
        
        telemetry["chosen_answer"] = final_answer
        
    except Exception as e:
        telemetry["error_message"] = str(e)
        final_answer = 0
        telemetry["chosen_answer"] = final_answer
        telemetry["used_fallback"] = True
    
    telemetry["elapsed_sec"] = time.time() - start_time
    
    return final_answer, telemetry

print("Solver orchestrator initialized (improved)")

## CELL I — TELEMETRY LOGGER

In [None]:
# ================================
# CELL I — TELEMETRY LOGGER
# ================================

def append_jsonl(filepath: str, data: Dict[str, Any]):
    os.makedirs(os.path.dirname(filepath), exist_ok=True) if os.path.dirname(filepath) else None
    with open(filepath, "a") as f:
        f.write(json.dumps(data, default=str) + "\n")
        f.flush()

def read_telemetry(filepath: str) -> List[Dict[str, Any]]:
    entries = []
    if os.path.exists(filepath):
        with open(filepath, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    try:
                        entries.append(json.loads(line))
                    except json.JSONDecodeError:
                        pass
    return entries

def log_telemetry(telemetry: Dict[str, Any]):
    append_jsonl(LOG_PATH, telemetry)

def print_telemetry_summary(telemetry_list: List[Dict[str, Any]]):
    if not telemetry_list:
        print("No telemetry data available")
        return
    
    n = len(telemetry_list)
    total_time = sum(t.get("elapsed_sec", 0) for t in telemetry_list)
    avg_time = total_time / n if n > 0 else 0
    
    parse_methods = Counter(t.get("parse_method_used", "none") for t in telemetry_list)
    parse_fail_rate = parse_methods.get("none", 0) / n if n > 0 else 0
    
    k_values = [t.get("k_used", 0) for t in telemetry_list]
    avg_k = sum(k_values) / n if n > 0 else 0
    
    verifier_used = sum(1 for t in telemetry_list if t.get("verifier_used", False))
    verifier_pass = sum(1 for t in telemetry_list if t.get("verifier_pass", True))
    
    fallback_used = sum(1 for t in telemetry_list if t.get("fallback_used", False))
    
    print("\n" + "="*50)
    print("TELEMETRY SUMMARY")
    print("="*50)
    print(f"Total problems: {n}")
    print(f"Total time: {total_time:.2f}s")
    print(f"Avg time per problem: {avg_time:.2f}s")
    print(f"Parse fail rate: {parse_fail_rate:.2%}")
    print(f"Avg k_used: {avg_k:.1f}")
    print(f"K distribution: {Counter(k_values)}")
    print(f"Parse methods: {dict(parse_methods)}")
    print(f"Verifier usage: {verifier_used}/{n}")
    print(f"Verifier pass rate: {verifier_pass}/{verifier_used if verifier_used > 0 else 1}")
    print(f"Fallback used: {fallback_used}/{n}")
    print("="*50)

print("Telemetry logger initialized")
print(f"Log path: {LOG_PATH}")

## CELL J — LOCAL HARNESS (Reference CSV Regression)

In [None]:
# ================================
# CELL J — LOCAL HARNESS (Reference CSV Regression)
# ================================

def run_reference_eval(csv_path: str = None, limit: int = None) -> Dict[str, Any]:
    csv_path = csv_path or REFERENCE_CSV_PATH
    
    if not os.path.exists(csv_path):
        print(f"Reference CSV not found: {csv_path}")
        return {"error": "File not found", "accuracy": 0.0}
    
    print(f"\nRunning reference evaluation on: {csv_path}")
    print("="*60)
    
    df = pd.read_csv(csv_path)
    
    if limit:
        df = df.head(limit)
    
    n_problems = len(df)
    print(f"Evaluating {n_problems} problems...\n")
    
    results = []
    correct = 0
    telemetry_list = []
    
    for idx, row in df.iterrows():
        problem_id = str(row["id"])
        problem_text = row["problem"]
        expected_answer = int(row["answer"])
        
        print(f"[{idx+1}/{n_problems}] Problem {problem_id}...")
        
        predicted_answer, telemetry = solve_problem(problem_id, problem_text)
        
        telemetry["expected_answer"] = expected_answer
        telemetry["is_correct"] = (predicted_answer == expected_answer)
        log_telemetry(telemetry)
        telemetry_list.append(telemetry)
        
        is_correct = (predicted_answer == expected_answer)
        if is_correct:
            correct += 1
            status = "Y"
        else:
            status = "X"
        
        print(f"  {status} Predicted: {predicted_answer}, Expected: {expected_answer} ({telemetry['elapsed_sec']:.2f}s)")
        
        results.append({
            "id": problem_id, "predicted": predicted_answer,
            "expected": expected_answer, "correct": is_correct,
            "elapsed_sec": telemetry["elapsed_sec"],
        })
    
    accuracy = correct / n_problems if n_problems > 0 else 0.0
    
    print("\n" + "="*60)
    print(f"ACCURACY: {correct}/{n_problems} = {accuracy:.2%}")
    print("="*60)
    
    print_telemetry_summary(telemetry_list)
    
    return {"accuracy": accuracy, "correct": correct, "total": n_problems, "results": results}

print("Local harness initialized")

## CELL K — SUBMISSION GLUE (Kaggle Evaluation API)

In [None]:
# ================================
# CELL K — SUBMISSION GLUE (Kaggle Evaluation API)
# ================================

import sys
import os

kaggle_eval_paths = ["/kaggle/input/kaggle-evaluation", "/kaggle/input", ".", ".."]

for path in kaggle_eval_paths:
    if os.path.exists(os.path.join(path, "kaggle_evaluation")):
        sys.path.insert(0, path)
        break

def predict(test_input: Union[pd.DataFrame, dict, pd.Series]) -> pd.DataFrame:
    """
    Kaggle prediction endpoint.
    Accepts DataFrame, dict, or Series as input.
    Returns DataFrame with columns 'id' and 'answer'.
    """
    # Convert input to DataFrame if needed
    if isinstance(test_input, dict):
        test_df = pd.DataFrame([test_input])
    elif isinstance(test_input, pd.Series):
        test_df = pd.DataFrame([test_input.to_dict()])
    elif isinstance(test_input, pd.DataFrame):
        test_df = test_input
    else:
        raise ValueError(f"predict() expects DataFrame, dict, or Series, got {type(test_input)}")
    
    # Validate required columns
    required_cols = {"id", "problem"}
    missing_cols = required_cols - set(test_df.columns)
    if missing_cols:
        raise ValueError(f"Input missing required columns: {missing_cols}")
    
    results = []
    
    for idx, row in test_df.iterrows():
        problem_id = str(row["id"])
        problem_text = str(row["problem"])
        
        answer, telemetry = solve_problem(problem_id, problem_text)
        log_telemetry(telemetry)
        
        results.append({"id": problem_id, "answer": int(answer)})
    
    return pd.DataFrame(results)

def setup_and_serve():
    """
    Setup and start the inference server.
    Handles missing run_local_gateway gracefully.
    """
    try:
        from kaggle_evaluation.aimo_3_inference_server import AIMO3InferenceServer
        
        server = AIMO3InferenceServer(predict)
        
        if os.getenv("KAGGLE_IS_COMPETITION_RERUN") == "1":
            print("Starting inference server (competition mode)...")
            server.serve()
        else:
            # Try run_local_gateway, but handle if it doesn't exist
            print("Attempting local gateway test...")
            if hasattr(server, 'run_local_gateway'):
                try:
                    server.run_local_gateway()
                except Exception as e:
                    print(f"run_local_gateway failed: {e}")
                    print("Falling back to direct serve()...")
                    server.serve()
            else:
                print("run_local_gateway not available, using serve()...")
                server.serve()
            
    except ImportError as e:
        print(f"kaggle_evaluation not available: {e}")
        print("Running in local-only mode")
    except Exception as e:
        print(f"Error in setup_and_serve: {e}")
        print("Running in local-only mode")

print("Submission glue initialized")
print("predict() accepts: DataFrame, dict, or Series")
print("Use setup_and_serve() to start the server")

## SELF TEST — Unit Tests

In [None]:
# ================================
# SELF TEST 1: Schema/Predict Unit Test
# ================================

def test_predict_schema():
    print("\nTEST 1: Schema/Predict Unit Test")
    print("-"*40)
    
    test_df = pd.DataFrame({"id": ["test001"], "problem": ["What is $1+1$?"]})
    result_df = predict(test_df)
    
    assert "id" in result_df.columns, "Missing 'id' column"
    assert "answer" in result_df.columns, "Missing 'answer' column"
    assert len(result_df) == 1, f"Expected 1 row, got {len(result_df)}"
    
    answer = result_df["answer"].iloc[0]
    assert isinstance(answer, (int, type(1))), f"Answer should be int, got {type(answer)}"
    assert 0 <= answer <= 99999, f"Answer {answer} out of range [0, 99999]"
    
    print(f"OK Output schema correct")
    print(f"OK Answer: {answer} (valid int in [0, 99999])")
    print("TEST 1 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_predict_schema()
    except AssertionError as e:
        print(f"TEST 1 FAILED: {e}")
    except Exception as e:
        print(f"TEST 1 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 2: Reference Evaluation
# ================================

def test_reference_eval():
    print("\nTEST 2: Reference Evaluation")
    print("-"*40)
    
    result = run_reference_eval(limit=2)
    
    assert "accuracy" in result, "Missing 'accuracy' in result"
    assert "total" in result, "Missing 'total' in result"
    assert result["total"] == 2, f"Expected 2 problems, got {result['total']}"
    
    print(f"OK Accuracy: {result['accuracy']:.2%}")
    print(f"OK No crashes during evaluation")
    print("TEST 2 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        if os.path.exists(REFERENCE_CSV_PATH):
            test_reference_eval()
        else:
            print("Skipping TEST 2: reference.csv not found")
    except AssertionError as e:
        print(f"TEST 2 FAILED: {e}")
    except Exception as e:
        print(f"TEST 2 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 3: Tool Executor
# ================================

def test_tool_executor():
    print("\nTEST 3: Tool Executor")
    print("-"*40)
    
    # Test 1: Simple arithmetic
    ok, output = run_python("result = 2 + 3")
    assert ok, f"Execution failed: {output}"
    assert "5" in output, f"Expected '5' in output, got: {output}"
    print(f"OK Simple arithmetic: 2+3 = 5")
    
    # Test 2: Math module (already loaded, no import needed)
    ok, output = run_python("result = math.factorial(5)")
    assert ok, f"Execution failed: {output}"
    assert "120" in output, f"Expected '120' in output, got: {output}"
    print(f"OK Math module: factorial(5) = 120")
    
    # Test 3: Code with import statement (should be stripped)
    ok, output = run_python("import math\nresult = math.sqrt(16)")
    assert ok, f"Execution failed: {output}"
    assert "4" in output, f"Expected '4' in output, got: {output}"
    print(f"OK Import stripping works: sqrt(16) = 4")
    
    # Test 4: Timeout handling
    ok, output = run_python("x = 1", timeout_sec=1)
    assert ok, f"Simple code should not timeout"
    print(f"OK Timeout handling works")
    
    print("TEST 3 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_tool_executor()
    except AssertionError as e:
        print(f"TEST 3 FAILED: {e}")
    except Exception as e:
        print(f"TEST 3 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 4: Answer Extraction
# ================================

def test_answer_extraction():
    print("\nTEST 4: Answer Extraction")
    print("-"*40)
    
    test_cases = [
        ("ANSWER: 42", 42, "ANSWER"),
        ("The answer is 123", 123, "THE_ANSWER_IS"),
        ("After calculation, we get 456. ANSWER: 456", 456, "ANSWER"),
        ("Result: 789", 789, "LASTINT"),
        ("The final answer is 999", 999, "THE_ANSWER_IS"),
    ]
    
    for text, expected_answer, expected_method in test_cases:
        answer, method = extract_answer(text)
        assert answer == expected_answer, f"Expected {expected_answer}, got {answer} for '{text}'"
        assert method == expected_method, f"Expected method {expected_method}, got {method} for '{text}'"
        print(f"OK '{text[:30]}...' -> {answer} ({method})")
    
    print("TEST 4 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_answer_extraction()
    except AssertionError as e:
        print(f"TEST 4 FAILED: {e}")
    except Exception as e:
        print(f"TEST 4 ERROR: {e}")

In [None]:
# ================================
# SELF TEST 5: predict() Input Types
# ================================

def test_predict_input_types():
    print("\nTEST 5: predict() Input Types")
    print("-"*40)
    
    # Test 1: DataFrame input
    df_input = pd.DataFrame({"id": ["test_df"], "problem": ["What is $2+2$?"]})
    result = predict(df_input)
    assert isinstance(result, pd.DataFrame), "Result should be DataFrame"
    assert "answer" in result.columns, "Missing 'answer' column"
    print("OK DataFrame input works")
    
    # Test 2: dict input
    dict_input = {"id": "test_dict", "problem": "What is $3+3$?"}
    result = predict(dict_input)
    assert isinstance(result, pd.DataFrame), "Result should be DataFrame"
    assert len(result) == 1, "Should have 1 row"
    print("OK dict input works")
    
    # Test 3: Series input
    series_input = pd.Series({"id": "test_series", "problem": "What is $4+4$?"})
    result = predict(series_input)
    assert isinstance(result, pd.DataFrame), "Result should be DataFrame"
    assert len(result) == 1, "Should have 1 row"
    print("OK Series input works")
    
    print("TEST 5 PASSED\n")
    return True

if RUN_MODE == "local_ref":
    try:
        test_predict_input_types()
    except AssertionError as e:
        print(f"TEST 5 FAILED: {e}")
    except Exception as e:
        print(f"TEST 5 ERROR: {e}")

## TUNING MODULE — Optuna Hyperparameter Optimization

In [None]:
# ================================
# TUNING MODULE — Optuna Hyperparameter Optimization
# ================================

def run_optuna_tuning(n_trials: int = 10, timeout: int = 3600) -> Dict[str, Any]:
    """
    Run Optuna hyperparameter tuning on reference.csv subset.
    
    Search space:
    - k: number of candidates (4-16)
    - temperature: sampling temperature (0.3-1.0)
    - max_new_tokens: generation length (1024-4096)
    - prompt_style: prompt template type
    - selection_strategy: answer selection method
    - top_p: nucleus sampling (0.8-1.0)
    """
    try:
        import optuna
        from optuna.pruners import MedianPruner
        from optuna.samplers import TPESampler
    except ImportError:
        print("Installing optuna...")
        import subprocess
        subprocess.run(["pip", "install", "optuna", "-q"])
        import optuna
        from optuna.pruners import MedianPruner
        from optuna.samplers import TPESampler
    
    # Load reference data
    if not os.path.exists(REFERENCE_CSV_PATH):
        print(f"Reference CSV not found: {REFERENCE_CSV_PATH}")
        return {"error": "File not found"}
    
    df = pd.read_csv(REFERENCE_CSV_PATH)
    n_problems = len(df)
    print(f"Tuning on {n_problems} reference problems")
    
    def objective(trial):
        """Optuna objective function."""
        # Sample hyperparameters
        config = {
            "k": trial.suggest_int("k", 4, 12),
            "temperature": trial.suggest_float("temperature", 0.3, 0.9),
            "max_new_tokens": trial.suggest_int("max_new_tokens", 1024, 3072, step=512),
            "prompt_style": trial.suggest_categorical("prompt_style", ["strict_final", "tir"]),
            "selection_strategy": trial.suggest_categorical("selection_strategy", 
                                                            ["majority_vote", "verifier_weighted", "consensus"]),
            "top_p": trial.suggest_float("top_p", 0.85, 0.98),
        }
        
        # Evaluate on reference problems
        correct = 0
        total_time = 0
        
        for idx, row in df.iterrows():
            problem_id = str(row["id"])
            problem_text = row["problem"]
            expected = int(row["answer"])
            
            predicted, telemetry = solve_problem(problem_id, problem_text, config=config)
            
            if predicted == expected:
                correct += 1
            
            total_time += telemetry.get("elapsed_sec", 0)
            
            # Report intermediate value for pruning
            trial.report(correct / (idx + 1), idx)
            
            # Pruning: stop if clearly bad
            if trial.should_prune():
                raise optuna.TrialPruned()
        
        accuracy = correct / n_problems
        avg_time = total_time / n_problems
        
        # Penalize slow configs
        time_penalty = max(0, (avg_time - TIME_BUDGET_SEC_PER_PROBLEM) / TIME_BUDGET_SEC_PER_PROBLEM)
        score = accuracy - 0.1 * time_penalty
        
        return score
    
    # Create study
    sampler = TPESampler(seed=BASE_SEED)
    pruner = MedianPruner(n_startup_trials=3, n_warmup_steps=2)
    
    study = optuna.create_study(
        direction="maximize",
        sampler=sampler,
        pruner=pruner,
        study_name="aimo3_tuning"
    )
    
    print(f"Starting Optuna tuning: {n_trials} trials, {timeout}s timeout")
    
    try:
        study.optimize(objective, n_trials=n_trials, timeout=timeout, show_progress_bar=True)
    except KeyboardInterrupt:
        print("Tuning interrupted")
    
    # Get best config
    best_config = study.best_params
    best_score = study.best_value
    
    print(f"\nBest config (score={best_score:.4f}):")
    for k, v in best_config.items():
        print(f"  {k}: {v}")
    
    # Save best config
    result = {
        "best_config": best_config,
        "best_score": best_score,
        "n_trials": len(study.trials),
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    
    with open(BEST_CONFIG_PATH, "w") as f:
        json.dump(result, f, indent=2)
    
    print(f"\nBest config saved to: {BEST_CONFIG_PATH}")
    
    return result

def load_best_config() -> Dict[str, Any]:
    """Load best config from previous tuning run."""
    if os.path.exists(BEST_CONFIG_PATH):
        with open(BEST_CONFIG_PATH) as f:
            data = json.load(f)
        print(f"Loaded best config (score={data.get('best_score', 'N/A')})")
        return data.get("best_config", {})
    else:
        print("No saved config found, using defaults")
        return {}

print("Tuning module initialized")
print(f"Quick tune: {TUNE_TRIALS_QUICK} trials")
print(f"Full tune: {TUNE_TRIALS_FULL} trials")

## MAIN EXECUTION

In [None]:
# ================================
# MAIN EXECUTION
# ================================

if __name__ == "__main__":
    print("\n" + "="*60)
    print("AIMO3 BASELINE NOTEBOOK - IMPROVED")
    print("="*60)
    print(f"RUN_MODE: {RUN_MODE}")
    print(f"CURRENT_SEED: {CURRENT_SEED}")
    print(f"IS_KAGGLE_RERUN: {IS_KAGGLE_RERUN}")
    print(f"PROMPT_STYLE: {PROMPT_STYLE}")
    print(f"K_BASE: {K_BASE}, TEMPERATURE_BASE: {TEMPERATURE_BASE}")
    print("="*60 + "\n")
    
    # Try to load best config from tuning
    best_config = load_best_config() if os.path.exists(BEST_CONFIG_PATH) else {}
    if best_config:
        print(f"Using tuned config: {best_config}")
    
    if RUN_MODE == "local_ref":
        print("Running in LOCAL/DEBUG mode...")
        print("Evaluating reference.csv...\n")
        
        if os.path.exists(LOG_PATH):
            os.remove(LOG_PATH)
        
        if os.path.exists(REFERENCE_CSV_PATH):
            result = run_reference_eval()
            print(f"\nFinal Accuracy: {result['accuracy']:.2%}")
        else:
            print(f"Reference CSV not found: {REFERENCE_CSV_PATH}")
            print("Running self-tests only...")
        
    elif RUN_MODE == "quick_tune":
        print("Running QUICK TUNE mode...")
        result = run_optuna_tuning(n_trials=TUNE_TRIALS_QUICK, timeout=TUNE_TIMEOUT_SEC // 2)
        
        # Validate with best config
        if "best_config" in result:
            print("\nValidating with best config...")
            # Re-run reference eval with best config
            # (would need to pass config through, simplified here)
    
    elif RUN_MODE == "full_tune":
        print("Running FULL TUNE mode...")
        result = run_optuna_tuning(n_trials=TUNE_TRIALS_FULL, timeout=TUNE_TIMEOUT_SEC)
        
    elif RUN_MODE == "submit_auto":
        print("Running in SUBMISSION mode...")
        print("Starting inference server...\n")
        setup_and_serve()
    
    else:
        print(f"Unknown RUN_MODE: {RUN_MODE}")
        print("Valid options: 'local_ref', 'submit_auto', 'quick_tune', 'full_tune'")

print("\nNotebook execution complete.")