# LEMMA Solver for AIMO3

Lemma-Level Decomposition and Verification for Olympiad Mathematics

**Strategy:**
1. Decompose problem into verifiable lemmas
2. Verify each lemma with Python execution
3. Backtrack and repair on verification failure
4. Synthesize final answer from verified lemmas

**Model Configuration:**
- For testing: Set `USE_SMALL_MODEL = True` (loads ~1-3B model from HF)
- For competition: Set `USE_SMALL_MODEL = False` (uses attached GPT-120B)

In [None]:
# ============================================================
# CONFIGURATION - Modify these for your setup
# ============================================================

USE_SMALL_MODEL = True  # Set to False on Kaggle with GPT-120B

# Small model for testing (1.5B parameters, fast download)
SMALL_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# Kaggle model path (when USE_SMALL_MODEL = False)
KAGGLE_MODEL_PATH = "/kaggle/input/gpt-oss-120b/transformers/default/1"

# Inference settings
MAX_TOKENS_PER_TURN = 4096
MAX_TURNS = 32
TEMPERATURE = 0.7
TOP_P = 0.95

# Lemma verification
MAX_LEMMA_RETRIES = 3
MAX_LEMMAS = 8

# Parallel attempts (set to 1 for pure sequential, 4-8 for hybrid)
PARALLEL_ATTEMPTS = 4

# Timeouts
PROBLEM_TIMEOUT = 300  # seconds per problem
PYTHON_TIMEOUT = 10    # seconds per code execution

print(f"Configuration:")
print(f"  Use small model: {USE_SMALL_MODEL}")
print(f"  Parallel attempts: {PARALLEL_ATTEMPTS}")
print(f"  Max turns per attempt: {MAX_TURNS}")
print(f"  Max lemma retries: {MAX_LEMMA_RETRIES}")

In [None]:
# ============================================================
# IMPORTS
# ============================================================

import os
import re
import sys
import json
import math
import time
import queue
import threading
import subprocess
import warnings
from typing import Optional, List, Dict, Tuple, Any
from dataclasses import dataclass, field
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
import polars as pl

warnings.filterwarnings('ignore')

print("Base imports done.")

In [None]:
# ============================================================
# INSTALL VLLM (if needed)
# ============================================================

try:
    import vllm
    print("vLLM already installed")
except ImportError:
    print("Installing vLLM...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "vllm"])
    print("vLLM installed")

from vllm import LLM, SamplingParams
print("vLLM imported successfully")

In [None]:
# ============================================================
# JUPYTER SANDBOX FOR CODE EXECUTION
# ============================================================

try:
    from jupyter_client import KernelManager
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "jupyter_client"])
    from jupyter_client import KernelManager

