In [None]:
from IPython.core.display import display, HTML

def highlight_corrected_tokens(response, corrected_tokens, tokenizer):
    """
    使用 HTML 高亮被 target model 修正的 token（红色背景）。
    这里保证 token 级别的一一对应。
    """
    # **使用 tokenizer 进行 token 级别拆分**
    token_ids = tokenizer(response, return_tensors="pt").input_ids[0].tolist()
    tokens = [tokenizer.decode([tid]) for tid in token_ids]  # 确保 token 与 token_id 对应

    highlighted_text = []
    
    for i, token in enumerate(tokens):
        # 检查该位置是否在 corrected_tokens 中
        correction = next((c for c in corrected_tokens if c["position"] == i), None)
        if correction:
            highlighted_text.append(f'<span style="background-color: red; color: white;">{correction["target"]}</span>')
        else:
            highlighted_text.append(token)

    return " ".join(highlighted_text)

def display_results(results, tokenizer):
    """
    在 Notebook 中展示问题、生成的答案，并高亮被修正的 token。
    """
    for item in results:
        question = item["question"]
        generated_answer = item["generated_answer"]
        corrected_tokens = item["corrected_tokens"]

        highlighted_answer = highlight_corrected_tokens(generated_answer, corrected_tokens, tokenizer)

        html_content = f"""
        <div style="border: 1px solid #ccc; padding: 10px; margin-bottom: 10px;">
            <p><b>Question:</b> {question}</p>
            <p><b>Generated Answer:</b> {highlighted_answer}</p>
        </div>
        """
        display(HTML(html_content))

# **加载 JSON 结果**
import json
from transformers import AutoTokenizer

with open("math500_results.json", "r", encoding="utf-8") as f:
    results = json.load(f)

# **加载 tokenizer，确保和生成时一致**
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

# **显示结果**
display_results(results, tokenizer)

In [1]:
import __init__
from utils.data_utils import read_saved_results
from utils.qwen_math_parser import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def check_math_correctness(ref, generation):
    if not find_box(generation): return False
    answer = strip_answer_string(ref)
    pred = extract_answer(generation)
    pred = strip_answer_string(pred)
    return math_equal(pred, answer)


In [3]:
def check_results(path):
    results = read_saved_results(path)
    correct_num, all_time, all_tokens_num, num_8k, num_corr = 0,0,0,0,0
    for r_i, r in enumerate(results):
        right_flag = check_math_correctness(r['answer'], r['generated_answer'])
        correct_num = correct_num + 1 if right_flag else correct_num
        all_time = all_time + r['generation_time']
        all_tokens_num = all_tokens_num + r['num_tokens']
        num_corr = num_corr + len(r['corrected_tokens'])
        if not right_flag:
            print('#:',r_i+1, 'ans:' ,r['answer'], 'time:' ,r['generation_time'], 'num_tokens' ,r['num_tokens'], 'num_corr', len(r['corrected_tokens']))
            if r['num_tokens'] > 8*1024: num_8k = num_8k +1
            
        
    print(f'num: {len(results)}, acc: {correct_num/len(results)}, tokens/s: {all_tokens_num/all_time}, token_num:{all_tokens_num/len(results)}, num_8k:{num_8k}, num_corr:{num_corr/all_tokens_num}')

In [20]:
check_results('/home/wxy320/ondemand/program/speculative_thinking/results/MATH500_deepseek-32b_None.json')

#: 2 ans: p - q time: 41.044524908065796 num_tokens 415 num_corr 0
#: 19 ans: 28 time: 339.0802273750305 num_tokens 3823 num_corr 0
#: 21 ans: 6+9i time: 90.07904958724976 num_tokens 1008 num_corr 0
#: 22 ans: 13535 time: 38.0232412815094 num_tokens 425 num_corr 0
#: 25 ans: 10 time: 235.81136083602905 num_tokens 2634 num_corr 0
#: 26 ans: 1,-2 time: 168.15165328979492 num_tokens 1886 num_corr 0
#: 33 ans: 720 time: 835.905914068222 num_tokens 8683 num_corr 0
#: 34 ans: \frac{243}{625} time: 96.51075792312622 num_tokens 1110 num_corr 0
#: 37 ans: 3, 5, 7 time: 138.31290650367737 num_tokens 1586 num_corr 0
#: 44 ans: 70 \sqrt{2} time: 101.02843022346497 num_tokens 1150 num_corr 0
#: 60 ans: 5 time: 73.34125518798828 num_tokens 844 num_corr 0
#: 72 ans: 2516_8 time: 37.99384117126465 num_tokens 436 num_corr 0
#: 81 ans: 501 time: 46.224913358688354 num_tokens 530 num_corr 0
#: 91 ans: \frac{3}{2} time: 23.03094792366028 num_tokens 263 num_corr 0
#: 95 ans: 80 time: 496.0704679489136 num_

