In [None]:

import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
import re
import tempfile
import subprocess
from collections import Counter, defaultdict
from typing import Optional

import pandas as pd
import polars as pl
from transformers import set_seed
import torch
from vllm import LLM, SamplingParams
import kaggle_evaluation.aimo_3_inference_server

set_seed(42)
pd.set_option('display.max_colwidth', None)
cutoff_time = time.time() + (4 * 60 + 45) * 60


import os
import shutil

# Use /kaggle/tmp - no 20GB limit!
model_dir = "/kaggle/tmp/phi4-reasoning-plus"
os.makedirs(model_dir, exist_ok=True)

# Copy from both parts (actual copy is fine here)
for part in ["phi-4-reasoning-plus-1-2", "phi-4-reasoning-plus-2-2"]:
    part_dir = f"/kaggle/input/{part}"
    if os.path.exists(part_dir):
        for file in os.listdir(part_dir):
            src = os.path.join(part_dir, file)
            dst = os.path.join(model_dir, file)
            if os.path.isfile(src) and not os.path.exists(dst):
                shutil.copy2(src, dst)
                print(f"Copied: {file}")

print("\nModel files:")
for f in sorted(os.listdir(model_dir)):
    size_gb = os.path.getsize(os.path.join(model_dir, f)) / 1024**3
    print(f"  {f}: {size_gb:.2f} GB")

# Load from tmp
LLM_MODEL_PATH = "/kaggle/tmp/phi4-reasoning-plus"

# Load the model
llm = LLM(
    LLM_MODEL_PATH,
    dtype="bfloat16",
    max_num_seqs=256,
    max_model_len=32768,
    trust_remote_code=True,
    tensor_parallel_size=1,
    gpu_memory_utilization=0.96,
)

tokenizer = llm.get_tokenizer()

sampling_params = SamplingParams(
    temperature=0.8,
    min_p=0.01,
    skip_special_tokens=True,
    max_tokens=32768,
)

# Microsoft's recommended system prompt for Phi-4-Reasoning-Plus
SYSTEM_PROMPT = """You have the role of a math professor exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Put your final numerical answer in \\boxed{}."""

# Multiple prompts for diversity
thoughts = [
    SYSTEM_PROMPT,
    SYSTEM_PROMPT + " Double-check your arithmetic carefully.",
    SYSTEM_PROMPT + " Consider multiple approaches before settling on an answer.",
    SYSTEM_PROMPT + " Verify your solution by substituting back.",
    SYSTEM_PROMPT + " Break down the problem into smaller parts.",
]

def extract_boxed_answers(text: str) -> list[int]:
    pattern = r'oxed{(.*?)}'
    matches = re.findall(pattern, text)
    if not matches:
        return []
    ans = []
    for content in matches:
        content = content.replace(',', '').replace(' ', '')
        if content.isdigit():
            num = content
        else:
            nums = re.findall(r'\d+', content)
            if not nums:
                continue
            num = nums[-1]
        try:
            ans.append(int(num))
        except:
            pass
    return ans

def select_answer(answers: list):
    valid_answers = []
    for answer in answers:
        try:
            if int(answer) != float(answer):
                continue
            if 0 <= int(answer) <= 99999:
                valid_answers.append(int(answer))
        except:
            pass
    if not valid_answers:
        return 49
    answer, _ = Counter(valid_answers).most_common(1)[0]
    return answer

def extract_python_code(text: str) -> list[str]:
    pattern = r'```python\s*(.*?)\s*```'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches

def process_python_code(query):
    query = "import math\nimport numpy as np\nimport sympy as sp\nfrom sympy import *\n" + query
    return query.strip()

class PythonREPL:
    def __init__(self, timeout=8):
        self.timeout = timeout

    def __call__(self, query):
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file_path = os.path.join(temp_dir, "tmp.py")
            with open(temp_file_path, "w", encoding="utf-8") as f:
                f.write(query)
            try:
                result = subprocess.run(
                    ["python3", temp_file_path],
                    capture_output=True,
                    check=False,
                    text=True,
                    timeout=self.timeout,
                )
            except subprocess.TimeoutExpired:
                return False, f"Timeout after {self.timeout}s"
            if result.returncode == 0:
                return True, result.stdout.strip()
            return False, result.stderr.strip()

MessagesBatch = list[list[dict[str, str]]]
answer_contributions = defaultdict(list)

def format_phi4_prompt(system: str, user: str) -> str:
    """Format prompt using Phi-4's chat template"""
    return f"<|im_start|>system<|im_sep|>\n{system}<|im_end|>\n<|im_start|>user<|im_sep|>\n{user}<|im_end|>\n<|im_start|>assistant<|im_sep|>\n"

def batch_message_generate(msg_batch: MessagesBatch) -> MessagesBatch:
    list_of_texts = []
    for messages in msg_batch:
        system_msg = messages[0]['content']
        user_msg = messages[1]['content']
        prompt = format_phi4_prompt(system_msg, user_msg)
        list_of_texts.append(prompt)
    
    request_output = llm.generate(prompts=list_of_texts, sampling_params=sampling_params)
    for messages, single_request_output in zip(msg_batch, request_output):
        messages.append({'role': 'assistant', 'content': single_request_output.outputs[0].text})
    return msg_batch

def batch_message_filter(msg_batch: MessagesBatch, list_of_idx: list[int]):
    global answer_contributions
    extracted_answers = []
    msgs_to_keep = []
    idx_to_keep = []
    for idx, messages in zip(list_of_idx, msg_batch):
        answers = extract_boxed_answers(messages[-1]['content'])
        if answers:
            extracted_answers.extend(answers)
            for answer in answers:
                answer_contributions[answer].append(idx)
        else:
            msgs_to_keep.append(messages)
            idx_to_keep.append(idx)
    return msgs_to_keep, extracted_answers, idx_to_keep

def batch_execute_and_get_answer(list_of_messages: MessagesBatch) -> list[int]:
    ans = []
    for messages in list_of_messages:
        python_code_list = extract_python_code(messages[-1]['content'])
        for python_code in python_code_list:
            python_code = process_python_code(python_code)
            try:
                success, output = PythonREPL()(python_code)
                if not success:
                    continue
                matches = re.findall(r'(\d+)', output)
                for match in matches:
                    ans.append(int(match))
            except:
                pass
    return ans

def predict_for_question(question: str, max_rounds: int = 1) -> int:
    global answer_contributions
    if time.time() > cutoff_time:
        return 210
    
    msgs_batch: MessagesBatch = [
        [{"role": "system", "content": t}, {"role": "user", "content": question}]
        for t in thoughts
    ]
    
    all_extracted_answers = []
    list_of_idx = list(range(len(msgs_batch)))
    
    for round_idx in range(max_rounds):
        msgs_batch = batch_message_generate(msgs_batch)
        extracted_python_answer = batch_execute_and_get_answer(msgs_batch)
        msgs_batch, extracted_answers, list_of_idx = batch_message_filter(msgs_batch, list_of_idx)
        all_extracted_answers.extend(extracted_python_answer)
        all_extracted_answers.extend(extracted_answers)
        if not msgs_batch:
            break
    
    return select_answer(all_extracted_answers)

def predict(id_: pl.DataFrame, question: pl.DataFrame, answer: Optional[pl.DataFrame] = None) -> pl.DataFrame:
    id_ = id_.item(0)
    question_str = question.item(0)
    answer = predict_for_question(question_str)
    return pl.DataFrame({'id': id_, 'answer': answer})

# Start inference server
inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        ('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',)
    )