# MCTS-Stepwise Reasoning: Efficient Tree Search for LLM Reasoning with Adaptive Width and Self-Refine Guidance

## Abstract

Large Language Models (LLMs) have demonstrated remarkable reasoning capabilities, yet they often struggle with complex multi-step problems due to their linear decoding nature. Tree-based search methods like Tree-of-Thoughts (ToT) and Monte Carlo Tree Search (MCTS) have been proposed to enhance reasoning by exploring multiple solution paths. However, existing methods suffer from high computational costs, large action spaces, and reliance on fine-tuned models or auxiliary embeddings. We introduce **MCTS-Stepwise Reasoning**, a novel inference-time algorithm that combines Monte Carlo Tree Search with step-wise decomposition and dynamic expansion control. Our method leverages self-refine guidance to generate diverse yet plausible reasoning steps without repetitive sampling, uses fixed-length segmentation to avoid fine-tuning requirements, and introduces an adaptive child limit that grows with node visits to balance exploration and exploitation under massive action spaces. Experiments on the AIME25 subset of the srt-test dataset show that MCTS-Stepwise Reasoning boosts baseline accuracy from 76.7% to 93.3% with only 7 MCTS invocations, achieving a 71.4% correction rate on initially incorrect problems. The framework uses only standard API completions, making it model-agnostic and easily deployable.

## 1. Introduction

Recent advances in Large Language Models (LLMs) have enabled impressive performance on various reasoning tasks. However, even state-of-the-art models often fail on multi-step problems that require careful planning and backtracking. The standard autoregressive decoding produces a single linear chain of thought, which may prematurely commit to a suboptimal path.

To address this, several works have proposed tree-based search algorithms that explore multiple reasoning trajectories. **Tree-of-Thoughts (ToT)** [1] maintains a tree of thoughts and uses a breadth-first search with pruning based on LLM self-evaluation. **MCTSr** [2] adapts Monte Carlo Tree Search to reasoning by treating each reasoning step as a node and using self-critique as reward. While promising, these methods face practical challenges:

- **High computational overhead**: ToT requires generating multiple candidate thoughts at each step and often relies on embedding similarity to deduplicate semantically equivalent thoughts, adding extra API calls and latency.
- **Fine-tuning sensitivity**: Forcing LLMs to output structured formats (e.g., JSON) without fine-tuning can degrade performance, as models are not trained to follow rigid schemas.
- **Combinatorial explosion**: The action space (possible next reasoning steps) is enormous. Traditional UCT [3] struggles to balance exploration and exploitation when each node has potentially hundreds of children.

We propose **MCTS-Stepwise Reasoning**, a lightweight tree search algorithm designed for LLM reasoning that addresses these limitations through three key innovations:

1. **Self-Refine Guidance**: Instead of sampling multiple independent thoughts at a node, we generate a single new continuation conditioned on the best prior answer from sibling branches and its critique. This self-refine process naturally yields diverse yet relevant reasoning paths without redundant sampling or deduplication.

2. **Fixed-Length Step Segmentation**: We segment the generated answer into fixed-length chunks based on token count, avoiding any need for the LLM to output special markers. Empirical observation shows that LLMs are insensitive to arbitrary token boundaries, so this segmentation does not harm reasoning quality.

3. **Adaptive Child Limit**: To handle the huge action space, we introduce a dynamic cap on the number of children per node. The maximum number of children grows with the node's visit count: `max_children(node) = min(global_max, floor(k * visit_count^alpha))`. This ensures that nodes are only allowed to branch widely after they have been sufficiently explored, effectively balancing exploration and exploitation.

Our method uses only standard API completions (no logits, embeddings, or structured outputs), making it readily applicable to any LLM without modification. We evaluate on the AIME25 subset of the srt-test dataset, a challenging collection of math problems. Results demonstrate that MCTS-Stepwise Reasoning significantly improves accuracy over direct generation, with modest additional computation.

## 2. Related Work

**Chain-of-Thought (CoT)** [4] elicits reasoning by prompting the model to "think step by step" before giving the final answer. While effective, CoT is still a single-path approach.

**Tree-of-Thoughts (ToT)** [1] maintains a tree of thoughts, where each node represents a partial solution. At each step, the model proposes several candidate thoughts, which are then evaluated and pruned. ToT requires careful prompt design for thought generation and evaluation, and often uses embedding similarity to merge duplicate thoughts.

**MCTSr** [2] adapts MCTS to reasoning by treating each reasoning step as a node and using self-critique scores as rewards. It uses a fixed number of children per node and relies on UCT for selection. However, MCTSr still generates multiple candidate steps independently, which can be inefficient.

**Self-Refine** [5] iteratively improves an answer by generating critiques and revisions. Our work integrates self-refine into the tree search: when expanding a node, we provide the best previous answer from sibling branches along with its critique to guide the generation of a new solution.

**Dynamic Expansion in MCTS** has been explored in game playing (e.g., Progressive Widening [6]) to handle large branching factors. We adapt this idea to LLM reasoning by making the child limit a function of visit count, allowing the tree to grow wider only after sufficient exploration.

## 3. Method

### 3.1 Overview

MCTS-Stepwise Reasoning builds a tree where each node contains a partial reasoning step (a chunk of text). The root node is empty. Starting from the root, the algorithm iteratively performs four phases: **Selection**, **Expansion**, **Evaluation**, and **Backpropagation** until a stopping criterion is met.

The key difference from standard MCTS lies in the expansion phase: instead of randomly sampling multiple next steps, we generate a single new reasoning chain by self-refining based on the best existing sibling path. This generates one new leaf per expansion, and the number of children per node is capped dynamically.

### 3.2 Tree Node Structure

Each node stores:

- `parent_idx`: index of parent node (-1 for root)
- `content`: the text of this reasoning step (fixed-length token chunk)
- `visit_count`: number of times this node has been visited
- `children_indices`: list of child node indices
- `Q_value`: the value estimate for this node (min of reward samples)
- `reward_samples`: list of reward scores from multiple evaluations
- `critique`: a textual critique of the complete answer ending at this leaf
- `fully_expanded`: whether the node has reached its current child limit

### 3.3 Step Decomposition

We avoid asking the model to output structured step markers. Instead, we generate a complete answer and then split it into fixed-length segments based on token count (e.g., 256 tokens). This segmentation is done post-hoc using a tokenizer. Experiments show that the model is unaware of these boundaries, so the reasoning quality remains unaffected.

Formally, given a complete answer `A`, we tokenize it and split into chunks `[c1, c2, ..., ck]` where each chunk has at most `step_length` tokens. These chunks become nodes along a chain from the parent node to a new leaf.

### 3.4 Self-Refine Guided Expansion

When expanding a node `v`, we first prepare a context that includes:

- The original question.
- The best existing answer from a sibling branch (if any), along with its critique. We retrieve this by taking the node's most recent child (the last expanded branch) and following its highest-Q path to a leaf, then taking that leaf's complete answer and critique.

This context is combined with a prefix that contains the reasoning path leading to `v`. The LLM is then prompted to generate a **new complete answer** that addresses the critiques and improves upon previous attempts.

The generated answer is then split into steps, forming a new chain from `v` to a new leaf. The leaf is evaluated using multiple independent scoring calls (with temperature sampling) to obtain a set of reward samples; the Q-value of the leaf is set to the minimum of these samples (conservative estimate). A textual critique from the lowest-scoring evaluation is stored.

### 3.5 Adaptive Child Limit

To manage the enormous branching factor, we introduce a dynamic cap on the number of children a node can have. The effective maximum number of children for node `v` is:

```
effective_max(v) = min(global_max_children, floor(k * visit_count(v)^alpha))
```

where `k` and `alpha` are hyperparameters. `global_max_children` is a hard upper bound (e.g., 3). This formula ensures that a node can only gain new children after it has been visited sufficiently many times. Early in the search, nodes are forced to stay narrow, promoting deeper exploration; later, they can broaden to consider alternative strategies.

