In [38]:
from transformers import AutoTokenizer
from efficient_reasoning.mcts import Node
from efficient_reasoning.utils import last_boxed_only_string, remove_boxed, is_equiv, AutoScoringJudge
from vllm import LLM, SamplingParams
import numpy as np
from datasets import load_dataset
import json
from typing import TypeAlias, Literal, List, Tuple

In [2]:
Benchmark: TypeAlias = Literal["AIME_2024", "MATH-500", "OlympiadBench-674-MATH_TO_EN"]
dataset = 'MATH-500'

In [3]:
path = f'../data/{dataset}/train.jsonl'
datapoints = []
with open(path, 'r') as f:
    for line in f:
        datapoints.append(json.loads(line))

In [4]:
model = "Qwen/Qwen2.5-3B-Instruct"
end_of_text_token = "<|end_of_text|>"
tokenizer = AutoTokenizer.from_pretrained(model)

llm = LLM(model=model, 
          tensor_parallel_size=4)


INFO 02-25 10:23:40 config.py:526] This model supports multiple tasks: {'generate', 'score', 'reward', 'classify', 'embed'}. Defaulting to 'generate'.
INFO 02-25 10:23:41 config.py:1383] Defaulting to use mp for distributed inference
INFO 02-25 10:23:41 llm_engine.py:232] Initializing a V0 LLM engine (v0.7.1) with config: model='Qwen/Qwen2.5-3B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-3B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 02-25 10:23:45 weight_utils.py:251] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  3.44it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00,  3.77it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00,  3.70it/s]



