In [1]:
import re
import os
import json
import string
from tqdm import tqdm
import torch
import numpy as np
from openai import OpenAI
from collections import Counter
import backoff

In [2]:
port = 8003
model_name = 'Qwen/QwQ-32B'

OPENAI_REQUEST_TIMEOUT = 60*60*24 
client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="EMPTY", timeout=OPENAI_REQUEST_TIMEOUT)
print(client.models.list())

# @backoff.on_exception(backoff.constant, Exception, interval=5)
def run_chat_completion_with_backoff(client, **kwargs):
    return client.chat.completions.create(**kwargs)


@backoff.on_exception(backoff.constant, Exception, interval=5)
def run_generate_with_backoff(client, **kwargs):
    return client.completions.create(**kwargs)

# load direct pred logs
def load_direct_pred_logs(task):
    if task == 'gpqa':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/gpqa.qwq.direct/diamond.7.1,20:8.json'))
    elif task == 'aime':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/aime.qwq.direct/test.7.1,19:42.json'))
    elif task == 'amc':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/amc.qwq.direct/test.7.1,19:24.json'))
    elif task == 'math500':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/math500.qwq.direct/test.7.1,20:17.json'))
    elif task == 'livecode':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/livecode.qwq.direct/test_1to4.7.1,22:30.json'))
    elif task == 'bamboogle':
        logs = json.load(open('/fsx-comem/diwu0162/Search-o1/outputs/runs.qa/bamboogle.qwq.direct/test.7.1,16:22.json'))
    else:
        raise NotImplementedError
    return logs

def segment_thoughts_v1(x):
    return x.strip().split('\n\n')

def segment_thoughts_v2(x):
    # note: excluding things like "so" "therefore", "but", "let me" 
    reasoning_word_list = [
        'okay', 'hmm', 'wait', 'but wait', 'oh wait', 'no wait', 'no, wait', 'but let me', 'but actually', 'alternatively', 
        'now', 'the question', 'ah', 'oh', 'next', 'another angle', 'another approach', 'also', 'hold on', 'looking it up', 
        'another point', 'I don\'t think', 'perhaps I', 'putting this together', 'Putting it all together', 'i\'m', 'but i\'m',   
        'let me think again', 'I don\'t see', 'maybe I', 'alternative', "I wonder if", "another way", 'an alternative', 
    ]
    prefix_len = max([len(x) for x in reasoning_word_list])
    newline_segmented_thoughts = segment_thoughts_v1(x)
    final_thoughts = []
    for t in newline_segmented_thoughts:
        t_lower = t.lower()
        is_segment_start = False
        for r_w in reasoning_word_list:
            if t_lower.startswith(r_w.lower()):
                is_segment_start = True
                break
        if is_segment_start or not final_thoughts:
            final_thoughts.append(t)
        else:
            final_thoughts[-1] += '\n\n' + t
    return final_thoughts


SyncPage[Model](data=[Model(id='Qwen/QwQ-32B', created=1752164388, object='model', owned_by='vllm', root='Qwen/QwQ-32B', parent=None, max_model_len=40960, permission=[{'id': 'modelperm-1733db744c8344ad8e775c7040500d5c', 'object': 'model_permission', 'created': 1752164388, 'allow_create_engine': False, 'allow_sampling': True, 'allow_logprobs': True, 'allow_search_indices': False, 'allow_view': True, 'allow_fine_tuning': False, 'organization': '*', 'group': None, 'is_blocking': False}])], object='list')


In [16]:
# evaluation helper (right now only supporting choice, qa, and math)

def _fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string

def _fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except:
        return string

def _remove_right_units(string):
    # "\\text{ " only ever occurs (at least in the val set) when describing units
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        assert len(splits) == 2
        return splits[0]
    else:
        return string

def _fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0] 
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string

def _strip_string(string):
    # linebreaks  
    string = string.replace("\n", "")
    #print(string)

    # remove inverse spaces
    string = string.replace("\\!", "")
    #print(string)

    # replace \\ with \
    string = string.replace("\\\\", "\\")
    #print(string)

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    #print(string)

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    #print(string)
    
    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")
    
    # remove units (on the right)
    string = _remove_right_units(string)

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace("\%", "")

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    # fix sqrt3 --> sqrt{3}
    string = _fix_sqrt(string)

    # remove spaces
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
    string = _fix_fracs(string)

    # manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = _fix_a_slash_b(string)

    return string

