# Test R1 on MMLU 

If R1 reasons about an answer multiple times, does the probability of the answer increases?

In [119]:
from vllm import LLM, SamplingParams
import torch
import numpy as np
import json
from transformers import AutoTokenizer
from utils import *
from wrapper import *
from tqdm import tqdm
seed = 42
from collections import defaultdict
torch.cuda.manual_seed(seed)
np.random.seed(seed)

In [3]:
model_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model = LLM(model=model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.tokenizer = tokenizer

INFO 03-07 20:43:55 config.py:542] This model supports multiple tasks: {'classify', 'generate', 'reward', 'embed', 'score'}. Defaulting to 'generate'.
INFO 03-07 20:43:55 config.py:1556] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 03-07 20:43:55 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.2) with config: model='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_mod

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 03-07 20:44:02 model_runner.py:1115] Loading model weights took 14.9888 GB
INFO 03-07 20:44:03 worker.py:267] Memory profiling takes 0.54 seconds
INFO 03-07 20:44:03 worker.py:267] the current vLLM instance can use total_gpu_memory (79.33GiB) x gpu_memory_utilization (0.90) = 71.39GiB
INFO 03-07 20:44:03 worker.py:267] model weights take 14.99GiB; non_torch_memory takes 0.14GiB; PyTorch activation peak memory takes 1.21GiB; the rest of the memory reserved for KV Cache is 55.05GiB.
INFO 03-07 20:44:03 executor_base.py:110] # CUDA blocks: 28187, # CPU blocks: 2048
INFO 03-07 20:44:03 executor_base.py:115] Maximum concurrency for 131072 tokens per request: 3.44x
INFO 03-07 20:44:04 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory

Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:15<00:00,  2.20it/s]

INFO 03-07 20:44:20 model_runner.py:1562] Graph capturing finished in 16 secs, took 0.96 GiB
INFO 03-07 20:44:20 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 17.86 seconds





In [72]:
from datasets import load_dataset
dataset = load_dataset('cais/mmlu', "all",split = 'test')

In [73]:
# random sample 1K ques
dataset = dataset.to_list()
ds = np.random.choice(dataset,2000,replace = False)


In [100]:
gen_kwargs = SamplingParams(temperature = 0.,max_tokens = 4096)
ans_kwargs = SamplingParams(temperature = 0.,max_tokens = 10,logprobs=True)
prompt_template = """Question: {question}\n\nChoices:\n{choice}\n\nPlease reason step by step, before stating the correct choice."""

def format_prompt(question,choices,tokenizer):
    choices = '\n'.join([f'({chr(97+i).upper()}) {c}' for i,c in enumerate(choices)])
    inp_ =prompt_template.format(question = question,choice = choices)
    formatted_prompt = [{'role':'user','content':inp_}]
    return tokenizer.apply_chat_template(formatted_prompt,add_generation_prompt=True,tokenize=False)