INFO 02-25 10:23:46 model_runner.py:1116] Loading model weights took 1.4820 GB
[1;36m(VllmWorkerProcess pid=3582472)[0;0m INFO 02-25 10:23:46 model_runner.py:1116] Loading model weights took 1.4820 GB
[1;36m(VllmWorkerProcess pid=3582462)[0;0m INFO 02-25 10:23:46 model_runner.py:1116] Loading model weights took 1.4820 GB
[1;36m(VllmWorkerProcess pid=3582467)[0;0m INFO 02-25 10:23:46 model_runner.py:1116] Loading model weights took 1.4820 GB
[1;36m(VllmWorkerProcess pid=3582467)[0;0m INFO 02-25 10:23:51 worker.py:266] Memory profiling takes 4.72 seconds
[1;36m(VllmWorkerProcess pid=3582467)[0;0m INFO 02-25 10:23:51 worker.py:266] the current vLLM instance can use total_gpu_memory (47.53GiB) x gpu_memory_utilization (0.90) = 42.78GiB
[1;36m(VllmWorkerProcess pid=3582467)[0;0m INFO 02-25 10:23:51 worker.py:266] model weights take 1.48GiB; non_torch_memory takes 0.65GiB; PyTorch activation peak memory takes 1.14GiB; the rest of the memory reserved for KV Cache is 39.51GiB.
[1;

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:19<00:00,  1.76it/s]

INFO 02-25 10:24:14 model_runner.py:1563] Graph capturing finished in 20 secs, took 0.46 GiB
[1;36m(VllmWorkerProcess pid=3582467)[0;0m INFO 02-25 10:24:14 model_runner.py:1563] Graph capturing finished in 20 secs, took 0.46 GiB
[1;36m(VllmWorkerProcess pid=3582462)[0;0m INFO 02-25 10:24:14 model_runner.py:1563] Graph capturing finished in 20 secs, took 0.46 GiB
[1;36m(VllmWorkerProcess pid=3582472)[0;0m INFO 02-25 10:24:14 model_runner.py:1563] Graph capturing finished in 20 secs, took 0.46 GiB
INFO 02-25 10:24:14 llm_engine.py:429] init engine (profile, create kv cache, warmup model) took 28.27 seconds





In [5]:
sampling_params = SamplingParams(
    max_tokens=1024,
    temperature=0.7,
    top_p=0.95,
    n=2,
)

In [9]:
def extract_final_answer(response_text_list: List[str], verbose: bool = False) -> Tuple[List[str], List[str]]:
    # # get the last 4 lines of the response text
    last_four_lines_list = []
    for response_text in response_text_list:
        response_text = response_text.strip()
        last_four_lines = "".join(response_text.split("\n")[-4:])
        last_four_lines_list.append(last_four_lines)

    # returning `failed_last_line_list` for debugging purposes
    final_answer_list = []
    failed_list = []

    for last_four_lines in last_four_lines_list:
        # extract final answer with latex box: \boxed{}, \fbox{}, \framebox{}, \x08oxed{}
        boxed_answer = last_boxed_only_string(last_four_lines)
        # if no boxed answer is found, use an error message as the placeholder for the final answer
        if not boxed_answer:
            if verbose:
                print(f"Error: no boxed answer found in the last four lines: {last_four_lines}")
            final_answer_list.append("Error: no boxed answer found")
            failed_list.append(last_four_lines)
            continue
        # if the boxed answer is found, remove the latex box
        else:
            final_answer = remove_boxed(boxed_answer)
            final_answer_list.append(final_answer)

    return final_answer_list, failed_list

def compute_accuracy(
    benchmark: Benchmark, ground_truth_list: List[str], final_answer_list: List[str], verbose: bool = False
) -> List[bool]:
    # check if the number of final answers and ground truths are equal
    assert len(final_answer_list) == len(ground_truth_list), "The number of final answers and ground truths should be equal."

    # initialize the scorer from OlympiadBench, it's almost compatible with the AIME_2024 and MATH benchmarks
    # excpet for cases like \$18.90 and 18.90, which can be handled by `is_equiv` but not `AutoScoringJudge`
    # in general, the `AutoScoringJudge` is more robust and can handle more cases, see test cases in `evaluation.py`
    scorer = AutoScoringJudge()
    accuracy_result_list = []

    # use the corresponding accuracy metric for the benchmark
    for index, line in enumerate(final_answer_list):
        ground_truth = ground_truth_list[index]
        final_answer = final_answer_list[index]

        # if failed to extract the final answer, set the accuracy to False
        if final_answer == "Error: no boxed answer found":
            accuracy_result = False
        # for AIME 2024, `AutoScoringJudge` is completely compatible
        elif benchmark == "AIME_2024":
            accuracy_result = scorer.judge(ground_truth, final_answer)
        # for MATH, use both `is_equiv` and `AutoScoringJudge` for more robust equivalence checking
        elif benchmark == "MATH-500":
            accuracy_result = is_equiv(ground_truth, final_answer) or scorer.judge(ground_truth, final_answer)
        # for OlympiadBench, use the native `AutoScoringJudge`
        elif benchmark == "OlympiadBench-674-MATH_TO_EN":
            ground_truth_answer, precision = ground_truth
            if not precision:
                accuracy_result = scorer.judge(ground_truth_answer, final_answer)
            else:
                accuracy_result = scorer.judge(ground_truth_answer, final_answer, precision=float(precision))
        # other benchmarks are not supported
        else:
            raise ValueError(f"Benchmark: {benchmark} is not supported.")

        # if `verbose` is set True, print the final answer and ground truth
        if verbose:
            print(f"Ground Truth: {ground_truth}, Final Answer: {final_answer}, Accuracy: {accuracy_result}")

        accuracy_result_list.append(accuracy_result)

    return accuracy_result_list


def evaluate(benchmark, responses, ground_truth_list, verbose=False):

    # construct the final answer list
    final_answer_list, failed_list = extract_final_answer(responses, verbose)

    # compute the accuracy result list
    accuracy_result_list = compute_accuracy(benchmark, ground_truth_list, final_answer_list, verbose)
    
    return accuracy_result_list

In [31]:
class Vine(Node):
    def __init__(self, demonstration_steps: List[str], llm: LLM, sampling_params: SamplingParams, curr_step_index: int, target: str, value: float, benchmark: Benchmark):
        super().__init__()
        self.value = value
        self.llm = llm
        self.sampling_params = sampling_params
        self.curr_step_index = curr_step_index
        self.target = target
        self.demonstration_steps = demonstration_steps
        self.benchmark = benchmark
        self.roll_out()
        
    def roll_out(self):
        
        # generate responses
        self.responses = self.llm.generate(
            "".join(self.demonstration_steps[:self.curr_step_index+1]),
            sampling_params=self.sampling_params,
        )
        
        self.responses = [response.text for response in self.responses[0].outputs]
        
        # evaluate the responses
        rewards = evaluate(self.benchmark, self.responses, [self.target]*len(self.responses))
        
        # compute average reward (Q-value)
        self.q_value = np.mean(rewards)
        
        # compute the advantage
        self.advantage = self.q_value - self.value
    
    def find_children(self):
        if self.curr_step_index == len(self.demonstration_steps) - 1:
            return []
        
        return [
            Vine(
                demonstration_steps=self.demonstration_steps,
                llm=self.llm,
                sampling_params=self.sampling_params,
                curr_step_index=self.curr_step_index + 1,
                target=self.target,
                value=self.q_value,
                benchmark=self.benchmark
            )
        ]
    
    def find_random_child(self):
        return self.find_children()[0]
    
    def reward(self):
        return evaluate(self.benchmark, self.responses, [self.target]*len(self.responses))
    
    def is_terminal(self):
        return self.curr_step_index == len(self.demonstration_steps) - 1
    
    def make_move(self):
        return self.find_children()[0]

In [32]:
model_name = model.split("/")[-1]
step_limit = 10
from collections import defaultdict

In [40]:
benchmark = "MATH-500"
games = []
for i, problem in enumerate(datapoints):
    target = problem["answer"]
    demonstration_steps = [problem['problem']] + problem["solution"].split(".")
    demonstration_tokens = []
    for step in demonstration_steps:
        demonstration_tokens.extend(tokenizer.encode(step))
    curr_step_index = 0
    game = {}
    game['problem'] = problem
    game['index'] = i
    game['demonstration_steps'] = demonstration_steps
    game['advantage'] = []
    game['q_value'] = []
    game['value'] = []
    while True:
        vinegame = Vine(
            demonstration_steps=demonstration_steps, 
            llm=llm, 
            sampling_params=sampling_params, 
            curr_step_index=curr_step_index, 
            target=target, 
            value=0, 
            benchmark=benchmark)
        curr_step_index += 1
        game['advantage'].append(vinegame.advantage)
        game['q_value'].append(vinegame.q_value)
        game['value'].append(vinegame.value)
        
        if curr_step_index == len(demonstration_steps):
            break
        else:
            vinegame = vinegame.find_children()[0]
    breakpoint()
            


Processed prompts:  50%|█████     | 1/2 [00:03<00:03,  3.46s/it, est. speed input: 7.23 toks/s, output: 194.43 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  2.21it/s, est. speed input: 108.45 toks/s, output: 214.69 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  1.08it/s, est. speed input: 53.16 toks/s, output: 218.04 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  2.87it/s, est. speed input: 245.79 toks/s, output: 213.94 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  2.83it/s, est. speed input: 241.63 toks/s, output: 170.55 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  9.41it/s, est. speed input: 943.56 toks/s, output: 133.41 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  9.65it/s, est. speed input: 967.71 toks/s, output: 136.80 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  8.94it/s, est. speed input: 901.15 toks/s, output: 127.40 toks/s]


ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1


Processed prompts:  50%|█████     | 1/2 [00:00<00:00,  9.40it/s, est. speed input: 941.18 toks/s, output: 133.07 toks/s]

ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1
ANTLR runtime and generated code versions disagree: 4.13.2!=4.11.1