def is_equiv(str1, str2, verbose=False):
    if str1 is None and str2 is None:
        print("WARNING: Both None")
        return True
    if str1 is None or str2 is None:
        return False

    try:
        ss1 = _strip_string(str1)
        ss2 = _strip_string(str2)
        if verbose:
            print(ss1, ss2)
        return ss1 == ss2
    except:
        return str1 == str2
        
def extract_answer_fn(output, mode='qa', extract_answer=False):
    if extract_answer == False and mode not in ['infogen', 'summary', 'research']:
        if mode == 'qa':
            return output.strip()
        pred_answer_lines = output.replace("\n\n", "\n").strip().split('\n')
        pred_answer = '\n'.join(pred_answer_lines[-3:])
        return pred_answer
    extracted_text = ''
    if mode == 'codegen':
        pattern = r'```python\s*(.*?)\s*```'  # Extract the code between ```python and ```
        matches = re.findall(pattern, output, re.DOTALL | re.IGNORECASE)
        if matches:
            extracted_text = matches[-1].strip()  # Take the last match
    elif mode in ['infogen', 'summary', 'research']:
        pattern_info = "**Final Information"
        if "</think>\n" in output:
            extracted_text = output.split("</think>\n")[-1].split("<|begin_click_link|>")[0].replace(pattern_info, "").strip(':**').strip('\n').strip("```").strip()  # 提取</think>后面的内容
            if mode == 'infogen':
                extracted_text = '\n'.join(extracted_text.replace("\n\n", "\n").split('\n')[:5])  # 只保留前5行
        elif pattern_info in output:
            extracted_text = output.split(pattern_info)[-1].split("<|begin_click_link|>")[0].strip('\n').strip(':**').strip("```").strip()  # 提取**Final Information**后面的内容
            if mode == 'infogen':
                extracted_text = '\n'.join(extracted_text.replace("\n\n", "\n").split('\n')[:5])  # 只保留前5行
        else:
            # extracted_text = "No helpful information found."
            extracted_text = '\n'.join(output.strip().replace("</think>\n", "").replace("\n\n", "\n").split('\n')[-5:])  # 若没提取到，只保留最后5行
        if mode == 'research':
            extracted_text = extracted_text[:6000]
        else:
            extracted_text = extracted_text[:2500]
    elif mode in ['math', 'choose', 'qa']:
        pattern = r'\\boxed\{(.*)\}'
        matches = re.findall(pattern, output)
        if matches:
            extracted_text = matches[-1]  # Take the last match
        else:
            pattern = 'ANSWER:'
            if pattern in output:
                extracted_text = output.split(pattern)[-1].strip('**').strip()
        if mode in ['choose']:
            inner_pattern = r'\\text\{(.*)\}'
            inner_matches = re.findall(inner_pattern, extracted_text)
            if inner_matches:
                extracted_text = inner_matches[-1]  # Take the last match
            extracted_text = extracted_text.strip("()")
    return extracted_text
    
