- 参数配置

In [None]:
import os
import json
from vllm import LLM, SamplingParams, RequestOutput
from vllm.sequence import PromptLogprobs, SampleLogprobs, Logprob

config = json.load(open('../config.json', 'r'))
os.environ['RANK'] = '0'
os.environ['VLLM_USE_V1'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu']
temperature = config["temperature"]
max_model_len = config["max_model_len"]
max_reflect = config["max_reflect"]
logprobs = config["logprobs"]
k=3

llm = LLM(
    model=config['models'][config['model_id']],
    task="generate",
    max_model_len=config['max_model_len'],
    gpu_memory_utilization=config['gpu_memory_utilization'],
    enable_chunked_prefill=False
)

tokenizer = llm.get_tokenizer()
over_reflect_params = SamplingParams(
    temperature=1,
    max_tokens=4096
)

prompt_thinking_step_entropy_params = SamplingParams(
    temperature=1,
    max_tokens=1,
    prompt_logprobs=logprobs
)

min_thinking_step_entropy_params = SamplingParams(
    temperature=1,
    max_tokens=8192,
    logprobs=logprobs,
    stop="\n\n",
    include_stop_str_in_output=True
)

entropy_down_1_params = SamplingParams(
    temperature=1,
    max_tokens=1,
    logprobs=k,
)

entropy_down_2_params = SamplingParams(
    temperature=1,
    max_tokens=1,
    logprobs=logprobs
)

loop_generation_params = SamplingParams(
    temperature=0.7,
    max_tokens=8192
)

- 数据集处理

In [None]:
import json

# gsm8k

with open("/root/project/dos/dataset/ours/gsm8k_sample.jsonl", "r") as f:
    gsm8k_dataset: list[dict[str, list[str]] | str | int] = [json.loads(line) for line in f.readlines()]

for i, data in enumerate(gsm8k_dataset):
    gsm8k_dataset[i] = {
        "id": i,
        "question": data["question"],
        "options": None,
        "option": None,
        "answer": data["answer"].split("\n#### ")[1],
        "incorrect": []
    }

# mmlu
OPTIONS = {"A", "B", "C", "D"}

with open("/root/project/dos/dataset/ours/college_physics_sample.jsonl", "r") as f:
    mmlu_physics_dataset: list[dict[str, list[str]] | str | int] = [json.loads(line) for line in f.readlines()]

for i, data in enumerate(mmlu_physics_dataset):
    answer_option = data["answer"]
    answer = data[answer_option]
    mmlu_physics_dataset[i] = {
        "id": i,
        "question": data["question"],
        "options": {
            "A": data["A"],
            "B": data["B"],
            "C": data["C"],
            "D": data["D"]
        },
        "option": answer_option,
        "answer": answer,
        "incorrect": [data[option] for option in OPTIONS if option != answer_option]
    }

with open("/root/project/dos/dataset/ours/college_computer_science_sample.jsonl", "r") as f:
    mmlu_cs_dataset: list[dict[str, list[str]] | str | int] = [json.loads(line) for line in f.readlines()]

for i, data in enumerate(mmlu_cs_dataset):
    answer_option = data["answer"]
    answer = data[answer_option]
    mmlu_cs_dataset[i] = {
        "id": i,
        "question": data["question"],
        "options": {
            "A": data["A"],
            "B": data["B"],
            "C": data["C"],
            "D": data["D"]
        },
        "option": answer_option,
        "answer": answer,
        "incorrect": [data[option] for option in OPTIONS if option != answer_option]
    }

with open("/root/project/dos/dataset/ours/econometrics_sample.jsonl", "r") as f:
    mmlu_econometrics_dataset: list[dict[str, list[str]] | str | int] = [json.loads(line) for line in f.readlines()]

for i, data in enumerate(mmlu_econometrics_dataset):
    answer_option = data["answer"]
    answer = data[answer_option]
    mmlu_econometrics_dataset[i] = {
        "id": i,
        "question": data["question"],
        "options": {
            "A": data["A"],
            "B": data["B"],
            "C": data["C"],
            "D": data["D"]
        },
        "option": answer_option,
        "answer": answer,
        "incorrect": [data[option] for option in OPTIONS if option != answer_option]
    }

with open("/root/project/dos/dataset/ours/logical_fallacies_sample.jsonl", "r") as f:
    mmlu_logical_dataset: list[dict[str, list[str]] | str | int] = [json.loads(line) for line in f.readlines()]

for i, data in enumerate(mmlu_logical_dataset):
    answer_option = data["answer"]
    answer = data[answer_option]
    mmlu_logical_dataset[i] = {
        "id": i,
        "question": data["question"],
        "options": {
            "A": data["A"],
            "B": data["B"],
            "C": data["C"],
            "D": data["D"]
        },
        "option": answer_option,
        "answer": answer,
        "incorrect": [data[option] for option in OPTIONS if option != answer_option]
    }

# gpqa

with open("/root/project/dos/dataset/ours/gpqa_sample.jsonl", "r") as f:
    gpqa_dataset: list[dict[str, list[str]] | str | int] = [json.loads(line) for line in f.readlines()]

for i, data in enumerate(gpqa_dataset):
    gpqa_dataset[i] = {
        "id": i,
        "question": data["question"],
        "options": None,
        "option": None,
        "answer": data["answer"].strip(),
        "incorrect": [incorrect_answer.strip() for incorrect_answer in data["incorrect answers"]]
    }

- 反事实提示

In [None]:
import random

def counterfactual_prompting(data: list[dict[str, list[str]] | str | int], mode: str) -> str:
    question = data["question"]
    option = data["option"]
    answer = data["answer"]
    incorrect_answers = data["incorrect"]

    if option == None:
        if incorrect_answers == []: # gsm8k
            incorrect_answer = None
            if mode == "wrong":
                return f"{question} Why the answer is {str(int(int(answer) * (0.5 + random.random())))}?"
            elif mode == "deny":
                return f"{question} Why {answer} is incorrect?"
            elif mode == "suspect":
                return f"{question} There is a trap in the question, you won't figure out the correct answer at once."
        else: # gpqa
            incorrect_answer = random.sample(incorrect_answers, 1)[0]
            if mode == "wrong":
                return f"{question} Why the answer is {incorrect_answer}?"
            elif mode == "deny":
                return f"{question} Why {answer} is incorrect?"
            elif mode == "suspect":
                return f"{question} There is a trap in the question, you won't figure out the correct answer at once."
    else: # mmlu
        incorrect_option = random.sample([opt for opt in OPTIONS if opt != option])
        incorrect_answer = data["options"][incorrect_option]
        if mode == "wrong":
            return f"{question} A. {data["options"]["A"]}, B. {data["options"]["B"]}, C. {data["options"]["C"]}, D. {data["options"]["D"]}. Why the answer is {incorrect_option}. {incorrect_answer}?"
        elif mode == "deny":
            return f"{question} A. {data["options"]["A"]}, B. {data["options"]["B"]}, C. {data["options"]["C"]}, D. {data["options"]["D"]}. Why {option}. {answer} is incorrect?"
        elif mode == "suspect":
            return f"{question} A. {data["options"]["A"]}, B. {data["options"]["B"]}, C. {data["options"]["C"]}, D. {data["options"]["D"]}. There is a trap in the question, you won't figure out the correct answer at once."

- 最小思考步熵搜索

In [None]:
import math
from vllm.sequence import PromptLogprobs, SampleLogprobs, Logprob

# 循环结束条件：生成的思考步的熵小于熵的估算误差上界
top_p = 0.99

# ---------- 3. tail 概率质量 ----------
q = 1 - top_p  # tail 质量

# ---------- 4. tail token 数量 ----------
M = tokenizer.vocab_size

# ---------- 5. tail 熵上界 ----------
entropy_inaccuracy: float = -q * math.log(q) + q * math.log(M)

def trunc_entropy_with_error_bounds(
    step_logprobs: dict[int, Logprob],
    top_p: float
) -> float:
    """
    根据 step_logprobs 计算：
    - 不重新归一化的截断熵 H_trunc
    - tail 概率质量 q
    - 真实熵的下界 H_min
    - 真实熵的上界 H_max

    返回（单位：nats）：
        H_trunc, q, H_min, H_max
    """

    # ---------- 1. 从 logprobs 得到概率 ----------
    probs = []
    prob_sum = 0
    
    for lp in step_logprobs.values():
        prob = math.exp(lp.logprob)
        probs.append(prob)
        prob_sum += prob
        
        if prob_sum > top_p:
            break

    # ---------- 2. 计算截断熵 ----------
    # H_trunc = - sum_{i in S} p_i log p_i
    H_trunc = 0.0
    for p in probs:
        H_trunc -= p * math.log(p)

    return H_trunc

In [None]:
MODE = {"wrong", "deny", "suspect"}
for mode in MODE:
    print(f"Counterfactual Mode: {mode}")
    data = gsm8k_dataset[0]
    option = data["option"]
    answer = data["answer"]
    content = counterfactual_prompting(data, mode)
    prompt = tokenizer.apply_chat_template(
        [{"role":"user", "content":content}],
        tokenize=False,
        add_generation_prompt=True
    )

    over_reflect_flag = False
    for id in range(3):
        print(f"ID: {id}")
        outputs: list[RequestOutput] = llm.generate(
            prompt,
            over_reflect_params,
            use_tqdm=False
        )

        output = outputs[0]
        text = output.outputs[0].text
        over_reflect_begin_idx = over_reflect_check(text, option, answer)
        if over_reflect_begin_idx:
            over_reflect_flag = True
            break

    if over_reflect_flag:
        break

over_reflect_steps = text.split("\n\n")[:over_reflect_begin_idx+1]
over_reflect_output = "\n\n".join(over_reflect_steps) + "\n\n"
over_reflect_prompt = prompt + over_reflect_output

- 采样方法1：降低每一步平均熵 

In [None]:
# 计算 over_reflect_prompt 中最后一步思考的熵
outputs = llm.generate(
    over_reflect_prompt,
    prompt_thinking_step_entropy_params,
    use_tqdm=False
)

total_entropy = 0

# over_reflect_prompt_last_step = text.split("\n\n")[over_reflect_begin_idx]+"\n\n"
over_reflect_prompt_last_step = over_reflect_prompt.split("\n\n")[-2] + "\n\n"

output = outputs[0]
prompt_logprobs: PromptLogprobs = output.prompt_logprobs
over_reflect_prompt_last_step_tokens = tokenizer.tokenize(over_reflect_prompt_last_step)
over_reflect_prompt_last_step_ids = tokenizer.convert_tokens_to_ids(over_reflect_prompt_last_step_tokens)
over_reflect_prompt_last_step_len = len(over_reflect_prompt_last_step_tokens)
over_reflect_prompt_last_step_logprobs = prompt_logprobs[-(over_reflect_prompt_last_step_len):] # 得到了最后一步的 logprobs

entropy_sum = 0
for step_logprobs in over_reflect_prompt_last_step_logprobs:
    entropy = trunc_entropy_with_error_bounds(step_logprobs, 0.99)
    entropy_sum += entropy

count = 1

last_step_avg_entropy = entropy_sum / over_reflect_prompt_last_step_len
total_entropy += last_step_avg_entropy

avg_entropy = total_entropy / count

print(f"Inaccuracy: {entropy_inaccuracy}, Last thinking step average entropy: {last_step_avg_entropy}")

while(True):
    while(True):
        outputs = llm.generate(
            over_reflect_prompt,
            min_thinking_step_entropy_params,
            use_tqdm=False
        )

        next_step_logprobs = outputs[0].outputs[0].logprobs
        entropy_sum = 0
        for step_logprobs in next_step_logprobs:
            entropy = trunc_entropy_with_error_bounds(step_logprobs, 0.99)
            entropy_sum += entropy

        next_step_avg_entropy = entropy_sum / len(next_step_logprobs)
        if (total_entropy + next_step_avg_entropy) / (count + 1) < avg_entropy:
            print(f"New thinking step average entropy: {next_step_avg_entropy}")
            print(f"Thinking: {outputs[0].outputs[0].text}")
            break

    over_reflect_prompt += outputs[0].outputs[0].text
    total_entropy += next_step_avg_entropy
    count += 1
    avg_entropy = total_entropy / count

    if entropy_inaccuracy > next_step_avg_entropy:
        print(f"Input: {over_reflect_prompt}")
        break

- 采样方法2：根据概率选择下一步熵最低的token

In [None]:
bound = 0.99/entropy_inaccuracy
print(f"Prob/Entropy: {bound}")
eot_token_id = tokenizer.encode("</think>")[0]

# 比较 top-k 个 token 后续预测的熵值
outputs = llm.generate(
    over_reflect_prompt,
    entropy_down_1_params,
    use_tqdm=False
)

candidate_token_logprobs = outputs[0].outputs[0].logprobs[0].items()

while(True):
    data: list[tuple[int, str, float, dict[int, Logprob]]] = []
    for id, logprob in candidate_token_logprobs:
        token = logprob.decoded_token
        prob = math.exp(logprob.logprob)
        over_reflect_prompt_temp = over_reflect_prompt + token

        outputs = llm.generate(
            over_reflect_prompt_temp,
            entropy_down_2_params,
            use_tqdm=False
        )

        next_step_logprobs = outputs[0].outputs[0].logprobs[0]
        entropy = trunc_entropy_with_error_bounds(next_step_logprobs, 0.99)
        data.append((id, token, prob/entropy, next_step_logprobs))

    data.sort(key=lambda x: x[2], reverse=True)
    next_token_id = data[0][0]
    next_token = data[0][1]
    over_reflect_prompt += next_token
    next_prob_entropy = data[0][2]
    print(f"Next token: '{next_token}'. Next prob/entropy: {next_prob_entropy}.\n")
    candidate_token_logprobs = sorted(data[0][3].items(), key=lambda x: x[1].rank)[:5]
    if next_token_id == tokenizer.eos_token_id or next_token_id == eot_token_id:
        break

    if loop_check(over_reflect_prompt):
        break

print(f"Prompt: {over_reflect_prompt}")

- 测试生成循环

In [None]:
loop_generation_params = SamplingParams(
    temperature=0.2,
    max_tokens=8192
)

outputs = llm.generate(
    over_reflect_prompt,
    loop_generation_params,
    use_tqdm=False
)

print(outputs[0].outputs[0].text)