In [75]:
saved_results = []
bz = 2
for i in tqdm(range(0,len(ds),bz),total = len(ds)//bz):
    samples = ds[i:i+bz]
    x = [format_prompt(d['question'],d['choices'],model.tokenizer) for d in samples]
    y = [chr(d['answer']+97).lower() for d in samples]
    out = model.generate(x,gen_kwargs,use_tqdm=False)
    cot = [oo.outputs[0].text for oo in out]

    print (cot[0])
    break
    
    # sample answer again
    x_cot = [xx+ f'\n{p}' + '\nFinalizing my decision, the final answer is (' for xx,p in zip(x,cot)]
    
    ans_out = model.generate(x_cot,ans_kwargs,use_tqdm=False)
    ans_pred = [oo.outputs[0].text for oo in ans_out]
    for j,d in enumerate(samples):
        d['answer'] = y[j]
        d['is_correct'] = y[j] == ans_pred[j].lower()
        d['pred'] = ans_pred[j].lower()
        d['cot'] = cot[j]
        saved_results.append(d)


In [25]:
import os
os.makedirs('data',exist_ok = True)
with open('data/mmlu_r1_llama_2k.jsonl', 'w') as f:
    for d in saved_results:
        f.write(json.dumps(d)+'\n')

In [146]:
with open('data/mmlu_r1_llama_2k.jsonl', 'r') as f:
    ds = [json.loads(l) for l in f]

In [147]:
correct_samples = [d for d in ds if d['is_correct']]
print (np.mean([d['cot'].lower().count('wait') for d in correct_samples]))
# print (len(b))

4.330262225372077


In [149]:
pprint (ds[2]['cot'])

# Check GSM8K answer prob after each presence of an answer

In [160]:
from copy import deepcopy
import re
def preprocess_answers(answers):
  responses_by_question = {}
  for idx, example in enumerate(answers):
      answer_str = example["answer"].split("#### ")[1].strip()
      answer_str_tmp = deepcopy(answer_str)
      if ',' in answer_str:
          answer_str_tmp = answer_str_tmp.replace(',', '.')
      # if '.' is contained twice, remove the second one
      if answer_str_tmp.count('.') > 1:
          answer_str_tmp = answer_str_tmp.replace('.', '')
      try:
          answer_float = float(answer_str_tmp)
      except:
          print(f"Failed to convert {answer_str_tmp} to float")
          answer_float = 0.
          continue

      # Extract the response text
      response_text = example["r1_full"]

      # Find all occurrences of numbers matching the answer within <think> tags
      think_start = response_text.find("<think>")
      think_end = response_text.find("</think>")

      if think_start != -1 and think_end != -1:
          think_content = response_text[think_start:think_end]

          # Initialize entry for this question
          responses_by_question[idx] = {
              "answer_str":answer_str,
              "answer_float": answer_float,
              "response": response_text,
              "end_indices": [],
              "think_start": think_start,
            "think_end": think_end,
          }

          # Create partial responses up to each occurrence of the answer
          last_pos = 0
          while True:
              # Find next occurrence of answer after last position
              pos = think_content.find(answer_str, last_pos)
              if pos == -1:
                  break   

              # Find the end of the sentence after this occurrence
              # Look for various sentence endings (. ? ! ;)
              endings = ['.', '?', '!', ';', '\n']
              sentence_end = -1
              for ending in endings:
                  end_pos = think_content.find(ending, pos)
                  if end_pos != -1:
                      if sentence_end == -1 or end_pos < sentence_end:
                          sentence_end = end_pos

              if sentence_end == -1:
                  sentence_end = len(think_content)
              else:
                  sentence_end += 1  # Include the punctuation mark

              # Store the end index
              responses_by_question[idx]["end_indices"].append(think_start + sentence_end)

              last_pos = pos + 1
  responses = list(responses_by_question.values())
  return responses

def check_answer(responses):
    acc = []
    sample_result = {}
    for i,r in enumerate(responses):
        ans = r['answer'].split('####')[-1].strip()
        if ',' in ans:
            ans = ans.replace(',','')
        ans = int(ans)

        res = r['r1_full']
        try:
            ans_ctx = res.split("</think>")[-1]
        except:
            ans_ctx = -1

        if ans_ctx != -1:
            ans_start = ans_ctx.find('boxed{') + 6
            if ans_start != -1:
                ans_end = ans_ctx[ans_start:].find('}')
                pred_ans = ans_ctx[ans_start:ans_start + ans_end].strip()
                pred_ans = re.sub(r"\D","",pred_ans) # remove non-digit
                if ',' in pred_ans:
                    pred_ans = pred_ans.replace(',','')
                if '.' in pred_ans:
                    try:
                        pred_ans = float(pred_ans)
                    except:
                        pass
                else:
                    try:
                        pred_ans = int(pred_ans)
                    except:
                        pass
                if isinstance(pred_ans,str):
                    correct = str(ans) == pred_ans
                else:
                    correct = pred_ans == ans
                acc.append(correct)
                # if not correct:
                #     print (pred_ans,ans)
                sample_result[i] = correct
    print (f'acc: {np.mean(acc)}, num evaluated: {len(acc)}/{len(responses)}')
    return sample_result

In [134]:
def get_probs(out,only_digit=False):
    all_probs = [[] for _ in out]
    out_toks = [[] for _ in out]
    for i,oo in enumerate(out):
        for o in oo.outputs[0].logprobs:
            if not only_digit:
                all_probs[i].append(list(o.values())[0].logprob)
            else: # only take digit and stop at } decoded_token
                output_tok = list(o.values())[0].decoded_token
                out_id = list(o.keys())[0]
                if output_tok == '}':
                    break
                if output_tok.isdigit():
                    all_probs[i].append(list(o.values())[0].logprob)
                    out_toks[i].append(out_id) # get the token id
    all_probs = [np.mean(p) if len(p) > 1 else p[0] for p in all_probs]
    return all_probs,out_toks

In [86]:
from huggingface_hub import hf_hub_download

file_name = hf_hub_download(
    repo_id="wendlerc/GSM8K_solutions_of_DeepSeek-R1-Distill-Llama-8B",
    filename="gsm8k_responses_test.jsonl",
    repo_type="dataset"
)
print(f"File downloaded to: {file_name}")
gsm8k_ds = load_jsonl(file_name)

File downloaded to: /export/home2/weijie210/.cache/huggingface/hub/datasets--wendlerc--GSM8K_solutions_of_DeepSeek-R1-Distill-Llama-8B/snapshots/e1d6c21d1905bbcc37297e762dbb30a4105c78e8/gsm8k_responses_test.jsonl


In [159]:
pprint(gsm8k_ds[0]['r1_full'])

In [161]:
correct_idx = check_answer(gsm8k_ds)
correct_gsm8k_ds = [gsm8k_ds[i] for i,v in correct_idx.items() if v]
print (len(correct_gsm8k_ds),len(gsm8k_ds))

processed_ds = preprocess_answers(correct_gsm8k_ds)
filtered_ds = [d for d in processed_ds if len(d['end_indices']) > 3 and len(d['end_indices']) <= 10] # too many endices could be false.
print (len(filtered_ds),len(processed_ds))

acc: 0.8923426838514026, num evaluated: 1319/1319
1177 1319
446 1177


In [140]:
def split_into_chunk(text,end_indices):
    out = []
    for i,chunk in enumerate(end_indices):
        out.append(text[:chunk])
        # if i < len(split_)-1:
        #     out.append((prev + chunk).strip())
        # else:
        #     final = (prev + chunk).strip()
        #     if 'Final Answer' in final:
        #         final = '\n\n'.join(final.split('Final Answer')[0].strip().split('\n\n')[:-1]).strip()
        #     elif '</think>' in final:
        #         final = final.split('</think>')[0].strip()
        #     out.append(final)
        # prev += (chunk + 'Wait')
    return out

def norm(x):
    return (x - np.min(x))/(np.max(x) - np.min(x))

bz = 32

sample_lp = defaultdict(list) # separate by chunk len
sample_ac = defaultdict(list) # acc of answer

for t in tqdm(range(0,len(filtered_ds),bz),total = len(filtered_ds)//bz):
    responses = [d['response'] for d in filtered_ds[t:t+bz]]
    end_indices = [d['end_indices'] for d in filtered_ds[t:t+bz]]
    answers = [d['answer_str'] for d in filtered_ds[t:t+bz]]
    ques_start = [d['think_start'] for d in filtered_ds[t:t+bz]]
    chunk_samples = [split_into_chunk(r,e) for r,e in zip(responses,end_indices)]
    ques_only = [r[:s] + '\nI figured it out' for r,s in zip(responses,ques_start)]
    for i in range(len(chunk_samples)):
        chunk_samples[i] = [ques_only[i]] + [c + '\n</think>\nThe final answer is \\boxed{' for c in chunk_samples[i]] # early exit + no thinking at the start
    chunk_lens = [len(c) for c in chunk_samples] # to split between samples
    chunk_samples = sum(chunk_samples,[])
    out = model.generate(chunk_samples,ans_kwargs,use_tqdm=False)
    for i,cl in enumerate(chunk_lens): # over each sample split into chunks
        curr_sample = out[:cl]
        ans_probs,ans_toks = get_probs(curr_sample,only_digit=True)
        ans_toks = [model.tokenizer.decode(tok) for tok in ans_toks]
        ans_correct = [str(yhat) == str(answers[i]) for yhat in ans_toks]
        norm_ans_probs = norm(ans_probs) # normalize across chunks
        out = out[cl:]
        sample_lp[cl-1].append(norm_ans_probs) # -1 because 1st is without question
        sample_ac[cl-1].append(np.array(ans_correct).astype(float))

    
    

14it [02:01,  8.69s/it]                                                                                                                                                                                                                                                                   


In [129]:
# plot utils
import pandas as pd
import plotly.express as px
def line(tensor, labels=None, x_label="x-axis", y_label="y-axis", **kwargs):
    if tensor.ndim == 2:
        tensor = tensor.T  # Ensure rows correspond to different series, columns to x-axis
        df = pd.DataFrame(tensor)

        if labels is not None:
            df.columns = labels  # Assign meaningful column names if labels are provided

        df["index"] = df.index  # Add index for x-axis
        df = df.melt(id_vars="index", var_name="line_id", value_name="y")  # Convert to long format

        fig = px.line(df, x="index", y="y", color="line_id", **kwargs)
        fig.update_layout(
            xaxis_title=x_label,  # Set X-axis label
            yaxis_title=y_label   # Set Y-axis label
        )
        fig.show()
    else:
        fig = px.line(
            y=tensor,
            **kwargs
        )
        fig.update_layout(
            xaxis_title=x_label,  # Set X-axis label
            yaxis_title=y_label   # Set Y-axis label
        )
        fig.show()

In [145]:
mean_sample_lp = {k:np.stack(v).mean(0) for k,v in sample_lp.items()}
mean_sample_acc = {k:np.stack(v).mean(0) for k,v in sample_ac.items()}
for k,v in sorted(mean_sample_lp.items(),key = lambda x:x[0]):
    line(np.stack([v,mean_sample_acc[k]]),x_label = 'Answer occurences',title = f'Num ans occurences: {k}',labels = ['logprob','acc'])