def evaluate_predictions(output, labeled_answer, mode='math', use_llm=False, question=None, extract_answer=False):
    final_metric = {"is_valid_answer": False, "acc": 0, "em": 0, "f1": 0, 'math_equal': 0, 'llm_equal': 0}
    pred_answer = extract_answer_fn(output, mode=mode, extract_answer=extract_answer)
    pred_answer_new = pred_answer
    if pred_answer != '':
        final_metric["is_valid_answer"] = True
    else:
        # If no answer was extracted, keep only the last 3 lines
        pred_answer_new = '\n'.join(output.replace("\n\n", "\n").strip().split('\n')[-5:])

    if mode in ['qa']:
        def normalize_answer_qa(s):
            def remove_articles(text):
                return re.sub(r"\b(a|an|the)\b", " ", text)
            def white_space_fix(text):
                return " ".join(text.strip().split())
            def remove_punc(text):
                exclude = set(string.punctuation)
                return "".join(ch for ch in text if ch not in exclude)
            def lower(text):
                return text.lower()
            return white_space_fix(remove_articles(remove_punc(lower(s))))
        normalized_pred_answer = normalize_answer_qa(pred_answer_new)

        for answer in labeled_answer:
            normalized_ground_truth = normalize_answer_qa(answer)
            em = int(normalized_pred_answer == normalized_ground_truth)
            acc = int(normalized_ground_truth in normalized_pred_answer)

            prediction_tokens = normalized_pred_answer.split()
            ground_truth_tokens = normalized_ground_truth.split()
            common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
            num_same = sum(common.values())
            if num_same == 0:
                continue
            precision = 1.0 * num_same / len(prediction_tokens)
            recall = 1.0 * num_same / len(ground_truth_tokens)
            f1 = (2 * precision * recall) / (precision + recall)
            for k in ["em", "acc", "f1"]:
                final_metric[k] = max(eval(k), final_metric[k])

    elif mode in ['math', 'choose']:
        def normalize_answer(text):
            text = text.lower()
            text = " ".join(text.strip().split())
            return text
        normalized_pred_answer = normalize_answer(pred_answer_new)
        normalized_ground_truth = normalize_answer(labeled_answer)

        em = int(normalized_pred_answer == normalized_ground_truth)
        acc = int(normalized_ground_truth in normalized_pred_answer)
    
        prediction_tokens = normalized_pred_answer.split()
        ground_truth_tokens = normalized_ground_truth.split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            f1 = 0
        else:
            precision = 1.0 * num_same / len(prediction_tokens) if len(prediction_tokens) > 0 else 0
            recall = 1.0 * num_same / len(ground_truth_tokens) if len(ground_truth_tokens) > 0 else 0
            if (precision + recall) == 0:
                f1 = 0
            else:
                f1 = (2 * precision * recall) / (precision + recall)

        final_metric["em"] = em
        final_metric["acc"] = acc
        final_metric["f1"] = f1

        final_metric["math_equal"] = is_equiv(normalized_pred_answer, normalized_ground_truth)
        
        # Add LLM-based evaluation if requested
        if use_llm and question is not None:
            final_metric["llm_equal"] = 0  # Will be updated in batch later

    return final_metric, pred_answer

In [5]:
# load rollout log 
task = 'gpqa'
if task in ['gpqa']:
    eval_task_type = 'choose'
    metric_name = 'acc'
elif task in ['aime', 'amc', 'math500']:
    eval_task_type = 'math'
    metric_name = 'math_equal'
# elif task in ['livecode']:
#     eval_task_type = 'code'
elif task in ['gaia', 'bamboogle']:
    eval_task_type = 'qa'
    metric_name = 'f1'
else:
    raise NotImplementedError
assert eval_task_type in ['math', 'choose', 'qa']   # not supporting code here


orig_pred_logs = load_direct_pred_logs(task)
if metric_name in ['acc', 'math_equal']:
    orig_pred_logs_wrong = [x for x in orig_pred_logs if not x['Metrics'][metric_name]]
else:
    orig_pred_logs_wrong = [x for x in orig_pred_logs if x['Metrics'][metric_name] < 0.5]
rollout_logs = [json.loads(line) for line in open('/fsx-comem/diwu0162/Search-o1/explorations/20250707_rollout_analysis/perf_dicts_wrong_only_math500_20250708-1641.json').readlines()]
# len(rollout_logs)
# idx = 1
# print([(x, np.mean(rollout_logs[idx]['perf_dict'][str(x)]['metrics'])) for x in range(len(rollout_logs[idx]['perf_dict']))])
# print(json.dumps(rollout_logs[idx]['item'], indent=4))