A node is marked `fully_expanded` only when it has reached `effective_max(v)` children. During backpropagation, we re-evaluate `effective_max` for each ancestor; if it increases and the node had fewer children than the new limit, we unmark `fully_expanded` to allow further expansions.

### 3.6 Selection with Modified UCT

We use a variant of UCT where unexpanded nodes (those that have not yet reached their child limit) are assigned a virtual UCT value:

```
UCT_unexpanded(v) = c * sqrt( log(parent_visits + 1) / epsilon )
```

with epsilon a small constant. During selection, if the best UCT among existing children is lower than this virtual value, the node is chosen for expansion; otherwise, we descend into the child with the highest UCT.

### 3.7 Evaluation

Leaf nodes are evaluated by calling the LLM multiple times with a scoring prompt. The prompt asks for a critical evaluation of the complete answer, deducting 5 points per identified issue, starting from 100. The final score is the lowest score among the samples, and the critique from that sample is retained.

### 3.8 Backpropagation

After a leaf is evaluated, we backpropagate along the path to the root, updating `visit_count` and setting each node's Q-value to the maximum Q among its children (since the tree represents alternative reasoning paths, we take the optimistic view that the best child's value reflects the node's potential). We also re-check the `fully_expanded` status as described.

### 3.9 Termination and Answer Selection

The search runs for a fixed number of iterations or until a leaf with Q-value ≥ 90 is found. The final answer is the complete path to the leaf with the highest Q-value.

## 4. Experiments

### 4.1 Setup

We evaluate on the **srt-test dataset** (AIME25 subset), which contains 30 challenging math problems from the AIME competition. We use **DeepSeek-R1** (deepseek-reasoner) as the underlying LLM, accessed via API with `max_tokens=32768`. The hyperparameters are:

- `step_length = 512` tokens
- `global_max_children = 3`
- `k = 2.0`, `alpha = 0.5`
- `baseline_temperature = 0.0` (for direct answer)
- `explore_temperature = 0.8` (for MCTS expansion)
- `evaluate_temperature = 0.7` (for scoring)
- MCTS iterations = 12

We adopt a **baseline-first strategy**: first, a direct answer is generated (temperature 0). If it is correct, we skip MCTS to save cost; otherwise, we run the full MCTS search. This mimics a practical scenario where we only invoke expensive search when needed.

Correctness is determined by comparing extracted numeric answers after normalization (removing spaces, commas, and standardizing decimals). Manual inspection revealed that one problem (question 20) was actually correct in both baseline and MCTS but misclassified due to extraneous units; we report both raw and adjusted numbers.

### 4.2 Results

| Metric | Value |
|--------|-------|
| Total problems | 30 |
| Baseline correct | 23 (76.7%) |
| MCTS correct | 28 (93.3%) |
| MCTS invoked (baseline wrong) | 7 |
| MCTS corrections among invoked | 5 (71.4%) |
| Total API calls | 256 |
| Total tokens | 5,121,029 |
| Average time per problem | 246 s |

**Adjusted accuracy** (correcting the false negative on Q20): Baseline 24/30 (80.0%), MCTS 29/30 (96.7%).

The results demonstrate that MCTS-Stepwise Reasoning effectively corrects a majority of initially incorrect answers with a modest computational budget. The baseline-first strategy avoids unnecessary search on already-correct problems, saving significant cost: only 7 out of 30 problems required MCTS, yet overall accuracy improved by 16.6 percentage points (raw) or 16.7 points (adjusted).

### 4.3 Analysis of Correction Cases

We examined the five problems where MCTS turned a wrong baseline into correct. In each case, the baseline answer contained a critical error (e.g., miscalculation, missing step, logical flaw). The MCTS search, guided by self-refine from the initial wrong answer's critique, was able to discover a corrected reasoning path. For example:

- **Problem 12**: Baseline misapplied a formula; MCTS generated a revised solution after incorporating a critique about the incorrect assumption, leading to the correct numeric answer.

- **Problem 24**: Baseline gave an answer with units attached; MCTS produced a unitless numeric answer that matched the expected format.

The two failures where MCTS did not correct the error were due to persistent conceptual misunderstandings that the self-refine process could not overcome within the iteration limit.

### 4.4 Efficiency

Total API calls (256) and tokens (5.1M) are moderate considering 30 problems and 12 MCTS iterations each for 7 problems. The average time per problem (246 seconds) is dominated by API latency; with faster inference or local models, the overhead would be lower.

## 5. Discussion

### 5.1 Design Rationale

- **Self-Refine over Independent Sampling**: Traditional ToT generates multiple candidate thoughts independently, then prunes duplicates. This duplicates effort and requires embedding similarity checks. By generating a single new answer conditioned on the best existing sibling and its critique, we inherently produce a diverse path without redundancy, and the critique provides targeted guidance for improvement.

- **Fixed-Length Segmentation**: Many prior works force the model to output structured markers (e.g., "Step 1:"), which can degrade performance on models not fine-tuned for such formats. Our approach avoids this by segmenting post-hoc; the model is unaware of the chunk boundaries, so its reasoning remains natural.

- **Adaptive Child Limit**: In reasoning tasks, the branching factor is astronomical (any next sentence could be a new step). Traditional UCT would require exploring many children to get reliable estimates, which is infeasible. Our dynamic cap ensures that nodes only branch widely after they have been visited enough, effectively implementing a form of progressive widening that prioritizes depth over breadth initially.

- **Model-Agnostic**: By relying solely on standard completions, our framework can be applied to any LLM without special endpoints or fine-tuning. This broadens its applicability.

### 5.2 Limitations

- **Evaluation Cost**: Each leaf evaluation requires multiple API calls (default 3). While this provides robust scores, it adds cost. Future work could explore using a single scoring call with higher temperature or learned reward models.

- **Step Boundary Insensitivity**: While we argue that fixed boundaries don't harm reasoning, they may occasionally cut a thought mid-sentence. However, because we reconstruct the full answer during evaluation, the segmentation is invisible to the model; it only affects the tree structure. We observed no negative impact.

- **Baseline-First Bias**: The baseline-first strategy may unfairly give MCTS an advantage only on initially wrong problems, but in practice, running MCTS on already-correct problems could potentially degrade them (regression). Our results show no regressions, but the sample size is small.

## 6. Conclusion

We presented MCTS-Stepwise Reasoning, a practical and efficient tree search algorithm for LLM reasoning. By integrating self-refine guidance, fixed-length segmentation, and adaptive child limits, we overcome key limitations of prior methods: high computational cost, need for structured outputs, and inability to handle massive action spaces. Experiments on challenging math problems demonstrate substantial accuracy gains with modest overhead, validating the design choices. The framework is model-agnostic and ready for deployment in applications requiring improved reasoning reliability.

## References

[1] Yao, S., et al. "Tree of Thoughts: Deliberate Problem Solving with Large Language Models." NeurIPS 2023.

[2] Zhang, S., et al. "MCTSr: Monte Carlo Tree Search with Self-Refine for Mathematical Reasoning." arXiv preprint arXiv:2406.07394.

[3] Kocsis, L., & Szepesvári, C. "Bandit based Monte-Carlo Planning." ECML 2006.

[4] Wei, J., et al. "Chain-of-Thought Prompting Elicits Reasoning in Large Language Models." NeurIPS 2022.

[5] Madaan, A., et al. "Self-Refine: Iterative Refinement with Self-Feedback." NeurIPS 2023.

[6] Chaslot, G., et al. "Progressive Strategies for Monte-Carlo Tree Search." New Mathematics and Natural Computation 2008.

---

*Note: This work was conducted as part of my Final Year Project (FYP). The code and additional materials are available upon request.*

In [None]:
import pandas as pd
import time
import requests
import json
import random
import math
import numpy as np
import re
import os
import tiktoken
from typing import List, Dict, Optional, Tuple
from datasets import load_dataset
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import traceback
from datetime import datetime