class JupyterSandbox:
    """Safe Python code execution environment."""
    
    _port_lock = threading.Lock()
    _next_port = 50000
    
    @classmethod
    def _get_ports(cls, count: int = 5) -> List[int]:
        with cls._port_lock:
            ports = list(range(cls._next_port, cls._next_port + count))
            cls._next_port += count
            return ports
    
    def __init__(self, timeout: float = 10.0):
        self.timeout = timeout
        self._km = None
        self._client = None
        
        ports = self._get_ports(5)
        env = os.environ.copy()
        env.update({
            'PYDEVD_DISABLE_FILE_VALIDATION': '1',
            'PYTHONWARNINGS': 'ignore',
            'MPLBACKEND': 'Agg'
        })
        
        self._km = KernelManager()
        self._km.shell_port = ports[0]
        self._km.iopub_port = ports[1]
        self._km.stdin_port = ports[2]
        self._km.hb_port = ports[3]
        self._km.control_port = ports[4]
        
        self._km.start_kernel(env=env)
        self._client = self._km.blocking_client()
        self._client.start_channels()
        self._client.wait_for_ready(timeout=30)
        
        # Initialize with math libraries
        self.execute('''
import math
import numpy as np
import sympy as sp
from sympy import symbols, expand, factor, simplify, solve, Eq
import itertools
from collections import defaultdict, Counter
from fractions import Fraction
import functools
import random
sp.init_printing()
        ''')
    
    def execute(self, code: str, timeout: Optional[float] = None) -> Dict[str, Any]:
        """Execute code and return result."""
        timeout = timeout or self.timeout
        
        msg_id = self._client.execute(code, store_history=False)
        stdout_parts = []
        stderr_parts = []
        
        start = time.time()
        while True:
            if time.time() - start > timeout:
                self._km.interrupt_kernel()
                return {'success': False, 'output': '', 'error': 'Timeout'}
            
            try:
                msg = self._client.get_iopub_msg(timeout=1.0)
            except queue.Empty:
                continue
            
            if msg.get('parent_header', {}).get('msg_id') != msg_id:
                continue
            
            msg_type = msg.get('msg_type')
            content = msg.get('content', {})
            
            if msg_type == 'stream':
                text = content.get('text', '')
                if content.get('name') == 'stdout':
                    stdout_parts.append(text)
                else:
                    stderr_parts.append(text)
            elif msg_type == 'error':
                stderr_parts.append('\n'.join(content.get('traceback', [])))
            elif msg_type == 'execute_result':
                data = content.get('data', {})
                text = data.get('text/plain', '')
                if text:
                    stdout_parts.append(text + '\n')
            elif msg_type == 'status' and content.get('execution_state') == 'idle':
                break
        
        stdout = ''.join(stdout_parts)
        stderr = ''.join(stderr_parts)
        
        if stderr:
            return {'success': False, 'output': stdout, 'error': stderr}
        return {'success': True, 'output': stdout.strip(), 'error': None}
    
    def reset(self):
        """Reset kernel state."""
        self.execute('%reset -f')
        self.execute('''
import math
import numpy as np
import sympy as sp
from sympy import symbols, expand, factor, simplify, solve, Eq
import itertools
from collections import defaultdict, Counter
from fractions import Fraction
import functools
import random
        ''')
    
    def close(self):
        if self._client:
            self._client.stop_channels()
        if self._km:
            self._km.shutdown_kernel(now=True)

# Test sandbox
print("Testing Jupyter Sandbox...")
_test_sandbox = JupyterSandbox(timeout=5)
_result = _test_sandbox.execute("print(2 + 3)")
print(f"Test result: {_result}")
_test_sandbox.close()
print("Sandbox working!")

In [None]:
# ============================================================
# LEMMA FRAMEWORK PROMPTS
# ============================================================

SYSTEM_PROMPT = """You are an expert mathematical problem solver competing at the International Mathematical Olympiad level.

Your approach follows the LEMMA framework:
1. **Decompose** the problem into verifiable sub-claims (lemmas)
2. **Prove** each lemma rigorously
3. **Verify** each lemma with Python code execution
4. **Synthesize** the final answer from verified lemmas

## Problem-Solving Protocol:

For each problem, you will:
1. **ANALYZE**: Understand what's given and what to find
2. **DECOMPOSE**: Break into 2-6 lemmas (intermediate claims)
3. **VERIFY**: Each lemma MUST be checked with Python code
4. **REPAIR**: If verification fails, fix the lemma and retry
5. **SYNTHESIZE**: Combine verified lemmas for final answer

## Critical Rules:
- Every lemma MUST have a corresponding Python verification
- If code execution fails or gives wrong result, you MUST fix the lemma
- Do not proceed to next lemma until current one is verified
- Final answer must be in \\boxed{} format
- Answer must be a non-negative integer 0-99999

## Available Tools:
- Python with math, numpy, sympy (symbolic), itertools, collections
- Use sympy for exact symbolic computation
- Use numpy for numerical verification
- Always print results to verify correctness
"""

DECOMPOSE_PROMPT = """Given this problem, decompose it into verifiable lemmas.

Problem: {problem}

Provide your decomposition in this format:

**Lemma 1**: [Statement of first sub-claim]
- Verification: [What Python code will verify this]

**Lemma 2**: [Statement of second sub-claim]  
- Verification: [What Python code will verify this]

...

**Final Step**: How to combine lemmas to get answer

Be specific about what each lemma claims and how to verify it."""

PROVE_LEMMA_PROMPT = """Prove and verify this lemma.

Problem: {problem}

Verified Lemmas So Far:
{verified_lemmas}

Current Lemma to Prove: {lemma}

Your task:
1. Provide mathematical reasoning for this lemma
2. Write Python code to verify it
3. Execute the code and confirm it works
4. If verification fails, fix your reasoning and retry

Format your response as:

**Reasoning**: [Your mathematical argument]

**Verification Code**:
```python
# Your verification code here
print(result)
```

After seeing execution results, confirm if lemma is verified."""

