## 准备

### 数据集

AutoDL开启学术加速，直接加载数据集

In [None]:
# import subprocess
# import os

# result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
# output = result.stdout
# for line in output.splitlines():
#     if '=' in line:
#         var, value = line.split('=', 1)
#         os.environ[var] = value

# from datasets import load_dataset
# data = load_dataset('openai/gsm8k', 'main')

本地加载huggingface数据集

In [None]:
from datasets import load_dataset
base_url = "data/gsm8k/main/"
dataset = load_dataset("parquet", data_files={'train': base_url + 'train-00000-of-00001.parquet', 'test': base_url + 'test-00000-of-00001.parquet'})

### wandb

In [None]:
%pip install wandb
import wandb
wandb.login(key="your_wandb_api_key")  # 替换为你的wandb API密钥
project_name = "GRPO-Qwen2.5-3B-test" # todo: wandb项目命名
wandb.init(project = project_name)

## GRPO训练流程

In [None]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

### 格式化数据集

In [None]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def get_gsm8k_questions(split = "train") -> Dataset:
    data = dataset[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

In [None]:
dataset = get_gsm8k_questions()

### 奖励函数

In [None]:
# correctness_reward_func：根据正确性对答案进行奖励。
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

# int_reward_func：根据是否为数字对输出进行奖励。
def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

# strict_format_reward_func：根据严格的格式要求检查并奖励。
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# soft_format_reward_func：根据稍微宽松的格式要求检查并奖励。
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# count_xml：计算文本中的 XML 标签结构并给予奖励。
def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

# xmlcount_reward_func：为每个输出计算 XML 标签结构的符合度并返回奖励。
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

### 训练流程

In [None]:
# todo: 控制数据集大小
dataset = dataset.select(range(100))
# todo: 选取模型
model_name = "Qwen2.5-3B-Instruct" 

In [None]:
model_dir = "models/" + model_name
output_dir = "models/" + model_name + "-GRPO"
run_name = model_name + "-GRPO-gsm8k" # 当前运行项目的名称

# todo: 超参数优化
training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4, # batchsize
    gradient_accumulation_steps=4, # 梯度累积，使用梯度累积来模拟更大的批次大小
    num_generations=4, # 生成的数量
    max_prompt_length=256,
    max_completion_length=256, # 回复的token数量
    num_train_epochs=1,
    save_steps=300,
    save_total_limit=2, # 最多保留 5 个最新检查点，防止系统盘被占满
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=False,
    vllm_gpu_memory_utilization=.3,
    # vllm_device="cuda:0",
    report_to="wandb", # 报告工具
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

trainer.save_model(output_dir)

### 查看结果

In [None]:
grpo_model_name = output_dir
model = AutoModelForCausalLM.from_pretrained(
    grpo_model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

In [None]:
prompt = "Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?"
messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

In [None]:
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

In [None]:
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
response