class MCTS_NODE:
    def __init__(self, parent_idx: int, content: str):
        self.parent_idx = parent_idx
        self.content = content
        self.visit_count = 0
        self.children_indices = []
        self.Q_value = 0.0
        self.reward_samples = []
        self.critique = ""
        self.fully_expanded = False

    def is_leaf(self) -> bool:
        return len(self.children_indices) == 0

    def to_dict(self) -> dict:
        return {
            'parent_idx': self.parent_idx,
            'content': self.content,
            'visit_count': self.visit_count,
            'children_indices': self.children_indices,
            'Q_value': self.Q_value,
            'reward_samples': self.reward_samples,
            'critique': self.critique,
            'fully_expanded': self.fully_expanded,
            'is_leaf': self.is_leaf()
        }

class MCTS_STEPWISE_REASONING:
    def __init__(self, api_key: str, base_url: str = "https://api.deepseek.com/beta",
                 model: str = "deepseek-reasoner", max_children: int = 3, c: float = 1.4,
                 step_length: int = 256, log_file: str = None,
                 k: float = 2.0, alpha: float = 0.5,
                 # ── 新增：max_tokens 超参数（reasoner 支持最高 64k）──
                 max_tokens: int = 16384,
                 # ── 新增：各阶段独立温度超参数 ──
                 baseline_temperature: float = 0.0,   # 基线推理温度
                 explore_temperature: float = 0.8,    # MCTS 探索展开温度
                 evaluate_temperature: float = 0.7):  # 评估打分温度
        self.api_key = api_key
        self.base_url = base_url
        self.model = model
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        self.max_children = max_children
        self.c = c
        self.step_length = step_length
        self.k = k
        self.alpha = alpha

        # 温度超参数
        self.baseline_temperature = baseline_temperature
        self.explore_temperature = explore_temperature
        self.evaluate_temperature = evaluate_temperature

        # max_tokens 超参数
        self.max_tokens = max_tokens

        self.nodes: List[MCTS_NODE] = []
        self.query = None
        self.total_requests = 0
        self.total_tokens = 0
        self.log_file = log_file
        self.logs = []

        try:
            self.tokenizer = tiktoken.encoding_for_model("gpt-4")
        except:
            self.tokenizer = tiktoken.get_encoding("cl100k_base")

    def get_effective_max_children(self, node_idx: int) -> int:
        """
        计算节点的动态子节点上限。
        随访问次数自适应增长: min(max_children, floor(k * visit_count ** alpha))
        保证至少为1,避免节点永远无法展开。
        """
        visit_count = self.nodes[node_idx].visit_count
        dynamic_limit = int(self.k * (visit_count ** self.alpha))
        effective = min(self.max_children, dynamic_limit)
        effective = max(1, effective)
        self.log(f"[DEBUG] 节点 {node_idx} 动态子节点上限: min({self.max_children}, floor({self.k}*{visit_count}^{self.alpha})) = min({self.max_children}, {dynamic_limit}) = {effective}")
        return effective

    def log(self, message: str):
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_entry = f"[{timestamp}] {message}"
        self.logs.append(log_entry)
        print(log_entry)

    def call_api(self, messages: List[Dict], temperature: float = None,
                 max_tokens: int = None, retry_count = 0, retry_max = 3) -> Tuple[str, Dict]:
        """
        调用 API。
        - temperature 默认使用实例级 explore_temperature（可被调用方显式覆盖）。
        - max_tokens  默认使用实例级 self.max_tokens（可被调用方显式覆盖）。
        """
        if temperature is None:
            temperature = self.explore_temperature
        if max_tokens is None:
            max_tokens = self.max_tokens

        url = f"{self.base_url}/chat/completions"
        payload = {
            "model": self.model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        }
        try:
            response = requests.post(url, headers=self.headers, json=payload, timeout=120)
            response.raise_for_status()
            result = response.json()
            text = result["choices"][0]["message"]["content"]
            usage = result.get("usage", {})
            self.total_requests += 1
            self.total_tokens += usage.get("total_tokens", 0)
            return text.strip(), usage
        except requests.exceptions.Timeout:
            self.log(f"[ERROR] API请求超时,等待2秒后重试...")
            time.sleep(2)
            return self.call_api(messages, temperature, max_tokens, retry_count, retry_max)
        except requests.exceptions.RequestException as e:
            self.log(f"[ERROR] API调用错误: {e}")
            self.log(f"[ERROR] 重试次数: {retry_count + 1}/{retry_max}")
            if hasattr(e, 'response') and e.response is not None:
                self.log(f"[ERROR] 响应状态码: {e.response.status_code}")
                self.log(f"[ERROR] 响应内容: {e.response.text}")
            self.log(f"[ERROR] 完整错误信息:\n{traceback.format_exc()}")
            if retry_count < retry_max:
                wait_time = 2 ** retry_count
                self.log(f"[ERROR] 等待 {wait_time} 秒后重试...")
                time.sleep(wait_time)
                return self.call_api(messages, temperature, max_tokens, retry_count + 1, retry_max)
            else:
                self.log(f"[ERROR] 已达到最大重试次数,API调用失败")
                return "API call failed", {}

    def call_api_batch(self, messages_list: List[List[Dict]], temperature: float = None,
                       max_tokens: int = None) -> List[Tuple[str, Dict]]:
        if temperature is None:
            temperature = self.evaluate_temperature
        if max_tokens is None:
            max_tokens = self.max_tokens

        self.log(f"[DEBUG] 开始批量API调用,共 {len(messages_list)} 个请求")
        results = []
        with ThreadPoolExecutor(max_workers=min(10, len(messages_list))) as executor:
            futures = [
                executor.submit(self.call_api, messages, temperature, max_tokens)
                for messages in messages_list
            ]
            for future in as_completed(futures):
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    self.log(f"[ERROR] 批量API调用中的某个请求失败: {e}")
                    results.append(("", {}))
        self.log(f"[DEBUG] 批量API调用完成,成功 {len([r for r in results if r[0]])} 个")
        return results

    def generate_complete_answer(self, query: str, prev_answer_with_critiques: str = "",
                                 prefix: str = "", for_baseline: bool = False) -> str:
        self.log(f"[DEBUG] ========== generate_complete_answer 被调用 ==========")
        self.log(f"[DEBUG] for_baseline: {for_baseline}")
        self.log(f"[DEBUG] query:\n{query[:100]}......")

        # 对 reasoner 模型，确保 prefix 始终以 "<think>\n" 开头，
        # 无论调用方是否传入 prefix（含基线生成），都能触发 CoT。
        if "reasoner" in self.model.lower() and not prefix.startswith("<think>"):
            prefix = "<think>\n" + prefix

        if prefix or prev_answer_with_critiques:
            user_content = f"Question: {query}\n"
            if prev_answer_with_critiques:
                user_content += f"\nHere are some previous responses to the question with their critiques:\n{prev_answer_with_critiques}\n"
                user_content += "\nPlease provide a NEW answer that addresses these critiques and improves upon the previous attempts.\n"
            else:
                user_content += "\nPlease provide a complete answer.\n"
            messages = [{"role": "user", "content": user_content}]
            self.log(f"[DEBUG] user_content:\n{user_content}")
            if prefix:
                messages.append({"role": "assistant", "content": prefix, "prefix": True})
                self.log(f"[DEBUG] prefix:\n{prefix}")
        else:
            user_content = f"Question: {query}\n\nPlease provide a complete answer.\n"
            messages = [{"role": "user", "content": user_content}]
            self.log(f"[DEBUG] user_content (无prefix):\n{user_content}")

        # ── 根据调用场景选择对应的温度超参数 ──
        temp = self.baseline_temperature if for_baseline else self.explore_temperature
        self.log(f"[DEBUG] 使用 temperature: {temp} ({'baseline' if for_baseline else 'explore'})")
        answer, _ = self.call_api(messages, temperature=temp, max_tokens=self.max_tokens)
        self.log(f"[DEBUG] 生成的完整 answer:\n{answer}")
        self.log(f"[DEBUG] ========== generate_complete_answer 结束 ==========\n")
        return answer

    def split_answer_into_steps(self, answer: str) -> List[str]:
        self.log(f"[DEBUG] ========== split_answer_into_steps 被调用 ==========")
        self.log(f"[DEBUG] 完整输入 answer 长度: {len(answer)} 字符")
        tokens = self.tokenizer.encode(answer)
        total_tokens = len(tokens)
        self.log(f"[DEBUG] 总token数: {total_tokens}")

        if total_tokens <= self.step_length:
            self.log(f"[DEBUG] 答案长度小于step_length,作为单一步骤返回")
            return [f"{answer}"]

        steps = []
        step_num = 1
        for i in range(0, total_tokens, self.step_length):
            step_tokens = tokens[i:i + self.step_length]
            step_text = self.tokenizer.decode(step_tokens)
            formatted_step = f"{step_text}"
            steps.append(formatted_step)
            step_num += 1
            self.log(f"[DEBUG] 步骤 {step_num - 1}: {len(step_tokens)} tokens")

        self.log(f"[DEBUG] 最终得到 {len(steps)} 个步骤")
        self.log(f"[DEBUG] ========== split_answer_into_steps 结束 ==========\n")
        return steps

    def evaluate_complete_answer(self, query: str, complete_answer: str,
                                 num_samples: int = 3, max_retries: int = 3) -> Tuple[List[float], str]:
        self.log(f"[DEBUG] ========== evaluate_complete_answer 被调用 ==========")
        self.log(f"[DEBUG] 评估样本数: {num_samples}")
        self.log(f"[DEBUG] query:\n{query[:100]}......")
        self.log(f"[DEBUG] complete_answer:\n{complete_answer[:100]}......")

        evaluation_prompt = f"""Question: {query}
Answer: {complete_answer}
Please evaluate this answer critically using a deduction-based scoring system starting from a perfect score of 100 down to a minimum of 0. For each potential issue you identify in the answer, deduct 5 points. Consider the following aspects:
- Correctness and accuracy
- Logical flow and reasoning
- Completeness
- Clarity
Be strict in your evaluation. In your analysis, clearly describe each identified issue, explain why it is a problem, and explicitly state the 5-point deduction for it. Sum up all deductions to arrive at the final score.

Format your response as:
[Analysis] Your detailed analysis here, maintaining logical flow while listing issues and their deductions
[Score] final_score_between_0_and_100"""

        messages_list = [
            [{"role": "user", "content": evaluation_prompt}]
            for _ in range(num_samples)
        ]
        # 显式传入 evaluate_temperature
        results = self.call_api_batch(messages_list,
                                      temperature=self.evaluate_temperature,
                                      max_tokens=self.max_tokens)

        scores = []
        critiques = []

        for i, (text, usage) in enumerate(results):
            self.log(f"[DEBUG] --- 处理评估样本 {i+1}/{num_samples} ---")
            score = self._extract_score_with_retry(text, query, complete_answer, max_retries)
            if score is not None:
                analysis_match = re.search(r'\[Analysis\](.*?)(?=\[Score\]|$)', text, re.DOTALL | re.IGNORECASE)
                critique_temp = analysis_match.group(1).strip() if analysis_match else ""
                scores.append(score)
                critiques.append(critique_temp)
                self.log(f"[DEBUG] 样本 {i+1} 成功提取 Score: {score}")
            else:
                self.log(f"[WARNING] 样本 {i+1} 在 {max_retries} 次重试后仍无法提取分数,跳过此样本")

        if not scores:
            self.log(f"[ERROR] 所有评估样本都失败,使用默认值")
            return [0], "Unable to generate critique after multiple attempts."

        min_score = min(scores)
        min_index = scores.index(min_score)
        critique = critiques[min_index]

        self.log(f"[DEBUG] 所有评分: {scores}")
        self.log(f"[DEBUG] 选择最低分: {min_score} (索引 {min_index})")
        self.log(f"[DEBUG] 对应的完整 critique:\n{critique}")
        final_scores = [min_score]
        self.log(f"[DEBUG] 最终返回 scores: {final_scores}, critique 长度: {len(critique)}")
        self.log(f"[DEBUG] ========== evaluate_complete_answer 结束 ==========\n")
        return final_scores, critique

    def _extract_score_with_retry(self, text: str, query: str, complete_answer: str,
                                   max_retries: int) -> Optional[int]:
        score_match = re.search(r'\[Score\]\s*(\d+)', text, re.IGNORECASE)
        if score_match:
            score = int(score_match.group(1))
            return max(0, min(100, score))

        for retry in range(max_retries):
            self.log(f"[DEBUG] 分数提取失败,尝试重试 {retry + 1}/{max_retries}")
            retry_prompt = f"""Question: {query}
Answer: {complete_answer}
Please evaluate this answer and provide a score between 0 and 100.
Use a deduction-based system starting from 100, deducting 5 points for each issue found.

You MUST format your response EXACTLY as:
[Analysis] Your analysis here
[Score] numeric_score_here

Example:
[Analysis] The answer contains 3 issues...
[Score] 85"""
            messages = [{"role": "user", "content": retry_prompt}]
            retry_text, _ = self.call_api(messages, temperature=0.5, max_tokens=min(2048, self.max_tokens))
            score_match = re.search(r'\[Score\]\s*(\d+)', retry_text, re.IGNORECASE)
            if score_match:
                score = int(score_match.group(1))
                self.log(f"[DEBUG] 重试成功,提取到分数: {score}")
                return max(0, min(100, score))
            time.sleep(1)
        return None

    def calculate_q_value(self, reward_samples: List[float]) -> float:
        if not reward_samples:
            return 0.0
        return min(reward_samples)

    def get_uct_value(self, parent_idx: int, child_idx: int) -> float:
        parent = self.nodes[parent_idx]
        child = self.nodes[child_idx]
        if child.visit_count == 0:
            return float('inf')
        exploitation = child.Q_value
        exploration = self.c * math.sqrt(math.log(parent.visit_count + 1) / (child.visit_count + 1e-6))
        return exploitation + exploration

    def get_unexpanded_uct_value(self, node_idx: int) -> float:
        parent = self.nodes[node_idx]
        exploration = self.c * math.sqrt(math.log(parent.visit_count + 1) / 1e-6)
        return exploration

    def select_path_to_expand(self) -> int:
        self.log("[DEBUG] ========== select_path_to_expand 被调用 ==========")
        current_idx = 0
        path = [0]
        decision_log = []

        while True:
            current_node = self.nodes[current_idx]
            self.log(f"[DEBUG] 当前节点索引: {current_idx}")
            self.log(f"[DEBUG] 当前节点是否为叶节点: {current_node.is_leaf()}")
            self.log(f"[DEBUG] 当前节点子节点: {current_node.children_indices}")
            self.log(f"[DEBUG] 当前节点完全展开状态: {current_node.fully_expanded}")
            self.log(f"[DEBUG] 当前节点访问次数: {current_node.visit_count}")
            self.log(f"[DEBUG] 当前节点Q值: {current_node.Q_value}")

            if current_node.is_leaf():
                self.log(f"[DEBUG] 找到叶节点 {current_idx}")
                self.log(f"[DEBUG] 选择路径: {path}")
                self.log(f"[DEBUG] 决策过程:\n" + "\n".join(decision_log))
                self.log(f"[DEBUG] ========== select_path_to_expand 结束 ==========\n")
                return current_idx

            if current_node.fully_expanded:
                decision_log.append(f"节点 {current_idx}: 已完全展开")
                if current_node.children_indices:
                    uct_values = {}
                    for child_idx in current_node.children_indices:
                        uct = self.get_uct_value(current_idx, child_idx)
                        uct_values[child_idx] = uct
                        child = self.nodes[child_idx]
                        self.log(f"[DEBUG] 子节点 {child_idx}: UCT={uct:.4f}, Q={child.Q_value:.4f}, visits={child.visit_count}")
                    best_child = max(current_node.children_indices,
                                   key=lambda child_idx: self.get_uct_value(current_idx, child_idx))
                    decision_log.append(f"选择最佳子节点 {best_child} (UCT={uct_values[best_child]:.4f})")
                    self.log(f"[DEBUG] 选择最佳子节点: {best_child}")
                    current_idx = best_child
                    path.append(current_idx)
                    continue
                else:
                    decision_log.append(f"节点 {current_idx}: 已完全展开但无子节点")
                    self.log(f"[DEBUG] 节点 {current_idx} 已完全展开但无子节点")
                    self.log(f"[DEBUG] ========== select_path_to_expand 结束 ==========\n")
                    return current_idx
            else:
                decision_log.append(f"节点 {current_idx}: 未完全展开")
                unexpanded_uct = self.get_unexpanded_uct_value(current_idx)
                self.log(f"[DEBUG] 未展开节点的UCT值: {unexpanded_uct:.4f}")

                if current_node.children_indices:
                    uct_values = {}
                    for child_idx in current_node.children_indices:
                        uct = self.get_uct_value(current_idx, child_idx)
                        uct_values[child_idx] = uct
                        child = self.nodes[child_idx]
                        self.log(f"[DEBUG] 现有子节点 {child_idx}: UCT={uct:.4f}, Q={child.Q_value:.4f}, visits={child.visit_count}")

                    best_child_uct = max(self.get_uct_value(current_idx, child_idx)
                                       for child_idx in current_node.children_indices)
                    self.log(f"[DEBUG] 现有子节点最佳UCT: {best_child_uct:.4f}")

                    if best_child_uct > unexpanded_uct:
                        best_child = max(current_node.children_indices,
                                       key=lambda child_idx: self.get_uct_value(current_idx, child_idx))
                        decision_log.append(f"现有子节点UCT更高,选择子节点 {best_child} (UCT={best_child_uct:.4f} > {unexpanded_uct:.4f})")
                        self.log(f"[DEBUG] 现有子节点UCT更高,选择子节点 {best_child}")
                        current_idx = best_child
                        path.append(current_idx)
                        continue
                    else:
                        decision_log.append(f"未展开节点UCT更高 ({unexpanded_uct:.4f} >= {best_child_uct:.4f}),选择当前节点进行展开")
                        self.log(f"[DEBUG] 未展开节点UCT更高,选择当前节点 {current_idx} 进行展开")
                else:
                    decision_log.append(f"节点无子节点,未展开UCT={unexpanded_uct:.4f},选择当前节点进行展开")
                    self.log(f"[DEBUG] 节点无子节点,选择当前节点 {current_idx} 进行展开")

                self.log(f"[DEBUG] ========== select_path_to_expand 结束 ==========\n")
                return current_idx

    def expand_node(self, node_idx: int) -> int:
        self.log(f"[DEBUG] ========== expand_node 被调用 ==========")
        self.log(f"[DEBUG] 展开节点索引: {node_idx}")

        if self.nodes[node_idx].is_leaf():
            self.log(f"[DEBUG] 展开节点为叶节点，直接返回")
            return node_idx

        node = self.nodes[node_idx]
        self.log(f"[DEBUG] 节点当前子节点数: {len(node.children_indices)}")
        self.log(f"[DEBUG] 最大子节点数(静态): {self.max_children}")
        self.log(f"[DEBUG] 节点内容:\n{node.content[:100]}......")

        effective_max = self.get_effective_max_children(node_idx)
        self.log(f"[DEBUG] 当前有效子节点上限: {effective_max}")

        if len(node.children_indices) > effective_max:
            node.fully_expanded = True
            if len(node.children_indices) > self.max_children:
                self.log(f"[DEBUG] 节点 {node_idx} 已达全局最大子节点数,标记为完全展开")
            else:
                self.log(f"[DEBUG] 节点 {node_idx} 已达当前动态上限({effective_max}),暂不展开")
            self.log(f"[DEBUG] ========== expand_node 结束 (达当前子节点上限) ==========\n")
            return -1

        self.log(f"[DEBUG] 准备展开上下文...")
        prev_answer_with_critiques, prefix = self.prepare_expansion_context(node_idx)

        self.log(f"[DEBUG] 生成新的完整答案...")
        complete_answer = self.generate_complete_answer(self.query, prev_answer_with_critiques, prefix, for_baseline=False)
        self.log(f"[DEBUG] 生成的完整答案:\n{complete_answer}")

        self.log(f"[DEBUG] 将答案分割为步骤...")
        steps = self.split_answer_into_steps(complete_answer)

        if not steps:
            self.log(f"[DEBUG] 没有生成步骤,展开失败")
            return -1

        chain_start_idx = len(self.nodes)
        self.log(f"[DEBUG] 创建新链,起始索引: {chain_start_idx}")
        self.log(f"[DEBUG] 链包含 {len(steps)} 个步骤")

        for i, step in enumerate(steps):
            new_node = MCTS_NODE(
                parent_idx=node_idx if i == 0 else len(self.nodes) - 1,
                content=step
            )

            if i == len(steps) - 1:
                full_answer = prefix + ''.join(steps)
                self.log(f"[DEBUG] 这是最后一个步骤,进行评估")
                new_node.reward_samples, new_node.critique = self.evaluate_complete_answer(
                    self.query, full_answer
                )
                new_node.Q_value = self.calculate_q_value(new_node.reward_samples)
                self.log(f"[DEBUG] 评估结果 - reward_samples: {new_node.reward_samples}, Q_value: {new_node.Q_value}")

            new_node.visit_count = 0
            node_idx_before = len(self.nodes)
            self.nodes.append(new_node)
            self.log(f"[DEBUG] 新节点已添加到树中,索引: {node_idx_before}")

            if i == 0:
                node.children_indices.append(len(self.nodes) - 1)
            else:
                self.nodes[-2].children_indices.append(len(self.nodes) - 1)

        if len(node.children_indices) >= self.max_children:
            node.fully_expanded = True
            self.log(f"[DEBUG] 节点 {node_idx} 现在有 {len(node.children_indices)} 个子节点,标记为完全展开")

        self.log(f"[DEBUG] 展开完成,返回链起始索引: {chain_start_idx}")
        self.log(f"[DEBUG] ========== expand_node 结束 ==========\n")
        return chain_start_idx

    def prepare_expansion_context(self, node_idx: int) -> Tuple[str, str]:
        self.log(f"[DEBUG] ========== prepare_expansion_context 被调用 ==========")
        node = self.nodes[node_idx]
        prev_answer_with_critiques = ""
        prefix = self.get_path_to_node(node_idx)

        if node.children_indices:
            last_child_idx = node.children_indices[-1]
            leaf_node = self.find_leaf_node(last_child_idx)
            leaf_idx = self.nodes.index(leaf_node) if leaf_node else -1
            complete_answer = self.get_complete_answer_from_chain(node_idx)

            prev_answer_with_critiques += f"\nPrevious Answer:\n{complete_answer}\n"

            # if leaf_node and leaf_node.reward_samples:
            #     min_score = min(leaf_node.reward_samples)
            #     prev_answer_with_critiques += f"Score: {min_score:.1f}\n"

            if leaf_node and leaf_node.critique:
                prev_answer_with_critiques += f"\nCritique:\n{leaf_node.critique}\n"

        self.log(f"[DEBUG] ========== prepare_expansion_context 结束 ==========\n")
        return prev_answer_with_critiques, prefix

    def get_path_to_node(self, node_idx: int) -> str:
        path = []
        current_idx = node_idx
        while current_idx != -1:
            path.append(self.nodes[current_idx].content)
            current_idx = self.nodes[current_idx].parent_idx
        path.reverse()
        return ''.join(path)

    def get_complete_answer_from_chain(self, start_idx: int) -> str:
        current_idx = start_idx
        complete_answer = self.get_path_to_node(start_idx)
        while self.nodes[current_idx].children_indices:
            child_Q_value = [self.nodes[c].Q_value for c in self.nodes[current_idx].children_indices]
            current_idx = self.nodes[current_idx].children_indices[np.argmax(child_Q_value)]
            complete_answer += self.nodes[current_idx].content
        return complete_answer

    def find_leaf_node(self, start_idx: int) -> Optional['MCTS_NODE']:
        current_idx = start_idx
        while self.nodes[current_idx].children_indices:
            child_Q_value = [self.nodes[c].Q_value for c in self.nodes[current_idx].children_indices]
            current_idx = self.nodes[current_idx].children_indices[np.argmax(child_Q_value)]
        return self.nodes[current_idx]

    def backpropagate(self, leaf_idx: int):
        self.log(f"[DEBUG] ========== backpropagate 被调用 ==========")
        self.log(f"[DEBUG] 从叶节点 {leaf_idx} 开始回溯")

        current_idx = leaf_idx
        update_count = 0
        path = []

        while current_idx != -1:
            current_node = self.nodes[current_idx]
            path.append(current_idx)

            current_node.visit_count += 1

            if not current_node.is_leaf():
                leaf_q_values = [self.nodes[c].Q_value for c in current_node.children_indices]
                if leaf_q_values:
                    old_q = current_node.Q_value
                    current_node.Q_value = np.max(leaf_q_values)
                    self.log(f"[DEBUG] 节点 {current_idx} Q值更新: {old_q:.4f} -> {current_node.Q_value:.4f}")

            # 回溯时同步检查动态上限是否已放开
            if not current_node.is_leaf():
                new_effective = self.get_effective_max_children(current_idx)
                if current_node.fully_expanded and len(current_node.children_indices) < self.max_children:
                    if len(current_node.children_indices) < new_effective:
                        current_node.fully_expanded = False
                        self.log(f"[DEBUG] 节点 {current_idx} 动态上限提升至 {new_effective},取消 fully_expanded")

            update_count += 1
            current_idx = current_node.parent_idx

        self.log(f"[DEBUG] 回溯路径: {path}, 共更新 {update_count} 个节点")
        self.log(f"[DEBUG] ========== backpropagate 结束 ==========\n")

    def initialize_tree(self, query: str):
        self.query = query
        self.nodes = []
        self.log("=" * 70)
        self.log(f"初始化MCTS树 (使用 baseline_temperature={self.baseline_temperature})")

        root_node = MCTS_NODE(parent_idx=-1, content="")
        root_node.visit_count = 0
        self.nodes.append(root_node)
        self.log("创建根节点 (node 0) - 内容为空")

        self.log(f"生成初始答案 (baseline_temperature={self.baseline_temperature})...")
        complete_answer = self.generate_complete_answer(query, for_baseline=True)
        steps = self.split_answer_into_steps(complete_answer)

        if not steps:
            steps = ["Let me solve this problem."]
        self.log(f"分解为 {len(steps)} 个步骤")

        for i, step in enumerate(steps):
            new_node = MCTS_NODE(
                parent_idx=0 if i == 0 else len(self.nodes) - 1,
                content=step
            )
            if i == len(steps) - 1:
                self.log("评估初始答案...")
                new_node.reward_samples, new_node.critique = self.evaluate_complete_answer(
                    query, complete_answer
                )
                new_node.Q_value = self.calculate_q_value(new_node.reward_samples)
                self.log(f"初始评分: {new_node.Q_value:.1f}")
            new_node.visit_count = 0
            self.nodes.append(new_node)

            if i == 0:
                root_node.children_indices.append(len(self.nodes) - 1)
            else:
                self.nodes[-2].children_indices.append(len(self.nodes) - 1)

        self.backpropagate(len(self.nodes) - 1)

    def mcts_iteration(self):
        expand_idx = self.select_path_to_expand()
        if expand_idx == -1:
            return
        new_chain_start = self.expand_node(expand_idx)
        if new_chain_start == -1:
            return
        leaf_node = self.find_leaf_node(new_chain_start)
        if leaf_node:
            leaf_idx = self.nodes.index(leaf_node)
            self.backpropagate(leaf_idx)

    def run(self, query: str, iterations: int = 10) -> str:
        self.log(f"开始MCTS推理 (迭代次数: {iterations})")
        self.log("=" * 70)
        self.initialize_tree(query)

        for i in range(iterations):
            self.log(f"\n{'='*70}")
            self.log(f"迭代 {i+1}/{iterations}")
            self.log('='*70)
            self.mcts_iteration()
            best_leaf = self.get_best_leaf()
            if best_leaf:
                self.log(f"当前最佳评分: {best_leaf.Q_value:.1f}")
                if best_leaf.Q_value >= 90:
                    self.log(f"提前停止: 找到高质量解答")
                    break

        self.log("\n" + "=" * 70)
        self.log("MCTS推理完成")
        self.log(f"总API调用次数: {self.total_requests}")
        self.log(f"总Token使用: {self.total_tokens}")
        self.log(f"总节点数: {len(self.nodes)}")
        self.log("=" * 70)

        best_leaf = self.get_best_leaf()
        if best_leaf:
            return self.reconstruct_answer_to_leaf(best_leaf)
        else:
            return "无法生成答案"

    def get_best_leaf(self) -> Optional['MCTS_NODE']:
        leaf_nodes = [node for node in self.nodes if node.is_leaf() and node.reward_samples]
        if not leaf_nodes:
            return None
        return max(leaf_nodes, key=lambda n: n.Q_value)

    def reconstruct_answer_to_leaf(self, leaf_node: 'MCTS_NODE') -> str:
        leaf_idx = self.nodes.index(leaf_node)
        return self.get_path_to_node(leaf_idx)

    def get_tree_structure(self) -> str:
        tree_info = []
        tree_info.append("\n" + "="*70)
        tree_info.append("完整树结构")
        tree_info.append("="*70 + "\n")

        for idx, node in enumerate(self.nodes):
            tree_info.append(f"\n{'='*70}")
            tree_info.append(f"节点 {idx}")
            tree_info.append('='*70)
            tree_info.append(f"父节点索引: {node.parent_idx}")
            tree_info.append(f"子节点索引: {node.children_indices}")
            tree_info.append(f"访问次数: {node.visit_count}")
            tree_info.append(f"Q值: {node.Q_value:.2f}")
            tree_info.append(f"是否为叶节点: {node.is_leaf()}")
            tree_info.append(f"是否完全展开: {node.fully_expanded}")
            tree_info.append(f"\n内容:\n{node.content}")

            if node.reward_samples:
                tree_info.append(f"\n奖励样本: {node.reward_samples}")
            if node.critique:
                tree_info.append(f"\n评论:\n{node.critique}")
            tree_info.append("")

        return "\n".join(tree_info)

    def save_results(self, output_file: str, problem: str, correct_answer: str,
                    direct_answer: Optional[str], direct_correct: bool,
                    mcts_answer: Optional[str], mcts_correct: bool,
                    mcts_skipped: bool = False):
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("="*70 + "\n")
            f.write("MCTS推理详细结果\n")
            f.write("="*70 + "\n\n")
            f.write("问题:\n")
            f.write(problem + "\n\n")
            f.write(f"正确答案: {correct_answer}\n\n")
            f.write("="*70 + "\n")
            f.write("直接回答结果\n")
            f.write("="*70 + "\n")
            f.write(f"答案: {direct_answer}\n")
            f.write(f"正确: {'✓' if direct_correct else '✗'}\n\n")
            f.write("="*70 + "\n")
            f.write("MCTS结果\n")
            f.write("="*70 + "\n")
            if mcts_skipped:
                f.write("MCTS搜索已跳过 (Baseline已答对)\n")
            else:
                f.write(f"答案: {mcts_answer}\n")
                f.write(f"正确: {'✓' if mcts_correct else '✗'}\n")
                f.write(f"总API调用: {self.total_requests}\n")
                f.write(f"总Token: {self.total_tokens}\n\n")
                best_leaf = self.get_best_leaf()
                if best_leaf:
                    f.write("="*70 + "\n")
                    f.write("最佳答案路径\n")
                    f.write("="*70 + "\n")
                    best_answer = self.reconstruct_answer_to_leaf(best_leaf)
                    f.write(best_answer + "\n\n")
                    f.write(f"最佳Q值: {best_leaf.Q_value:.2f}\n\n")
                f.write(self.get_tree_structure())
                f.write("\n\n" + "="*70 + "\n")
                f.write("详细日志\n")
                f.write("="*70 + "\n\n")
                f.write("\n".join(self.logs))