In [4]:
check_results('/home/wxy320/ondemand/program/speculative_thinking/results/MATH500_deepseek-32b_deepseek-1.5b_0.05.json')

#: 10 ans: 4 time: 392.92427802085876 num_tokens 5820 num_corr 428
#: 12 ans: \frac{3}{56} time: 58.335697412490845 num_tokens 961 num_corr 26
#: 16 ans: 6 - 5i time: 18.957124948501587 num_tokens 308 num_corr 10
#: 25 ans: 10 time: 31.40611743927002 num_tokens 502 num_corr 19
#: 27 ans: 144 time: 2531.0188755989075 num_tokens 32771 num_corr 1056
#: 37 ans: 3, 5, 7 time: 158.38452625274658 num_tokens 2642 num_corr 59
#: 44 ans: 70 \sqrt{2} time: 32.343971967697144 num_tokens 519 num_corr 16
#: 51 ans: 203 time: 121.12784457206726 num_tokens 1982 num_corr 57
#: 65 ans: \frac{35}{64} time: 232.97329211235046 num_tokens 3817 num_corr 104
#: 87 ans: 240 time: 30.359503746032715 num_tokens 434 num_corr 36
#: 90 ans: 21 time: 26.805914640426636 num_tokens 434 num_corr 14
#: 91 ans: \frac{3}{2} time: 19.383261919021606 num_tokens 296 num_corr 16
#: 95 ans: 80 time: 159.01009392738342 num_tokens 2459 num_corr 149
#: 97 ans: 1 \pm \sqrt{19} time: 356.400422334671 num_tokens 5773 num_corr 162
#:

In [22]:
check_results('/home/wxy320/ondemand/program/speculative_thinking/results/MATH500_deepseek-32b_deepseek-1.5b_0.01.json')

#: 8 ans: 90^\circ time: 82.44849371910095 num_tokens 1377 num_corr 32
#: 10 ans: 4 time: 38.15027594566345 num_tokens 603 num_corr 25
#: 12 ans: \frac{3}{56} time: 306.63749527931213 num_tokens 4916 num_corr 86
#: 13 ans: 284 time: 34.6164288520813 num_tokens 580 num_corr 12
#: 18 ans: \pi time: 45.182859897613525 num_tokens 758 num_corr 17
#: 19 ans: 28 time: 738.8522665500641 num_tokens 11072 num_corr 248
#: 21 ans: 6+9i time: 216.89724802970886 num_tokens 3648 num_corr 63
#: 22 ans: 13535 time: 68.06164741516113 num_tokens 1164 num_corr 11
#: 27 ans: 144 time: 58.482487201690674 num_tokens 975 num_corr 22
#: 37 ans: 3, 5, 7 time: 215.31054759025574 num_tokens 3702 num_corr 27
#: 44 ans: 70 \sqrt{2} time: 106.77568411827087 num_tokens 1766 num_corr 43
#: 72 ans: 2516_8 time: 155.97350811958313 num_tokens 2644 num_corr 36
#: 73 ans: 3 time: 18.223875284194946 num_tokens 295 num_corr 9
#: 83 ans: \frac{3}{2} time: 30.82742738723755 num_tokens 505 num_corr 11
#: 90 ans: 21 time: 66.578

In [23]:
check_results('/home/wxy320/ondemand/program/speculative_thinking/results/MATH500_deepseek-32b_deepseek-1.5b_3.json')

