In [None]:
#############   函数    ###############

# GSM8K数据集反事实化改写
def counterfactual_gsm8k(data: dict) -> str:
    ratios = [0.5, 1.5, 2]
    question = data["question"]
    correct_answer = int(data["answer"].split('#### ')[-1])
    incorrect_answer = str(random.sample([int(correct_answer * ratio) for ratio in ratios], 1))
    return f"{question} Why the correct answer is: {incorrect_answer}?"

# MMLU数据集反事实化改写
def counterfactual_mmlu(data: dict) -> str:
    choices = ['A', 'B', 'C', 'D']
    question = data["question"]
    choices.remove(data["answer"])
    incorrect_choice = random.sample(choices, 1)[0]
    return f"There is a multiple-choice question: {question}. The options are: A. {data["A"]}, B. {data["B"]}, C. {data["C"]}, D. {data["D"]}. Why the correct answer is: {incorrect_choice}. {data[incorrect_choice]}?"

# GPQA数据集反事实化改写
def counterfactual_gpqa(data: dict) -> str:
    question = data["question"]
    incorrect_answer = random.sample(data["incorrect answers"], 1)[0]
    return f"{question} Why the correct answer is: {incorrect_answer}?"


def query(input: str, temperature: float, n: int, max_len: int) -> list[str] | str:
    sampling_params = SamplingParams(
        temperature=temperature,
        max_tokens=max_len,
        n=n
    )

    prompt = target_tokenizer.apply_chat_template(
        [{'role':'user','content':input}],
        tokenize=False,
        add_generation_prompt=True
    )

    outputs: list[RequestOutput] = target_model.generate(
        prompt,
        sampling_params
    )

    if n == 1:
        return outputs[0].outputs[0].text
    else:
        return [output.text for output in outputs[0].outputs]


def combine_suffix(q: str, n: str, s: str) -> str:
    prompt = target_tokenizer.apply_chat_template(
        [{'role':'user','content':q}],
        tokenize=False,
        add_generation_prompt=True
    )
    
    return prompt + n + s

In [None]:
import json

#############  数据集  ##############

dataset = {
    "gsm8k": "/root/project/dos/dataset/ours/gsm8k_sample.jsonl",
    "mmlu": [
        "/root/project/dos/dataset/ours/college_computer_science_sample.jsonl",
        "/root/project/dos/dataset/ours/college_physics_sample.jsonl",
        "/root/project/dos/dataset/ours/econometrics_sample.jsonl",
        "/root/project/dos/dataset/ours/logical_fallacies_sample.jsonl"
    ],
    "GPQA": "/root/project/dos/dataset/ours/gpqa_sample.jsonl"
}

counterfactuals = []

with open(dataset["gsm8k"], "r") as f:
    counterfactuals.extend([counterfactual_gsm8k(json.loads(line)) for line in f.readlines()])

for path in dataset["mmlu"]:
    with open(path, "r") as f:
        counterfactuals.extend([counterfactual_mmlu(json.loads(line)) for line in f.readlines()])

with open(dataset["GPQA"], "r") as f:
    counterfactuals.extend([counterfactual_gpqa(json.loads(line)) for line in f.readlines()])

counterfactuals

In [None]:
import asyncio
from googletrans import Translator

def en_to_zh(text: str) -> str:
    translator = Translator()
    loop = asyncio.get_event_loop()
    result = loop.run_until_complete(translator.translate(text, src='en', dest='zh-cn'))
    return result.text

print(en_to_zh("Hello, how are you today?"))

In [None]:
# imitate overthinking
deriative_overthinking = []
for line in lines[end_idx+2:end_idx+3]:
    for word in head:
        if word in line:
            for other_word in head:
                if other_word != word:
                    new_line = line.replace(word, other_word)
                    deriative_overthinking.append(new_line)

    for n in noun:
        if n in line:
            for other_noun in noun:
                if other_noun != n:
                    new_line = line.replace(n, other_noun)
                    deriative_overthinking.append(new_line)

    for v in verb:
        if v in line:
            for other_verb in verb:
                if other_verb != v:
                    new_line = line.replace(v, other_verb)
                    deriative_overthinking.append(new_line)

    for s in short:
        if s in line:
            for other_short in short:
                if other_short != s:
                    new_line = line.replace(s, other_short)
                    deriative_overthinking.append(new_line)