# ─────────────────────────────────────────────────────────────────────────────
# 工具函数
# ─────────────────────────────────────────────────────────────────────────────

def extract_answer(text: str) -> Optional[str]:
    patterns = [
        r'\\boxed\{([^}]+)\}',
        r'boxed\{([^}]+)\}',
        r'answer\s+is\s+([^\s.,;]+)',
        r'final\s+answer\s*:?\s*([^\s.,;]+)',
        r'therefore[,\s]+([^\s.,;]+)',
        r'=\s*([^\s.,;]+)\s*$',
        r'\$([^\$]+)\$'
    ]
    for pattern in patterns:
        matches = list(re.finditer(pattern, text, re.IGNORECASE | re.MULTILINE))
        if matches:
            answer = matches[-1].group(1).strip()
            answer = answer.replace('\\', '').replace('$', '').strip()
            return answer
    numbers = list(re.finditer(r'\b(\d+\.?\d*)\b', text))
    if numbers:
        return numbers[-1].group(1)
    return None


def normalize_answer(answer: str) -> str:
    if answer is None:
        return ""
    answer = str(answer).strip().replace(' ', '').replace(',', '')
    try:
        if '.' not in answer:
            return str(int(answer))
        else:
            return str(float(answer))
    except:
        return answer.lower()