In [14]:
def roll_out_single_node(prefix, steps, step_id, answer, eval_task_type, metric_name, n_sample_per_node, node_join_char='\n\n'):
    print(f'Roll out at (before) node {step_id}, {n_sample_per_node} samples per node...')

    perf_dict = {}
    
    cur_prefix_nodes = steps[:step_id]
    cur_prefix_text = prefix + node_join_char.join(cur_prefix_nodes)
    print('Prompt:', [cur_prefix_text])

    responses = client.completions.create(model=model_name, prompt=cur_prefix_text, n=n_sample_per_node, temperature=0.7, top_p=0.8, max_tokens=20000, timeout=OPENAI_REQUEST_TIMEOUT,
                                          extra_body={'top_k': 20, 'include_stop_str_in_output': True, 'repetition_penalty': 1.05,})

    sample_preds = [x.text for x in responses.choices]
    sample_preds_processed = []
    metrics = []
    for cur_pred in sample_preds:
        metrics_dict, processed_pred = evaluate_predictions(cur_pred, answer, mode=eval_task_type, use_llm=False, question=None, extract_answer=True)
        sample_preds_processed.append(processed_pred)
        metrics.append(metrics_dict[metric_name])
    perf_dict[step_id] = {
        'preds': sample_preds,
        'preds_processed': sample_preds_processed,
        'metrics': metrics
    }
    cur_node_mean_metrics = np.mean(metrics)
    print(metrics)
    print(f'Mean metrics {round(cur_node_mean_metrics, 4)}')
    return perf_dict


def roll_out_single_node_with_hint(orig_prefix, hint, steps, step_id, answer, eval_task_type, metric_name, n_sample_per_node, node_join_char='\n\n', max_tokens=24000):
    
    def process_question_instruct_add_hint(instruction):
        if task in ['gpqa']:
            instruction_new = instruction.replace('Please answer the following multiple-choice question.',
                                                  'Please answer the following multiple-choice question. Hints might be provided during your question answering wrapped within [hint] and [end of hint]. If you see hints, try to leverage them to guide your thinking process.')
        elif task in ['aime', 'amc', 'math500']:
            instruction_new = instruction.replace('Please answer the following math question.',
                                                  'Please answer the following math question. Hints might be provided during your question answering wrapped within [hint] and [end of hint]. If you see hints, try to leverage them to guide your thinking process.')
        elif task in ['livecode']:
            raise NotImplementedError
        elif task in ['bamboogle']:
            instruction_new = instruction.replace('Please answer the following question.',
                                                  'Please answer the following question. Hints might be provided during your question answering wrapped within [hint] and [end of hint]. If you see hints, try to leverage them to guide your thinking process.')
        assert instruction_new != instruction
        return instruction_new

    print(f'Roll out with hint injected at (before) node {step_id}, {n_sample_per_node} samples per node...')

    perf_dict = {}
    
    cur_prefix_nodes = steps[:step_id]
    # cur_prefix_text = prefix + node_join_char.join(cur_prefix_nodes)
    formatted_question = process_question_instruct_add_hint(orig_prefix)
    hint_str = f'[hint] {hint} [end of hint]\n\nOkay,'
    final_prompt = formatted_question + node_join_char.join(cur_prefix_nodes) + hint_str

    print('Prompt:', [final_prompt])

    responses = client.completions.create(model=model_name, prompt=final_prompt, n=n_sample_per_node, temperature=0.7, top_p=0.8, max_tokens=max_tokens, timeout=OPENAI_REQUEST_TIMEOUT,
                                          extra_body={'top_k': 20, 'include_stop_str_in_output': True, 'repetition_penalty': 1.05,})

    sample_preds = [x.text for x in responses.choices]
    sample_preds_processed = []
    metrics = []
    for cur_pred in sample_preds:
        metrics_dict, processed_pred = evaluate_predictions(cur_pred, answer, mode=eval_task_type, use_llm=False, question=None, extract_answer=True)
        sample_preds_processed.append(processed_pred)
        metrics.append(metrics_dict[metric_name])
    perf_dict[step_id] = {
        'preds': sample_preds,
        'preds_processed': sample_preds_processed,
        'metrics': metrics
    }
    cur_node_mean_metrics = np.mean(metrics)
    print(metrics)
    print(f'Mean metrics {round(cur_node_mean_metrics, 4)}')
    return perf_dict



In [None]:

idx = 18
entry = orig_pred_logs_wrong[idx]
question_only_prefix = entry['Question']
answer = entry['answer'] if 'answer' in entry else entry["Correct Choice"]
pred_steps = segment_thoughts_v2('\n\n'.join(entry['Output'].split('</think>')[:-1]))