#: 10 ans: 4 time: 40.4138605594635 num_tokens 717 num_corr 17
#: 12 ans: \frac{3}{56} time: 221.57176995277405 num_tokens 4180 num_corr 57
#: 19 ans: 28 time: 19.468947410583496 num_tokens 329 num_corr 14
#: 21 ans: 6+9i time: 213.16167736053467 num_tokens 4049 num_corr 66
#: 23 ans: 5 time: 521.1013174057007 num_tokens 9095 num_corr 275
#: 26 ans: 1,-2 time: 283.22146439552307 num_tokens 5332 num_corr 111
#: 27 ans: 144 time: 536.5382895469666 num_tokens 9221 num_corr 325
#: 33 ans: 720 time: 410.34335255622864 num_tokens 7303 num_corr 214
#: 37 ans: 3, 5, 7 time: 35.28829646110535 num_tokens 717 num_corr 3
#: 43 ans: 4 time: 17.03549551963806 num_tokens 304 num_corr 9
#: 44 ans: 70 \sqrt{2} time: 64.6195478439331 num_tokens 1282 num_corr 10
#: 51 ans: 203 time: 149.51262784004211 num_tokens 2773 num_corr 59
#: 63 ans: 4 time: 20.871702909469604 num_tokens 402 num_corr 5
#: 69 ans: 46 time: 35.28115701675415 num_tokens 694 num_corr 6
#: 72 ans: 2516_8 time: 198.97911024093628 num_tok

In [24]:
check_results('/home/wxy320/ondemand/program/speculative_thinking/results/MATH500_deepseek-32b_deepseek-1.5b_5.json')

#: 6 ans: 42 time: 13.990874290466309 num_tokens 261 num_corr 6
#: 10 ans: 4 time: 25.819690942764282 num_tokens 481 num_corr 9
#: 12 ans: \frac{3}{56} time: 72.67899799346924 num_tokens 1406 num_corr 15
#: 19 ans: 28 time: 241.8283622264862 num_tokens 4718 num_corr 50
#: 20 ans: 3 time: 49.12025451660156 num_tokens 1016 num_corr 1
#: 21 ans: 6+9i time: 402.2250757217407 num_tokens 7483 num_corr 68
#: 22 ans: 13535 time: 141.37737917900085 num_tokens 2821 num_corr 30
#: 26 ans: 1,-2 time: 147.72704219818115 num_tokens 2970 num_corr 20
#: 30 ans: 225 time: 58.62046194076538 num_tokens 1098 num_corr 24
#: 32 ans: 11\sqrt2 time: 13.599175453186035 num_tokens 216 num_corr 9
#: 37 ans: 3, 5, 7 time: 208.94304966926575 num_tokens 4267 num_corr 12
#: 39 ans: 2000 time: 296.5413055419922 num_tokens 5863 num_corr 50
#: 44 ans: 70 \sqrt{2} time: 1252.6817829608917 num_tokens 21901 num_corr 283
#: 51 ans: 203 time: 34.25513315200806 num_tokens 689 num_corr 5
#: 63 ans: 4 time: 16.479451179504395 

In [18]:
check_results('/home/wxy320/ondemand/program/speculative_thinking/results/MATH500_deepseek-1.5b_None.json')
    # print(r['generated_answer'])

#: 9 ans: 3\sqrt{13} time: 20.468008279800415 num_tokens 532
#: 10 ans: 4 time: 130.00941586494446 num_tokens 3503
#: 12 ans: \frac{3}{56} time: 50.399529695510864 num_tokens 1353
#: 18 ans: \pi time: 16.321738481521606 num_tokens 435
#: 21 ans: 6+9i time: 119.26699471473694 num_tokens 3209
#: 22 ans: 13535 time: 13.217105865478516 num_tokens 355
#: 23 ans: 5 time: 302.6427173614502 num_tokens 8156
#: 25 ans: 10 time: 20.13202404975891 num_tokens 544
#: 37 ans: 3, 5, 7 time: 107.20967102050781 num_tokens 2893
#: 42 ans: 17 time: 455.1857998371124 num_tokens 12155
#: 44 ans: 70 \sqrt{2} time: 675.5228221416473 num_tokens 17823
#: 54 ans: -\frac{\pi}{6} time: 18.35799264907837 num_tokens 495
#: 61 ans: (6,31,-1) time: 40.43211650848389 num_tokens 1090
#: 65 ans: \frac{35}{64} time: 45.492159843444824 num_tokens 1224
#: 69 ans: 46 time: 17.069411993026733 num_tokens 460
#: 72 ans: 2516_8 time: 398.9450409412384 num_tokens 10713
#: 81 ans: 501 time: 260.2784128189087 num_tokens 7022
#: 82 