SYNTHESIZE_PROMPT = """Synthesize final answer from verified lemmas.

Problem: {problem}

Verified Lemmas:
{verified_lemmas}

Your task:
1. Combine the verified lemmas logically
2. Compute the final numerical answer
3. Verify the answer with Python if possible
4. Provide final answer in \\boxed{}

Format:
**Synthesis**: [How lemmas combine to answer]
**Verification**: [Python code to double-check]
**Final Answer**: \\boxed{[number]}"""

REPAIR_PROMPT = """Your previous attempt at this lemma failed verification.

Lemma: {lemma}

Your Previous Attempt:
{previous_attempt}

Execution Result:
{execution_result}

Error Analysis:
{error_analysis}

Fix your approach and provide:
1. Corrected reasoning
2. Fixed Python code
3. Verification that it now works"""

print("Prompts defined.")

In [None]:
# ============================================================
# LEMMA DATA STRUCTURES
# ============================================================

@dataclass
class Lemma:
    """Represents a single lemma/sub-claim."""
    id: int
    statement: str
    verification_plan: str
    proof: str = ""
    verification_code: str = ""
    execution_result: Optional[Dict] = None
    verified: bool = False
    retries: int = 0

@dataclass
class SolutionAttempt:
    """Tracks one complete solution attempt."""
    attempt_id: int
    lemmas: List[Lemma] = field(default_factory=list)
    final_answer: Optional[int] = None
    complete: bool = False
    entropy: float = 0.0

def extract_answer(text: str) -> Optional[int]:
    """Extract integer answer from text."""
    # Look for \boxed{}
    matches = re.findall(r'\\boxed\s*\{\s*([0-9,]+)\s*\}', text)
    if matches:
        try:
            value = int(matches[-1].replace(',', ''))
            if 0 <= value <= 99999:
                return value
        except ValueError:
            pass
    
    # Look for "final answer is X"
    matches = re.findall(r'final answer is:?\s*([0-9,]+)', text, re.IGNORECASE)
    if matches:
        try:
            value = int(matches[-1].replace(',', ''))
            if 0 <= value <= 99999:
                return value
        except ValueError:
            pass
    
    return None

def extract_code(text: str) -> Optional[str]:
    """Extract Python code from markdown."""
    patterns = [
        r'```python\s*(.*?)\s*```',
        r'```\s*(.*?)\s*```',
    ]
    for pattern in patterns:
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
    return None

print("Data structures defined.")

In [None]:
# ============================================================
# LLM INTERFACE
# ============================================================

class LLMInterface:
    """Wrapper for vLLM inference."""
    
    def __init__(self, model_path: str, use_small: bool = True):
        self.use_small = use_small
        
        if use_small:
            print(f"Loading small model: {model_path}")
            # Small model loads directly from HF
            self.llm = LLM(
                model=model_path,
                tensor_parallel_size=1,
                dtype="auto",
                trust_remote_code=True,
                max_model_len=8192,
                gpu_memory_utilization=0.9
            )
        else:
            print(f"Loading large model from: {model_path}")
            # Large model (GPT-120B) - adjust settings
            self.llm = LLM(
                model=model_path,
                tensor_parallel_size=1,
                dtype="auto",
                trust_remote_code=True,
                max_model_len=65536,
                gpu_memory_utilization=0.96,
                kv_cache_dtype="fp8_e4m3"
            )
        
        print("Model loaded successfully!")
    
    def generate(
        self, 
        prompt: str, 
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stop: Optional[List[str]] = None
    ) -> str:
        """Generate text from prompt."""
        
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=TOP_P,
            max_tokens=max_tokens,
            stop=stop or []
        )
        
        outputs = self.llm.generate(prompt, sampling_params)
        return outputs[0].outputs[0].text
    
    def generate_with_template(
        self,
        system: str,
        user: str,
        temperature: float = 0.7,
        max_tokens: int = 2048
    ) -> str:
        """Generate with chat template."""
        
        # Simple chat template
        prompt = f"<|system|>\n{system}<|end|>\n<|user|>\n{user}<|end|>\n<|assistant|>\n"
        
        return self.generate(prompt, temperature, max_tokens, stop=["<|end|>"])

