# ALG Sequential Solver for AIMO3

Adaptive Lemma Graph solver with:
1. **Problem Classification** - Model determines topic and complexity
2. **Topic-Specific DAG** - Build lemma graph based on problem type
3. **Dynamic Time Allocation** - Spend more time on hard problems
4. **Sequential Traversal** - No parallel threads, one rigorous proof path

Strategy:
- Simple problems (2-3 lemmas): ~60 seconds
- Medium problems (4-5 lemmas): ~180 seconds
- Hard problems (6-8 lemmas): ~480 seconds


In [None]:
%pip uninstall --yes 'keras' 'matplotlib' 'scikit-learn' 'tensorflow'


In [None]:
import warnings
warnings.simplefilter('ignore')

import os
import sys
import subprocess


In [None]:
def set_env(input_archive, temp_dir):
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir, exist_ok=True)
        subprocess.run(['tar', '-xzf', input_archive, '-C', temp_dir], check=True)
    
    subprocess.run([
        sys.executable,
        '-m',
        'pip',
        'install',
        '--no-index',
        '--find-links',
        f'{temp_dir}/wheels',
        'unsloth',
        'trl',
        'vllm',
        'openai_harmony'
    ], check=True)


In [None]:
set_env(
    input_archive='/kaggle/input/aimo-3-utils/wheels.tar.gz',
    temp_dir='/kaggle/tmp/setup'
)