print('Original ID:', entry['id'])
print('Question:', [question_only_prefix])
print('Answer:', answer)

# len(pred_steps)
# print(*pred_steps, sep='\n=============\n')
# print(json.dumps(entry, indent=4))

# step_to_inspect = 0
# # perf_dict_orig = roll_out_single_node(question_only_prefix, pred_steps, step_to_inspect, answer, eval_task_type, metric_name, n_sample_per_node=20, node_join_char='\n\n')
# print(np.mean(perf_dict_orig[step_to_inspect]['metrics']), perf_dict_orig[step_to_inspect]['metrics'])
# print(json.dumps(perf_dict_orig[step_to_inspect]['preds_processed'], indent=4))
# # perf_dict_orig[step_to_inspect]['preds']

step_to_intervene = 0
hint = "For any l=0 \u2192 0 cascade and beyond, follow the reasoning logic:\na) Fix the only l-sequence consistent with two \u0394l = \u00b11 jumps.  \nb) List every m\u2032 that the first E1 step can access (\u0394m = 0, \u00b11).  \nc) Use Clebsch\u2013Gordan factors to divide the amplitude among those m\u2032\u2019s.  \nd) Follow each fraction through any subsequent steps; only then match an individual branch to the answer choices.\n\nIf I stick to that four-point audit trail, the correct probability falls out automatically and I never have to rely on memory tricks."
hint = hint.strip()
perf_dict_intervene = roll_out_single_node_with_hint(question_only_prefix, hint, pred_steps, step_to_intervene, answer, 
                                                     eval_task_type, metric_name, n_sample_per_node=10, node_join_char='\n\n', max_tokens=20000)
print(np.mean(perf_dict_intervene[step_to_intervene]['metrics']), perf_dict_intervene[step_to_intervene]['metrics'])
print(json.dumps(perf_dict_intervene[step_to_intervene]['preds_processed'], indent=4))


Original ID: 55
Question: ['<|im_start|>user\nPlease answer the following multiple-choice question. You should provide your final choice in the format \\boxed{YOUR_CHOICE}.\n\nQuestion:\nThe |3,0,0\\rangle state in the standard notation |n,l,m\\rangle of the H -atom in the non-relativistic theory decays to the state |1,0,0\\rangle via two dipole transition. The transition route and the corresponding probability are (use latex), Choices:\n(A) \\rangle\\rightarrow|2,1,0\\rangle\\rightarrow|1,0,0\\rangle  and \\frac{2}{3}\n(B) |3,0,0\\rangle\\rightarrow|2,1,-1\\rangle\\rightarrow|1,0,0\\rangle and \\frac{1}{4}\n(C) |3,0,0\\rangle\\rightarrow|2,1,1\\rangle\\rightarrow|1,0,0\\rangle\nand \\frac{1}{4}\n(D) |3,0,0\\rangle\\rightarrow|2,1,0\\rangle\\rightarrow|1,0,0\\rangle and \\frac{1}{3}\n\n\n<|im_end|>\n<|im_start|>assistant\n<think>\n']
Answer: D
Roll out with hint injected at (before) node 0, 10 samples per node...
Prompt: ['<|im_start|>user\nPlease answer the following multiple-choice q

In [22]:
print(json.dumps(perf_dict_intervene[0]['preds'][1]))

" so I need to figure out which transition path and probability is correct for the decay from |3,0,0\u27e9 to |1,0,0\u27e9 via two dipole transitions. Let me start by recalling some quantum mechanics here. \n\nFirst, dipole transitions are electric dipole (E1) transitions, right? And the selection rules for E1 transitions require that the angular momentum changes by \u0394l = \u00b11 and \u0394m = 0 or \u00b11. Also, the total angular momentum change has to satisfy certain conditions. Wait, but for hydrogen, since there's no orbital angular momentum for the electron except what comes from the electron's own motion, the selection rules are straightforward.\n\nThe problem states it's a two-step decay. So the initial state is n=3, l=0, m=0. It must go through an intermediate state before reaching the final state |1,0,0\u27e9. The options give different paths and probabilities. Let me go step by step using the hints provided.\n\nThe hint says: For l=0 \u2192 0 cascade via two steps, follow