# Initialize LLM (with small model for testing)
model_path = SMALL_MODEL_NAME if USE_SMALL_MODEL else KAGGLE_MODEL_PATH
llm = LLMInterface(model_path, use_small=USE_SMALL_MODEL)

In [None]:
# ============================================================
# LEMMA-BASED SOLVER
# ============================================================

class LemmaSolver:
    """Main solver implementing LEMMA framework."""
    
    def __init__(self, llm_interface: LLMInterface, sandbox: JupyterSandbox):
        self.llm = llm_interface
        self.sandbox = sandbox
    
    def decompose_problem(self, problem: str) -> List[Lemma]:
        """Step 1: Decompose problem into lemmas."""
        
        prompt = DECOMPOSE_PROMPT.format(problem=problem)
        response = self.llm.generate_with_template(
            SYSTEM_PROMPT,
            prompt,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS_PER_TURN
        )
        
        # Parse lemmas from response
        lemmas = []
        
        # Extract lemma sections
        lemma_pattern = r'\*\*Lemma\s*(\d+)\*\*:?\s*(.+?)(?=\*\*Lemma|$)'
        matches = re.findall(lemma_pattern, response, re.DOTALL | re.IGNORECASE)
        
        if not matches:
            # Fallback: create single lemma
            lemmas.append(Lemma(
                id=1,
                statement="Solve the problem directly",
                verification_plan="Compute final answer"
            ))
        else:
            for i, (num, content) in enumerate(matches[:MAX_LEMMAS], 1):
                lines = content.strip().split('\n')
                statement = lines[0].strip('- ')
                verification = '\n'.join(lines[1:]).strip()
                
                lemmas.append(Lemma(
                    id=i,
                    statement=statement,
                    verification_plan=verification
                ))
        
        return lemmas
    
    def prove_lemma(
        self, 
        problem: str, 
        lemma: Lemma, 
        verified_lemmas: List[Lemma]
    ) -> bool:
        """Step 2: Prove and verify a single lemma."""
        
        verified_text = '\n'.join([
            f"Lemma {l.id}: {l.statement}\nProof: {l.proof[:200]}..."
            for l in verified_lemmas
        ]) if verified_lemmas else "None"
        
        prompt = PROVE_LEMMA_PROMPT.format(
            problem=problem,
            lemma=lemma.statement,
            verified_lemmas=verified_text
        )
        
        for retry in range(MAX_LEMMA_RETRIES):
            lemma.retries = retry
            
            response = self.llm.generate_with_template(
                SYSTEM_PROMPT,
                prompt,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS_PER_TURN
            )
            
            lemma.proof = response
            
            # Extract and execute verification code
            code = extract_code(response)
            if code:
                lemma.verification_code = code
                result = self.sandbox.execute(code, timeout=PYTHON_TIMEOUT)
                lemma.execution_result = result
                
                if result['success'] and not result['error']:
                    lemma.verified = True
                    return True
            
            # If failed, prepare repair prompt
            error_msg = result.get('error', 'No output') if result else 'No code found'
            prompt = REPAIR_PROMPT.format(
                lemma=lemma.statement,
                previous_attempt=response,
                execution_result=error_msg,
                error_analysis="Code execution failed or gave incorrect result"
            )
        
        return False
    
    def synthesize_answer(
        self, 
        problem: str, 
        verified_lemmas: List[Lemma]
    ) -> Optional[int]:
        """Step 3: Synthesize final answer."""
        
        verified_text = '\n'.join([
            f"Lemma {l.id}: {l.statement}\nVerified: {l.execution_result.get('output', 'N/A')}"
            for l in verified_lemmas
        ])
        
        prompt = SYNTHESIZE_PROMPT.format(
            problem=problem,
            verified_lemmas=verified_text
        )
        
        response = self.llm.generate_with_template(
            SYSTEM_PROMPT,
            prompt,
            temperature=0.3,  # Lower temp for final answer
            max_tokens=MAX_TOKENS_PER_TURN
        )
        
        # Try to extract answer
        answer = extract_answer(response)
        
        # Verify with code if possible
        code = extract_code(response)
        if code and answer is None:
            result = self.sandbox.execute(code, timeout=PYTHON_TIMEOUT)
            if result['success']:
                # Try to extract number from output
                numbers = re.findall(r'\b\d+\b', result['output'])
                if numbers:
                    val = int(numbers[-1])
                    if 0 <= val <= 99999:
                        answer = val
        
        return answer
    
    def solve(self, problem: str, timeout: float = PROBLEM_TIMEOUT) -> Dict[str, Any]:
        """Full LEMMA solving pipeline."""
        
        start_time = time.time()
        
        # Step 1: Decompose
        lemmas = self.decompose_problem(problem)
        print(f"  Decomposed into {len(lemmas)} lemmas")
        
        # Step 2: Prove each lemma
        verified_lemmas = []
        for lemma in lemmas:
            if time.time() - start_time > timeout:
                print(f"  Timeout! Moving to synthesis")
                break
            
            print(f"  Proving Lemma {lemma.id}: {lemma.statement[:50]}...")
            success = self.prove_lemma(problem, lemma, verified_lemmas)
            
            if success:
                print(f"    ✓ Verified (retries: {lemma.retries})")
                verified_lemmas.append(lemma)
            else:
                print(f"    ✗ Failed after {lemma.retries + 1} attempts")
                # Continue anyway, might still get answer
        
        # Step 3: Synthesize
        if time.time() - start_time > timeout:
            print(f"  Timeout before synthesis!")
            return {'answer': None, 'lemmas': verified_lemmas}
        
        print(f"  Synthesizing from {len(verified_lemmas)}/{len(lemmas)} verified lemmas...")
        answer = self.synthesize_answer(problem, verified_lemmas)
        
        elapsed = time.time() - start_time
        print(f"  Done in {elapsed:.1f}s, answer: {answer}")
        
        return {
            'answer': answer,
            'lemmas': verified_lemmas,
            'total_lemmas': len(lemmas),
            'verified_count': len(verified_lemmas),
            'time': elapsed
        }