def compare_answers(predicted: str, correct: str) -> bool:
    pred_norm = normalize_answer(predicted)
    corr_norm = normalize_answer(correct)
    return pred_norm == corr_norm


def process_single_problem(problem_data: dict, api_key: str, iterations: int,
                          output_dir: str, lock: Lock, results_list: list,
                          step_length: int = 256,
                          k: float = 2.0, alpha: float = 0.5,
                          max_tokens: int = 16384,
                          baseline_temperature: float = 0.0,
                          explore_temperature: float = 0.8,
                          evaluate_temperature: float = 0.7,
                          model: str = "deepseek-reasoner") -> dict:
    idx = problem_data['idx']
    problem = problem_data['problem']
    correct_answer = problem_data['correct_answer']
    data_source = problem_data['data_source']

    try:
        print(f"\n{'='*70}")
        print(f"线程处理问题 {idx+1}")
        print(f"数据来源: {data_source}")
        print('='*70)

        print(f"[问题 {idx+1}] 开始直接回答评估...")
        mcts_direct = MCTS_STEPWISE_REASONING(
            api_key=api_key,
            model=model,
            max_children=3,
            c=1,
            step_length=step_length,
            k=k,
            alpha=alpha,
            max_tokens=max_tokens,
            baseline_temperature=baseline_temperature,
            explore_temperature=explore_temperature,
            evaluate_temperature=evaluate_temperature,
        )

        direct_start = time.time()
        direct_full = mcts_direct.generate_complete_answer(problem, for_baseline=True)
        direct_time = time.time() - direct_start
        direct_ans = extract_answer(direct_full)
        direct_correct = compare_answers(direct_ans, correct_answer)
        direct_api_calls = mcts_direct.total_requests
        direct_tokens = mcts_direct.total_tokens

        print(f"[问题 {idx+1}] 直接回答完成: {direct_ans} {'✓' if direct_correct else '✗'}")
        print(f"[问题 {idx+1}] 用时: {direct_time:.2f}秒, API调用: {direct_api_calls}, Tokens: {direct_tokens}")

        mcts_skipped = False
        if direct_correct:
            print(f"[问题 {idx+1}] ✓ Baseline已正确,跳过MCTS搜索")
            mcts_full = direct_full
            mcts_ans = direct_ans
            mcts_correct = True
            mcts_time = 0
            mcts_api_calls = 0
            mcts_tokens = 0
            mcts_skipped = True
            mcts_search = mcts_direct
        else:
            print(f"[问题 {idx+1}] ✗ Baseline错误,启动MCTS搜索...")
            mcts_search = MCTS_STEPWISE_REASONING(
                api_key=api_key,
                model=model,
                max_children=3,
                c=30,
                step_length=step_length,
                k=k,
                alpha=alpha,
                max_tokens=max_tokens,
                baseline_temperature=baseline_temperature,
                explore_temperature=explore_temperature,
                evaluate_temperature=evaluate_temperature,
            )

            mcts_start = time.time()
            mcts_full = mcts_search.run(problem, iterations=iterations)
            mcts_time = time.time() - mcts_start
            mcts_ans = extract_answer(mcts_full)
            mcts_correct = compare_answers(mcts_ans, correct_answer)
            mcts_api_calls = mcts_search.total_requests
            mcts_tokens = mcts_search.total_tokens

            print(f"[问题 {idx+1}] MCTS搜索完成: {mcts_ans} {'✓' if mcts_correct else '✗'}")
            print(f"[问题 {idx+1}] MCTS用时: {mcts_time:.2f}秒, API调用: {mcts_api_calls}, Tokens: {mcts_tokens}")

        output_file = os.path.join(output_dir, f"question_{idx+1}.txt")
        mcts_search.save_results(
            output_file=output_file,
            problem=problem,
            correct_answer=correct_answer,
            direct_answer=direct_ans,
            direct_correct=direct_correct,
            mcts_answer=mcts_ans,
            mcts_correct=mcts_correct,
            mcts_skipped=mcts_skipped
        )

        if direct_correct and mcts_correct:
            comparison = "Both_Correct"
        elif direct_correct and not mcts_correct:
            comparison = "Direct_Only_Correct"
        elif not direct_correct and mcts_correct:
            comparison = "MCTS_Only_Correct"
        else:
            comparison = "Both_Wrong"

        result = {
            'Question_Number': idx + 1,
            'Data_Source': data_source,
            'Problem': problem[:500],
            'Correct_Answer': correct_answer,
            'Direct_Answer': direct_ans,
            'Direct_Correct': direct_correct,
            'Direct_Full_Response': direct_full,
            'Direct_Time_Seconds': round(direct_time, 2),
            'Direct_API_Calls': direct_api_calls,
            'Direct_Tokens': direct_tokens,
            'MCTS_Answer': mcts_ans,
            'MCTS_Correct': mcts_correct,
            'MCTS_Full_Response': mcts_full,
            'MCTS_Time_Seconds': round(mcts_time, 2),
            'MCTS_API_Calls': mcts_api_calls,
            'MCTS_Tokens': mcts_tokens,
            'MCTS_Skipped': mcts_skipped,
            'Comparison': comparison,
            'Improvement': mcts_correct and not direct_correct,
            'Regression': direct_correct and not mcts_correct,
            'Output_File': output_file
        }

        with lock:
            results_list.append(result)

        status = "跳过MCTS" if mcts_skipped else comparison
        print(f"✓ 问题 {idx+1} 处理完成 [{status}]")
        return result

    except Exception as e:
        error_msg = f"处理问题 {idx+1} 时出错: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)

        error_file = os.path.join(output_dir, f"question_{idx+1}_ERROR.txt")
        with open(error_file, 'w', encoding='utf-8') as f:
            f.write(error_msg)

        result = {
            'Question_Number': idx + 1, 'Data_Source': data_source,
            'Problem': problem[:500], 'Correct_Answer': correct_answer,
            'Direct_Answer': None, 'Direct_Correct': False,
            'Direct_Full_Response': f"ERROR: {str(e)}",
            'Direct_Time_Seconds': 0, 'Direct_API_Calls': 0, 'Direct_Tokens': 0,
            'MCTS_Answer': None, 'MCTS_Correct': False,
            'MCTS_Full_Response': f"ERROR: {str(e)}",
            'MCTS_Time_Seconds': 0, 'MCTS_API_Calls': 0, 'MCTS_Tokens': 0,
            'MCTS_Skipped': False, 'Comparison': 'ERROR',
            'Improvement': False, 'Regression': False,
            'Output_File': error_file
        }

        with lock:
            results_list.append(result)

        return result