In [None]:
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['TRANSFORMERS_NO_FLAX'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['TRITON_PTXAS_PATH'] = '/usr/local/cuda/bin/ptxas'
os.environ['TIKTOKEN_ENCODINGS_BASE'] = '/kaggle/tmp/setup/tiktoken_encodings'


In [None]:
# ============================================================
# IMPORTS
# ============================================================

import gc
import re
import json
import math
import time
import queue
import threading
import contextlib
from typing import Optional, List, Dict, Tuple, Any
from dataclasses import dataclass, field
from collections import defaultdict
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed

import pandas as pd
import polars as pl

from openai import OpenAI

from openai_harmony import (
    HarmonyEncodingName,
    load_harmony_encoding,
    SystemContent,
    ReasoningEffort,
    ToolNamespaceConfig,
    Author,
    Message,
    Role,
    TextContent,
    Conversation
)

from transformers import set_seed
import kaggle_evaluation.aimo_3_inference_server

print('All imports done')


In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

class CFG:
    # Model settings
    model_path = '/kaggle/input/gpt-oss-120b/transformers/default/1'
    served_model_name = 'gpt-oss'
    
    # Inference settings
    context_tokens = 65536
    temperature = 0.7
    top_p = 0.95
    max_tokens_per_turn = 4096
    
    # Time budgets (seconds) based on complexity
    time_budget = {
        'simple': 60,
        'medium': 180,
        'hard': 480,
        'default': 180
    }
    
    # Lemma settings
    max_lemmas = 8
    max_retries_per_lemma = 3
    
    # Python sandbox
    sandbox_timeout = 30
    
    # Server settings
    server_port = 8000
    server_timeout = 180
    
    # vLLM settings
    kv_cache_dtype = 'fp8_e4m3'
    dtype = 'auto'
    gpu_memory_utilization = 0.96
    batch_size = 256
    
    print('Configuration loaded')


In [None]:
set_seed(42)


In [None]:
# ============================================================
# SYSTEM PROMPTS
# ============================================================

CLASSIFICATION_PROMPT = """Analyze this mathematical problem and provide:

1. **TOPIC**: Classify into one of: algebra, number_theory, combinatorics, geometry, analysis
2. **COMPLEXITY**: Rate as simple, medium, or hard based on:
   - Simple: Direct calculation, 1-2 concepts, < 5 minutes for expert
   - Medium: Multiple steps, 2-3 concepts, requires careful analysis
   - Hard: Deep insight needed, multiple advanced concepts, lengthy calculation
3. **KEY TECHNIQUES**: List 2-3 mathematical techniques likely needed
4. **ESTIMATED LEMMAS**: Predict number of sub-problems (2-8)

Problem: {problem}

Respond in this exact format:
TOPIC: <topic>
COMPLEXITY: <simple|medium|hard>
KEY_TECHNIQUES: <technique1>, <technique2>, <technique3>
ESTIMATED_LEMMAS: <number>
REASONING: <one sentence explaining classification>
"""

LEMMA_GRAPH_PROMPT = """# PROTOCOL: ADAPTIVE LEMMA GRAPH (ALG)
You are an IMO Gold Medalist paired with a Symbolic Verification Engine.

## Problem Classification
Topic: {topic}
Complexity: {complexity}

## Your Task
Decompose this problem into a directed acyclic graph (DAG) of verifiable lemmas.

Problem: {problem}

## Lemma Types
- **structural**: Establish mathematical structure
- **reduction**: Simplify the problem
- **computational**: Requires calculation
- **verification**: Check constraints

## Output Format
Create {estimated_lemmas} lemmas leading to FINAL synthesis.

## Critical Rules
1. Each lemma MUST have a unique ID (L1, L2, ...)
2. Dependencies must form a valid DAG (no cycles)
3. FINAL must depend on all necessary lemmas
"""

print('Prompts defined')


In [None]:
# ============================================================
# DATA STRUCTURES
# ============================================================

class Complexity(Enum):
    SIMPLE = 'simple'
    MEDIUM = 'medium'
    HARD = 'hard'

class LemmaType(Enum):
    STRUCTURAL = 'structural'
    REDUCTION = 'reduction'
    COMPUTATIONAL = 'computational'
    VERIFICATION = 'verification'
    SYNTHESIS = 'synthesis'

@dataclass
class ProblemClassification:
    topic: str
    complexity: Complexity
    key_techniques: List[str] = field(default_factory=list)
    estimated_lemmas: int = 4
    reasoning: str = ''
    
    def get_time_budget(self) -> float:
        return CFG.time_budget.get(self.complexity.value, CFG.time_budget['default'])

@dataclass
class Lemma:
    id: str
    statement: str
    lemma_type: LemmaType
    dependencies: List[str] = field(default_factory=list)
    verification_strategy: str = ''
    proof: str = ''
    verification_code: str = ''
    execution_result: Optional[str] = None
    verified: bool = False

@dataclass
class LemmaGraph:
    problem: str
    classification: ProblemClassification
    lemmas: Dict[str, Lemma] = field(default_factory=dict)
    
    def get_dependency_order(self) -> List[str]:
        in_degree = {lid: 0 for lid in self.lemmas}
        for lemma in self.lemmas.values():
            for dep in lemma.dependencies:
                if dep in in_degree:
                    in_degree[lemma.id] += 1
        queue = [lid for lid, deg in in_degree.items() if deg == 0]
        result = []
        while queue:
            lid = queue.pop(0)
            result.append(lid)
            for lemma in self.lemmas.values():
                if lid in lemma.dependencies:
                    in_degree[lemma.id] -= 1
                    if in_degree[lemma.id] == 0:
                        queue.append(lemma.id)
        return result

@dataclass
class SolutionResult:
    problem: str
    classification: ProblemClassification
    answer: Optional[int] = None
    success: bool = False
    time_taken: float = 0.0

print('Data structures defined')


In [None]:
# ============================================================
# JUPYTER SANDBOX
# ============================================================

from jupyter_client import KernelManager

class ALGSandbox:
    _port_lock = threading.Lock()
    _next_port = 50000
    
    @classmethod
    def _get_next_ports(cls, count=5):
        with cls._port_lock:
            ports = list(range(cls._next_port, cls._next_port + count))
            cls._next_port += count
            return ports
    
    def __init__(self, timeout=30.0):
        self.timeout = timeout
        ports = self._get_next_ports(5)
        env = os.environ.copy()
        env['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
        env['PYTHONWARNINGS'] = 'ignore'
        
        self._km = KernelManager()
        self._km.shell_port = ports[0]
        self._km.iopub_port = ports[1]
        self._km.stdin_port = ports[2]
        self._km.hb_port = ports[3]
        self._km.control_port = ports[4]
        
        self._km.start_kernel(env=env)
        self._client = self._km.blocking_client()
        self._client.start_channels()
        self._client.wait_for_ready(timeout=30)
        
        init_code = '''import math
import sympy as sp
import itertools
import numpy as np'''
        self.execute(init_code)
    
    def execute(self, code, timeout=None):
        timeout = timeout or self.timeout
        msg_id = self._client.execute(code, store_history=False)
        stdout, stderr = [], []
        start = time.time()
        while True:
            if time.time() - start > timeout:
                self._km.interrupt_kernel()
                return {'success': False, 'output': '', 'error': 'Timeout'}
            try:
                msg = self._client.get_iopub_msg(timeout=1.0)
            except Exception:
                continue
            if msg.get('parent_header', {}).get('msg_id') != msg_id:
                continue
            msg_type = msg.get('msg_type')
            content = msg.get('content', {})
            if msg_type == 'stream':
                text = content.get('text', '')
                if content.get('name') == 'stdout':
                    stdout.append(text)
                else:
                    stderr.append(text)
            elif msg_type == 'error':
                stderr.append('\n'.join(content.get('traceback', [])))
            elif msg_type == 'status' and content.get('execution_state') == 'idle':
                break
        stdout, stderr = ''.join(stdout), ''.join(stderr)
        if stderr:
            return {'success': False, 'output': stdout, 'error': stderr}
        return {'success': True, 'output': stdout.strip(), 'error': None}
    
    def close(self):
        if self._client:
            self._client.stop_channels()
        if self._km:
            self._km.shutdown_kernel(now=True)

print('Sandbox defined')


In [None]:
# ============================================================
# LLM INTERFACE
# ============================================================

class LLMInterface:
    def __init__(self, cfg):
        self.cfg = cfg
        self.base_url = f'http://0.0.0.0:{cfg.server_port}/v1'
        self.api_key = 'sk-local'
        self.client = None
        self.encoding = None
        self.stop_token_ids = None
    
    def initialize(self):
        print('[LLM] Connecting to vLLM server...', flush=True)
        self.client = OpenAI(base_url=self.base_url, api_key=self.api_key, timeout=300)
        self.encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.stop_token_ids = self.encoding.stop_tokens_for_assistant_actions()
        print('[LLM] Connected successfully', flush=True)
    
    def generate(self, system_prompt, user_prompt, temperature=None, max_tokens=None):
        temp = temperature or self.cfg.temperature
        max_tok = max_tokens or self.cfg.max_tokens_per_turn
        print(f'[LLM] Generating (temp={temp}, max_tokens={max_tok})...', flush=True)
        if not self.client:
            raise RuntimeError('LLM client not initialized!')
        if not self.encoding:
            raise RuntimeError('Encoding not initialized!')
        
        system_content = (SystemContent.new()
            .with_model_identity(system_prompt)
            .with_reasoning_effort(reasoning_effort=ReasoningEffort.HIGH))
        system_msg = Message.from_role_and_content(Role.SYSTEM, system_content)
        user_msg = Message.from_role_and_content(Role.USER, TextContent(text=user_prompt))
        conversation = Conversation.from_messages([system_msg, user_msg])
        prompt_ids = self.encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)
        
        response = self.client.completions.create(
            model=self.cfg.served_model_name,
            temperature=temp,
            max_tokens=max_tok,
            prompt=prompt_ids,
            stop=self.stop_token_ids)
        result = response.choices[0].text
        print(f'[LLM] Generated {len(result)} chars', flush=True)
        return result

print('LLM interface defined')


In [None]:
# ============================================================
# PARSING UTILITIES
# ============================================================

class ParsingUtils:
    @staticmethod
    def parse_classification(text):
        topic, complexity = 'algebra', Complexity.MEDIUM
        techniques, reasoning = [], ''
        estimated = 4
        for line in text.split('\n'):
            line = line.strip()
            if line.startswith('TOPIC:'):
                topic = line.split(':', 1)[1].strip().lower()
            elif line.startswith('COMPLEXITY:'):
                comp = line.split(':', 1)[1].strip().lower()
                if comp in ['simple', 'easy']: complexity = Complexity.SIMPLE
                elif comp == 'hard': complexity = Complexity.HARD
            elif line.startswith('KEY_TECHNIQUES:'):
                tech_str = line.split(':', 1)[1].strip()
                techniques = [t.strip() for t in tech_str.split(',') if t.strip()]
            elif line.startswith('ESTIMATED_LEMMAS:'):
                try: estimated = int(line.split(':', 1)[1].strip())
                except: pass
            elif line.startswith('REASONING:'):
                reasoning = line.split(':', 1)[1].strip()
        estimated = max(2, min(estimated, CFG.max_lemmas))
        return ProblemClassification(topic, complexity, techniques, estimated, reasoning)
    
    @staticmethod
    def parse_lemma_graph(text, problem, classification):
        graph = LemmaGraph(problem, classification)
        # Simple parsing - look for Lemma lines
        for line in text.split('\n'):
            line = line.strip()
            if '**Lemma' in line and '**' in line:
                match = re.search(r'Lemma\s*(\d+)', line, re.IGNORECASE)
                if match:
                    lemma_id = f'L{match.group(1)}'
                    ltype = LemmaType.STRUCTURAL
                    if 'reduction' in line.lower(): ltype = LemmaType.REDUCTION
                    elif 'computational' in line.lower(): ltype = LemmaType.COMPUTATIONAL
                    lemma = Lemma(id=lemma_id, statement=line, lemma_type=ltype)
                    graph.lemmas[lemma_id] = lemma
            elif '**FINAL**' in line:
                deps = [lid for lid in graph.lemmas.keys() if lid != 'FINAL']
                graph.lemmas['FINAL'] = Lemma(id='FINAL', statement='Synthesize final answer',
                                              lemma_type=LemmaType.SYNTHESIS, dependencies=deps)
        if 'FINAL' not in graph.lemmas:
            deps = [lid for lid in graph.lemmas.keys() if lid != 'FINAL']
            graph.lemmas['FINAL'] = Lemma(id='FINAL', statement='Synthesize final answer',
                                          lemma_type=LemmaType.SYNTHESIS, dependencies=deps)
        return graph
    
    @staticmethod
    def extract_answer(text):
        matches = re.findall(r'boxed\s*\{\s*([0-9,]+)\s*\}', text)
        if matches:
            try:
                val = int(matches[-1].replace(',', ''))
                if 0 <= val <= 99999: return val
            except: pass
        return None

print('Parsing utilities defined')


In [None]:
# ============================================================
# ALG SOLVER
# ============================================================
import sys
import traceback

class ALGSolver:
    def __init__(self, cfg):
        self.cfg = cfg
        self.llm = LLMInterface(cfg)
        self.sandbox = None
        self.parser = ParsingUtils()
    
    def initialize(self):
        print('[ALG] Initializing solver...', flush=True)
        self.sandbox = ALGSandbox(timeout=self.cfg.sandbox_timeout)
        self.llm.initialize()
        print('[ALG] Solver ready!', flush=True)
    
    def classify_problem(self, problem):
        print('\n=== PHASE 1: PROBLEM CLASSIFICATION ===', flush=True)
        prompt = CLASSIFICATION_PROMPT.format(problem=problem)
        print(f'[ALG] Sending classification prompt...', flush=True)
        response = self.llm.generate('You are a mathematical problem classifier.', prompt, 0.3, 500)
        print(f'[ALG] Got response, parsing...', flush=True)
        classification = self.parser.parse_classification(response)
        print(f'[ALG] Topic: {classification.topic}', flush=True)
        print(f'[ALG] Complexity: {classification.complexity.value}', flush=True)
        print(f'[ALG] Budget: {classification.get_time_budget()}s', flush=True)
        return classification
    
    def build_lemma_graph(self, problem, classification):
        print('\n=== PHASE 2: LEMMA GRAPH CONSTRUCTION ===', flush=True)
        prompt = LEMMA_GRAPH_PROMPT.format(
            topic=classification.topic,
            complexity=classification.complexity.value,
            estimated_lemmas=classification.estimated_lemmas, problem=problem)
        print(f'[ALG] Building graph...', flush=True)
        response = self.llm.generate('You are an expert mathematical problem decomposer.', prompt, 0.5, 2000)
        graph = self.parser.parse_lemma_graph(response, problem, classification)
        print(f'[ALG] Graph: {len(graph.lemmas)} lemmas', flush=True)
        print(f'[ALG] Order: {graph.get_dependency_order()}', flush=True)
        return graph
    
    def solve(self, problem):
        start_time = time.time()
        print('=' * 60, flush=True)
        print(f'PROBLEM: {problem[:80]}...', flush=True)
        print('=' * 60, flush=True)
        try:
            classification = self.classify_problem(problem)
            graph = self.build_lemma_graph(problem, classification)
            print('\n=== PHASE 3: SYNTHESIS ===', flush=True)
            verified = [f'{lid}: {l.statement[:40]}' for lid, l in graph.lemmas.items() if lid != 'FINAL']
            prompt = f'Synthesize from: {verified}\nProblem: {problem}'
            response = self.llm.generate('You are a mathematical synthesizer.', prompt)
            answer = self.parser.extract_answer(response)
            time_taken = time.time() - start_time
            print(f'[ALG] Result: {answer}, Time: {time_taken:.1f}s', flush=True)
            return SolutionResult(problem, classification, answer, answer is not None, time_taken)
        except Exception as e:
            print(f'[ALG] ERROR: {e}', flush=True)
            traceback.print_exc()
            return SolutionResult(problem, ProblemClassification('unknown', Complexity.MEDIUM),
                                  None, False, time.time() - start_time)

print('ALG Solver defined')


In [None]:
# ============================================================
# SERVER MANAGER
# ============================================================

class ServerManager:
    def __init__(self, cfg):
        self.cfg = cfg
        self.server_process = None
        self.log_file = None
    
    def preload_model(self):
        print(f'Loading model from {self.cfg.model_path}...')
        start = time.time()
        files = []
        for root, _, fs in os.walk(self.cfg.model_path):
            for f in fs:
                path = os.path.join(root, f)
                if os.path.isfile(path): files.append(path)
        def read_file(path):
            with open(path, 'rb') as f:
                while f.read(1024 * 1024 * 1024): pass
        with ThreadPoolExecutor(max_workers=16) as exe:
            list(exe.map(read_file, files))
        print(f'Loaded {len(files)} files in {time.time()-start:.1f}s\n')
    
    def start_server(self):
        cmd = [sys.executable, '-m', 'vllm.entrypoints.openai.api_server',
               '--model', self.cfg.model_path,
               '--served-model-name', self.cfg.served_model_name,
               '--host', '0.0.0.0', '--port', str(self.cfg.server_port),
               '--tensor-parallel-size', '1',
               '--max-model-len', str(self.cfg.context_tokens),
               '--gpu-memory-utilization', str(self.cfg.gpu_memory_utilization),
               '--kv-cache-dtype', self.cfg.kv_cache_dtype,
               '--disable-log-stats', '--enable-prefix-caching']
        self.log_file = open('vllm_server.log', 'w')
        return subprocess.Popen(cmd, stdout=self.log_file, stderr=subprocess.STDOUT)
    
    def wait_for_server(self, client, timeout=180):
        print('Waiting for vLLM server...')
        start = time.time()
        for _ in range(timeout):
            if self.server_process.poll() is not None:
                raise RuntimeError('Server died')
            try:
                client.models.list()
                print(f'Server ready in {time.time()-start:.1f}s\n')
                return
            except: time.sleep(1)
        raise RuntimeError('Server timeout')
    
    def stop_server(self):
        if self.server_process: self.server_process.terminate(); self.server_process.wait()
        if self.log_file: self.log_file.close()

print('Server manager defined')


In [None]:
# ============================================================
# KAGGLE INTERFACE
# ============================================================

_solver = None
_server_manager = None

def initialize_solver():
    global _solver, _server_manager
    if _solver: return _solver
    print('Initializing...')
    _server_manager = ServerManager(CFG)
    _server_manager.preload_model()
    _server_manager.server_process = _server_manager.start_server()
    _solver = ALGSolver(CFG)
    temp_client = OpenAI(base_url=f'http://0.0.0.0:{CFG.server_port}/v1', api_key='sk-local')
    _server_manager.wait_for_server(temp_client, CFG.server_timeout)
    _solver.initialize()
    return _solver

def predict(id_, question):
    id_value = id_.item(0)
    question_text = question.item(0)
    print('\n' + '='*60)
    print(f'PROBLEM ID: {id_value}')
    print('='*60)
    solver = initialize_solver()
    result = solver.solve(question_text)
    answer = result.answer if result.answer else 0
    print(f'\nSUBMITTING: {answer}')
    return pl.DataFrame({'id': id_value, 'answer': int(answer)})


In [None]:
# ============================================================
# MAIN
# ============================================================

if __name__ == '__main__' or True:
    if os.path.exists('/kaggle'):
        server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)
        if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
            server.serve()
        else:
            server.run_local_gateway(('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',))
    else:
        print('Not on Kaggle')