print("LemmaSolver class defined.")

In [None]:
# ============================================================
# PARALLEL ATTEMPT SOLVER (Hybrid Approach)
# ============================================================

class ParallelLemmaSolver:
    """Runs multiple LEMMA attempts in parallel, votes on results."""
    
    def __init__(self, llm_interface: LLMInterface, num_workers: int = 4):
        self.llm = llm_interface
        self.num_workers = num_workers
        self.sandbox_pool = []
        
        # Initialize sandbox pool
        print(f"Initializing {num_workers} sandboxes...")
        for i in range(num_workers):
            self.sandbox_pool.append(JupyterSandbox(timeout=PYTHON_TIMEOUT))
        print("Sandboxes ready!")
    
    def _solve_single(
        self, 
        attempt_id: int, 
        problem: str, 
        deadline: float
    ) -> Dict[str, Any]:
        """Single attempt with its own sandbox."""
        
        sandbox = self.sandbox_pool[attempt_id % len(self.sandbox_pool)]
        sandbox.reset()
        
        solver = LemmaSolver(self.llm, sandbox)
        time_left = max(0, deadline - time.time())
        
        result = solver.solve(problem, timeout=time_left)
        result['attempt_id'] = attempt_id
        
        return result
    
    def solve(
        self, 
        problem: str, 
        num_attempts: int = PARALLEL_ATTEMPTS,
        timeout: float = PROBLEM_TIMEOUT
    ) -> int:
        """Parallel solving with voting."""
        
        print(f"\nSolving: {problem[:80]}...")
        print(f"Running {num_attempts} parallel attempts...")
        
        deadline = time.time() + timeout
        results = []
        
        if num_attempts == 1:
            # Sequential
            result = self._solve_single(0, problem, deadline)
            results.append(result)
        else:
            # Parallel
            with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
                futures = {
                    executor.submit(
                        self._solve_single, i, problem, deadline
                    ): i for i in range(num_attempts)
                }
                
                for future in as_completed(futures):
                    try:
                        result = future.result()
                        results.append(result)
                        print(f"  Attempt {result['attempt_id']}: answer={result['answer']}, "
                              f"verified={result['verified_count']}/{result['total_lemmas']}")
                    except Exception as e:
                        print(f"  Attempt failed: {e}")
        
        # Voting
        return self._vote(results)
    
    def _vote(self, results: List[Dict]) -> int:
        """Weighted voting on results."""
        
        if not results:
            return 0
        
        # Collect answers with weights
        answer_weights = defaultdict(float)
        
        for r in results:
            ans = r.get('answer')
            if ans is not None and 0 <= ans <= 99999:
                # Weight by verification ratio
                verified = r.get('verified_count', 0)
                total = r.get('total_lemmas', 1)
                weight = (verified / total) + 0.1  # Base weight even if no lemmas
                answer_weights[ans] += weight
        
        if not answer_weights:
            print("No valid answers! Returning 0")
            return 0
        
        # Display vote table
        vote_data = [(ans, w) for ans, w in answer_weights.items()]
        vote_df = pd.DataFrame(vote_data, columns=['Answer', 'Weight'])
        vote_df = vote_df.sort_values('Weight', ascending=False)
        print("\nVoting results:")
        display(vote_df)
        
        best_answer = vote_df.iloc[0]['Answer']
        print(f"\nFinal Answer: {best_answer}")
        
        return int(best_answer)
    
    def close(self):
        for sandbox in self.sandbox_pool:
            sandbox.close()

# Initialize parallel solver
parallel_solver = ParallelLemmaSolver(llm, num_workers=min(PARALLEL_ATTEMPTS, 4))
print("\nSolver ready!")

In [None]:
# ============================================================
# TEST WITH SIMPLE PROBLEMS
# ============================================================

test_problems = [
    "What is $0 \times 10$?",
    "Solve $4 + x = 4$ for $x$.",
    "What is the sum of the first 5 positive integers?",
    "If a triangle has sides 3, 4, and 5, what is its perimeter?",
]

print("Testing with simple problems...\n")

for problem in test_problems:
    print("="*60)
    try:
        answer = parallel_solver.solve(problem, num_attempts=1, timeout=60)
        print(f"Result: {answer}\n")
    except Exception as e:
        print(f"Error: {e}\n")
    print()

print("\nSimple tests complete!")

In [None]:
# ============================================================
# KAGGLE COMPETITION INTERFACE
# ============================================================

import kaggle_evaluation.aimo_3_inference_server

def predict(id_: pl.DataFrame, question: pl.DataFrame) -> pl.DataFrame:
    """Kaggle prediction function."""
    
    id_value = id_.item(0)
    problem_text = question.item(0)
    
    print(f"\n{'='*60}")
    print(f"Problem ID: {id_value}")
    print(f"{'='*60}")
    
    # Solve with parallel attempts
    answer = parallel_solver.solve(
        problem_text, 
        num_attempts=PARALLEL_ATTEMPTS,
        timeout=PROBLEM_TIMEOUT
    )
    
    return pl.DataFrame({'id': id_value, 'answer': answer})

# For local testing without Kaggle server
def test_local():
    """Test locally with sample problems."""
    
    test_cases = [
        (1, "What is $2^{10}$?"),
        (2, "Find the remainder when $14^{2025}$ is divided by $100$."),
    ]
    
    for id_val, problem in test_cases:
        id_df = pl.DataFrame({'id': [id_val]})
        q_df = pl.DataFrame({'question': [problem]})
        result = predict(id_df, q_df)
        print(f"Answer: {result['answer'][0]}\n")

print("Kaggle predict function defined.")
print("\nTo test locally, run: test_local()")
print("For Kaggle submission, the predict() function will be used.")

In [None]:
# ============================================================
# MAIN ENTRY POINT
# ============================================================

if __name__ == "__main__" or True:  # Run in notebook
    
    # Check if running on Kaggle
    is_kaggle = os.path.exists('/kaggle')
    print(f"Running on Kaggle: {is_kaggle}")
    
    if is_kaggle and not USE_SMALL_MODEL:
        # Start Kaggle inference server
        print("\nStarting Kaggle inference server...")
        inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(
            predict
        )
        
        if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
            print("Running in competition mode")
            inference_server.serve()
        else:
            print("Running in local test mode")
            inference_server.run_local_gateway(
                ('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',)
            )
    else:
        print("\nRun test_local() to test with sample problems")
        print("Or manually call predict() with DataFrames")

---

## Cleanup

Run this when done to free resources:

In [None]:
# Cleanup resources
try:
    parallel_solver.close()
    print("Solvers closed")
except:
    pass

try:
    import gc
    gc.collect()
    print("Garbage collected")
except:
    pass

print("Cleanup complete!")