def evaluate_srt_test_dataset_parallel(
        dataset_name: str = "ftajwar/srt_test_dataset",
        output_csv: str = 'srt_test_mcts_results.csv',
        iterations: int = 15,
        sample_size: int = None,
        max_workers: int = 5,
        step_length: int = 256,
        filter_source: str = None,
        k: float = 2.0,
        alpha: float = 0.5,
        # ── 新增超参数 ──
        model: str = "deepseek-reasoner",
        max_tokens: int = 16384,
        baseline_temperature: float = 0.0,
        explore_temperature: float = 0.8,
        evaluate_temperature: float = 0.7):

    print("="*70)
    print("MCTS并行评估系统 (Baseline-First策略 + 动态子节点上限)")
    print("="*70)
    print(f"模型: {model}")
    print(f"数据集: {dataset_name}")
    print(f"MCTS迭代次数: {iterations}")
    print(f"并行线程数: {max_workers}")
    print(f"步骤长度: {step_length} tokens")
    print(f"max_tokens: {max_tokens}")
    print(f"动态宽度参数: k={k}, alpha={alpha}")
    print(f"温度参数: baseline={baseline_temperature} | explore={explore_temperature} | evaluate={evaluate_temperature}")
    if filter_source:
        print(f"数据源过滤: {filter_source}")
    print()
    print("策略说明:")
    print("- Baseline正确 → 直接跳过MCTS,节省大量token")
    print("- Baseline错误 → 启动MCTS搜索尝试修正")
    print("- deepseek-reasoner: prefix自动插入 <think> 触发CoT")
    print("="*70 + "\n")

    try:
        ds = load_dataset(dataset_name, split='train')
        print(f"✓ 成功加载数据集,共 {len(ds)} 个问题")
    except Exception as e:
        raise ValueError(f"Failed to load dataset '{dataset_name}': {e}")

    if filter_source:
        ds = ds.filter(lambda x: x['data_source'] == filter_source)
        print(f"✓ 过滤后剩余 {len(ds)} 个问题(数据源: {filter_source})")

    if sample_size and sample_size < len(ds):
        ds_sample = ds.select(range(sample_size))
    else:
        ds_sample = ds
        sample_size = len(ds)

    output_dir = os.path.join(os.getcwd(), 'evaluation_data')
    os.makedirs(output_dir, exist_ok=True)
    print(f"✓ 输出目录: {output_dir}\n")

    api_key = "<your_key_here>"

    problems = [
        {
            'idx': idx,
            'problem': row['Problem'],
            'correct_answer': row['Answer'],
            'data_source': row['data_source']
        }
        for idx, row in enumerate(ds_sample)
    ]

    results_list = []
    lock = Lock()
    start_time = time.time()

    print(f"开始并行处理 {len(problems)} 个问题...\n")

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_problem = {
            executor.submit(
                process_single_problem,
                problem, api_key, iterations, output_dir,
                lock, results_list, step_length, k, alpha,
                max_tokens, baseline_temperature, explore_temperature,
                evaluate_temperature, model
            ): problem for problem in problems
        }

        completed = 0
        for future in as_completed(future_to_problem):
            completed += 1
            problem = future_to_problem[future]
            try:
                future.result()
                print(f"\n进度: {completed}/{len(problems)} 完成")
            except Exception as e:
                print(f"\n✗ 问题 {problem['idx']+1} 处理失败: {e}")

    total_time = time.time() - start_time
    results_list.sort(key=lambda x: x['Question_Number'])
    results_df = pd.DataFrame(results_list)
    results_df.to_csv(output_csv, index=False, encoding='utf-8-sig')
    print(f"\n✓ 结果已保存到: {output_csv}")

    correct_count_direct = sum(1 for r in results_list if r['Direct_Correct'])
    correct_count_mcts   = sum(1 for r in results_list if r['MCTS_Correct'])
    mcts_skipped_count   = sum(1 for r in results_list if r.get('MCTS_Skipped', False))
    mcts_triggered_count = len(problems) - mcts_skipped_count
    mcts_only            = sum(1 for r in results_list if r['Comparison'] == 'MCTS_Only_Correct')

    total_direct_api    = sum(r['Direct_API_Calls'] for r in results_list)
    total_direct_tokens = sum(r['Direct_Tokens'] for r in results_list)
    total_mcts_api      = sum(r['MCTS_API_Calls'] for r in results_list)
    total_mcts_tokens   = sum(r['MCTS_Tokens'] for r in results_list)

    sources = {}
    for r in results_list:
        source = r['Data_Source']
        if source not in sources:
            sources[source] = {'total': 0, 'direct_correct': 0, 'mcts_correct': 0}
        sources[source]['total'] += 1
        if r['Direct_Correct']:
            sources[source]['direct_correct'] += 1
        if r['MCTS_Correct']:
            sources[source]['mcts_correct'] += 1

    print("\n" + "="*70)
    print("评估完成 - 最终报告")
    print("="*70)
    print(f"总处理时间: {total_time:.2f}秒 ({total_time/60:.2f}分钟)")
    print(f"平均每题时间: {total_time/len(problems):.2f}秒")
    print(f"\nBaseline准确率: {correct_count_direct}/{len(problems)} ({100*correct_count_direct/len(problems):.1f}%)")
    print(f"最终准确率 (含MCTS修正): {correct_count_mcts}/{len(problems)} ({100*correct_count_mcts/len(problems):.1f}%)")
    print(f"\nMCTS跳过: {mcts_skipped_count} | MCTS触发: {mcts_triggered_count}")
    if mcts_triggered_count > 0:
        print(f"MCTS修正成功率: {mcts_only}/{mcts_triggered_count} ({100*mcts_only/mcts_triggered_count:.1f}%)")

    for source, stats in sources.items():
        print(f"\n{source}: Baseline {stats['direct_correct']}/{stats['total']} | Final {stats['mcts_correct']}/{stats['total']}")

    print(f"\n总API调用: {total_direct_api + total_mcts_api} | 总Token: {total_direct_tokens + total_mcts_tokens}")
    print("="*70)

    return results_df


