In [3]:
!pip show trl

Name: trl
Version: 0.14.0
Summary: Train transformer language models with reinforcement learning.
Home-page: https://github.com/huggingface/trl
Author: Leandro von Werra
Author-email: leandro.vonwerra@gmail.com
License: Apache 2.0
Location: /home/yyf/anaconda3/envs/grpo/lib/python3.10/site-packages
Requires: accelerate, datasets, rich, transformers
Required-by: unsloth, unsloth_zoo


#参考https://unsloth.ai/blog/r1-reasoning
环境：
ubuntu 22.04 LTS
cuda==12.4
torch==2.5.1
vllm==0.7.2
trl==0.14.0

单卡48g显存，显存消耗45g左右，模型qwen2.5-1.5b-instruct
使用3B模型需要单卡80g显存，实际消耗72g左右


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from vllm import LLM


In [None]:
model_path="/home/data/hub/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_path,model_max_length=512)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import re
from datasets import load_dataset, Dataset


SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] 
    data = data.map(lambda x: { 
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) 
    return data 

dataset = get_gsm8k_questions()


def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [None]:
#bf16训练一轮，单卡需要设置vllm_device="cuda:0" 和vllm_gpu_memory_utilization ，vllm_gpu_memory_utilization 不能太大也不能太小
#不设置vllm_device会默认使用最后一张卡
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    # optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = True,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    num_train_epochs = 1, # Set to 1 for a full training run
    # max_steps = 250,
    save_steps = 500,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs/GRPO_model",
    vllm_device="cuda:0",
    vllm_gpu_memory_utilization=0.3,

)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

In [None]:
#简单测试
from vllm import LLM,SamplingParams
model_path="outputs/GRPO_model/checkpoint-250"
model=LLM(model=model_path,max_model_len=16384,gpu_memory_utilization=0.4,enforce_eager=True)

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
correct =0 
total = len(dataset)
for data in dataset:
    question=data["question"]
    messages=[{"role":"system","content":SYSTEM_PROMPT},{"role" : "user", "content" : question}]
    outputs = model.chat(
    messages=messages,
    sampling_params=sampling_params,
    )

    for o in outputs:
        response = o.outputs[0].text
    try:    
        number = response.split('<answer>')[1].split('</answer>')[0].strip()
        if number == data["answer"]:
            correct += 1
    except:
        pass
    
        
acc = round(correct/total, 4)