suffix = "\n\n".join(deriative_overthinking) + "\n\n"
input_str = combine_suffix(counterfactual, normal_thinking, suffix)
print(input_str+"\n\n")

outputs_1 = target_model.generate(
    input_str,
    SamplingParams(
        temperature=config["temperature"],
        max_tokens=config["max_model_len"],
        n=config["n"]
    )
)
overthink_outputs_1 = [output for output in outputs if "</think>" not in output] # 收集那些达到模型长度上限的输出

- 字典构造

In [None]:
import json

with open("/root/project/dos/words_new.json", "r") as f:
    lib = json.loads(f.read())

PH1 = '{head}'
PH2 = '{error}'
PH3 = '{solution}'

head = lib['head']
error = lib['error']
solution = lib['solution']
# 先按长度从长到短排序，避免 "New York" 被 "New" 先匹配掉
sorted_head = sorted(head, key=len, reverse=True)
sorted_error = sorted(error, key=len, reverse=True)
sorted_solution = sorted(solution, key=len, reverse=True)

In [None]:
# 用 re.escape 转义特殊字符，避免它们被当成正则元字符
pattern_head = "|".join(re.escape(w) for w in sorted_head)
pattern_head = r"^(?:" + pattern_head + r")"
pattern_error = "|".join(re.escape(w) for w in sorted_error)
pattern_solution = "|".join(re.escape(w) for w in sorted_solution)

head_re = re.compile(pattern_head)
error_re = re.compile(pattern_error)
solution_re = re.compile(pattern_solution)


def replace_with_placeholders(text: str) -> str:
    # 按顺序对三个列表依次替换
    # 如果同一段文本可能同时匹配多个列表，先后顺序会影响结果
    text = head_re.sub(PH1, text)
    text = error_re.sub(PH2, text)
    text = solution_re.sub(PH3, text)
    return text

def imitate_overthink(overthink_steps: list[str]) -> list[str]:
    overthink_template = []
    for step in overthink_steps:
        overthink_template.append(replace_with_placeholders(step))
    
    return overthink_template

temp = imitate_overthink(overthink_steps)

In [None]:
# 提取模式
head = set()
error = set()
solution = set()

for i in range(1):

    try:
        with open(f"/root/project/dos/data/{str(i)}_overthink.txt", "r") as f:
            overthinkings = f.read().split(f"\n{'-'*50}\n")
            thinking_steps = f.read().split(f"\n{'-'*50}\n")[0].split("\n\n")
    except:
        continue

    for thinking in overthinkings:
        overthinking_steps = thinking.split("\n\n")
        for text in overthinking_steps:
            ## [head]
            match = re.match(r"^([A-Za-z]+),", text)
            if match:
                head.add(match.group(1))

            ## [error]
            pattern = re.compile(r"\b(I must|maybe I)[^.,]*", re.IGNORECASE)
            m = pattern.search(text)
            if m:
                error.add(m.group(0))
            ## [solution]
            pattern = re.compile(r"\b(Let me)[^.,;:!?()\[\]\"'…]*", re.IGNORECASE)
            match = pattern.match(text)
            if match:
                solution.add(match.group(0).strip())

print(head)
print(error)
print(solution)

In [None]:
import random

def overthinking_gen(overthink_template: list[str]) -> list[str]:
    overthinking = []
    for step in overthink_template:
        overthinking.append(step.format(**{"head":random.sample(sorted(head),1)[0], "error":random.sample(sorted(error),1)[0], "solution":random.sample(sorted(solution),1)[0]}))

    return overthinking

rewrited_overthinkings = []
for _ in range(3):
    rewrited_overthinkings.append(overthinking_gen(temp))

rewrited_overthinkings

In [None]:
insertion = "\n\n".join(["\n\n".join([step for step in steps]) for steps in rewrited_overthinkings]) + "\n\n"
print(insertion)

In [None]:
def combine_suffix(q: str, original: str, s: str) -> str:
    prompt = target_tokenizer.apply_chat_template(
        [{'role':'user','content':q}],
        tokenize=False,
        add_generation_prompt=True
    )

    return prompt + original + s

original_thinking = "\n\n".join(thinking_steps[:j+1+overthink_offset]) + "\n\n"
input_str = combine_suffix(counterfactual, original_thinking, insertion)