# ─────────────────────────────────────────────────────────────────────────────
# 入口
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    results = evaluate_srt_test_dataset_parallel(
        dataset_name="ftajwar/srt_test_dataset",
        output_csv='srt_test_mcts_results.csv',
        iterations=12,
        sample_size=30,
        max_workers=30,
        step_length=512,
        filter_source="aime25",
        k=2.0,
        alpha=0.5,
        # ── 新增超参数（按需调整）──
        model="deepseek-reasoner",   # 切换模型
        max_tokens=32768,            # reasoner 支持最高 64k，此处设为 32k
        baseline_temperature=0.0,   # 基线稳定
        explore_temperature=0.8,    # MCTS 探索多样性
        evaluate_temperature=0.7,   # 评估适度随机
    )

MCTS并行评估系统 (Baseline-First策略 + 动态子节点上限)
模型: deepseek-reasoner
数据集: ftajwar/srt_test_dataset
MCTS迭代次数: 12
并行线程数: 30
步骤长度: 512 tokens
max_tokens: 32768
动态宽度参数: k=2.0, alpha=0.5
温度参数: baseline=0.0 | explore=0.8 | evaluate=0.7
数据源过滤: aime25

策略说明:
- Baseline正确 → 直接跳过MCTS,节省大量token
- Baseline错误 → 启动MCTS搜索尝试修正
- deepseek-reasoner: prefix自动插入 <think> 触发CoT

✓ 成功加载数据集,共 273 个问题
✓ 过滤后剩余 30 个问题(数据源: aime25)
✓ 输出目录: c:\Users\PhilipZhu\PycharmProjects\FYP\evaluation_data

开始并行处理 30 个问题...


线程处理问题 1
数据来源: aime25
[问题 1] 开始直接回答评估...
[2026-02-25 17:31:35] [DEBUG] for_baseline: True
[2026-02-25 17:31:35] [DEBUG] query:
Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.......
[2026-02-25 17:31:35] [DEBUG] user_content:
Question: Find the sum of all integer bases $b>9$ for which $17_{b}$ is a divisor of $97_{b}$.

Please provide a complete answer.

[2026-02-25 17:31:35] [DEBUG] prefix:
<think>

[2026-02-25 17:31:35] [DEBUG] 使用 temperature: 0.0 (baseline)

线程处理问题 2
数据来源: