<a href="https://colab.research.google.com/github/shivvor2/RL-PEFT-a-small-reasoner/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3 align="center"></h3>

<h1 align="center">Qwen 0.5b on GRPO</h1>

---

<h1 align="center">Training a small math reasoner with RL</h1>

Original notebook by [will brown,](https://x.com/willccbb), unfortunately, I can't find the X/Twitter release post anymore.

On top of the original notebook, we have implemented:
1. Evaluation code (to evaluate performance of the finetuned model vs the original model)
2. LoRA finetuning (instead of full finetuning) of the model (in progress)

Here is the release message for the original notebook

> This notebook is an alternate version of the [GRPO demo](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) by [will brown,](https://x.com/willccbb) training llama-1b on the gsm8k math dataset.

> We've only implemented a series of changes to make the code more workable on Colab:
* Replacement of llama-1b with Qwen-0.5b
* Generation with vllm, which yields a significant speed-up. Qwen small size makes it possible to run vllm on the same gpu as the one being used for GRPO.
* Dropping flash-attn (recurrent bug with modeling qwen, not clear why)

## Setting up the environment.

First we install vllm. Notice that you'll have to restart the session afterwards.

In [None]:
!pip install vllm

Then we install trl and datasets. It has to be in this order for some reason (bug on trl if you do vllm afterwards)

In [None]:
!pip install trl datasets peft

(Optional) We mount google drive for persistant storage.

Change the root storage path if other forms of persistant storage is used

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

base_path = "/content/drive/MyDrive/ML_Experiments/qwen2.5_0.5B_GRPO_LoRA"
os.makedirs(os.path.dirname(base_path), exist_ok=True)

## Defining the RL rewards

Now we have everything ready to set up our RL training set and reward policy.

First we set the general prompt structure (with the reasoning tags).

In [None]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

Now we import the gsm8k dataset and restructure it to fit into a conversational prompt format:

In [None]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

We move on now to the reward functions. The most important one is the "correctness" function which acts as a verifier (comparison of model completions vs. answer). The three others are formatting functions.

In [None]:
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

And here are some additional helper functions to help find the latest checkpoint

In [None]:
import os
import re
from google.colab import files
import shutil

def get_latest_checkpoint(base_dir: str):
    """Find the latest checkpoint in the given directory."""

    # Check existance for base directory
    if not os.path.exists(base_dir):
        print(f"Warning: Directory {base_dir} does not exist")
        return None

    # Look for checkpoint directories
    checkpoint_dirs = [d for d in os.listdir(base_dir) if d.startswith('checkpoint-')]

    if not checkpoint_dirs:
        return None

    # Extract checkpoint numbers and find the highest
    checkpoint_nums = [int(re.search(r'checkpoint-(\d+)', d).group(1)) for d in checkpoint_dirs]
    latest_checkpoint_num = max(checkpoint_nums)
    latest_checkpoint = f"checkpoint-{latest_checkpoint_num}"

    return os.path.join(base_dir, latest_checkpoint)

## Full finetuning and evaluation

### Training loop

(Optional) Resume training from checkpoint

We now set the training arguments:

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

output_dir=os.path.join(base_path, "outputs/Qwen-0.5B-GRPO")
run_name="Qwen-0.5B-GRPO-gsm8k"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=True,
    vllm_gpu_memory_utilization=.3,
    vllm_device="cuda:0",
    report_to="none" #I'm disabling Wandb.
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

And launch the actual training:

In [None]:
# Obtain checkpoint (to resume training)
checkpoint_path = get_latest_checkpoint(output_dir)
# checkpoint_path = None # Uncomment this if we want to restart training

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
)

if checkpoint_path is None: # No checkpoint
    trainer.train()
else:
    trainer.train(resume_from_checkpoint=checkpoint_path) # resume training

### Evaluating the trained model

In [None]:
from vllm import SamplingParams, LLM

In [None]:
test_data = get_gsm8k_questions(split="test")

In [None]:
# 2. Load Trained Model & Tokenizer
model_path = get_latest_checkpoint(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
# 3. Format Prompts Using Chat Template
test_prompts = []
for example in test_data:
    formatted_prompt = tokenizer.apply_chat_template(
        example["prompt"],
        tokenize=False,
        add_generation_prompt=True
    )
    test_prompts.append(formatted_prompt)

In [None]:
# 4. Set Up vLLM for Batch Inference
llm = LLM(
    model=model_path,
    tensor_parallel_size=1,
    gpu_memory_utilization=0.3,
    trust_remote_code=True
)

# 5. Configure Sampling Parameters
sampling_params = SamplingParams(
    temperature=0.0,    # Greedy decoding for evaluation
    max_tokens=200,     # Same as training's max_completion_length
    stop=["<|im_end|>"] # Qwen's stop token
)

In [None]:
# 6. Generate Responses
outputs = llm.generate(test_prompts, sampling_params)

In [None]:
# 7. Extract Answers
def extract_xml_answer(text: str) -> str:
    if "<answer>" in text and "</answer>" in text:
        return text.split("<answer>")[1].split("</answer>")[0].strip()
    return ""

pred_answers = [extract_xml_answer(output.outputs[0].text) for output in outputs]
true_answers = [example["answer"] for example in test_data]

In [None]:
# 8. Calculate Accuracy
accuracy = sum(1 for p, t in zip(pred_answers, true_answers) if p == t) / len(true_answers)
print(f"GSM8K Test Accuracy: {accuracy * 100:.2f}%")

In [None]:
# (Optional) 9. Log the results
results_path = os.path.join(base_path, "/grpo_lora_results.txt")
os.makedirs(os.path.dirname(results_path), exist_ok=True)

with open(results_path, "a") as f:
    f.write(f"Baseline (full finetuning): {accuracy:.2f}%\n")

## LoRA finetuning and evaluation

### Training loop



We first setup the PEFT (LoRA) configuration

In [None]:
from peft import LoraConfig

rank = 16

peft_config = LoraConfig(
    r=rank,                     # the rank of the loRA matrices
    lora_alpha=2*rank,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",   # Attention layers
        "gate_proj", "up_proj", "down_proj",     # MLP layers
    ]
)

and setup the trainer like how it was previously (without VLLM as it does not support LoRA)

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

output_dir = os.path.join(base_path, f"outputs/Qwen-0.5B-GRPO-LoRA-r{rank}")
run_name = f"Qwen-0.5B-GRPO-LoRA-r{rank}-gsm8k"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4, # Changed from 4 to 16 because otherwise the training would not start
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=False,        # Use the PEFT model directly instead of vLLM engine
    report_to="none",
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Now, we launch the actural training

In [None]:
# Load checkpoint if it exists
checkpoint_path = get_latest_checkpoint(output_dir)
# checkpoint_path = None # Uncomment this if we want to restart training

# Initialize GRPOTrainer with PEFT enabled
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config  # <-- Enables PEFT fine-tuning
)

# Start the training
if checkpoint_path is None: # No checkpoint
    trainer.train()
else:
    trainer.train(resume_from_checkpoint=checkpoint_path) # resume training

### Evaluation with the PEFT model

We merge the trained LoRA adapter to our base model in order to evaluate using VLLM for better speed, as evaluation using `transformers` take over an hour

The merged model should behave almost the same as the unmerged PEFT model (up to floating point rounding in matrix additions), so for the purpose of evaluation, the merged model should have the same performance as the unmerged model

We load the model from the latest checkpoint and merge it (in memory)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch

base_model = 'Qwen/Qwen2.5-0.5B-Instruct'
checkpoint_path = get_latest_checkpoint(output_dir)  # LoRA adapter

tokenizer = AutoTokenizer.from_pretrained(base_model)
# If you saved tokenizer with PEFT, replace base_model with checkpoint_path above

# Load the base model
model = AutoModelForCausalLM.from_pretrained(
    base_model, torch_dtype=torch.bfloat16, device_map=None
).to("cuda")

# Apply the LoRA adapter
model = PeftModel.from_pretrained(model, checkpoint_path, is_trainable = False).to("cuda")
# model.eval()

# Merging the model
merged_model = model.merge_and_unload()

Now we merge and save the model to a temp directory, and then load the model with vLLM (which does not support reading from memory)

In [None]:
import tempfile

merged_dir = os.path.join(base_path, f"merged/Qwen-0.5B-GRPO-LoRA-r{rank}")
os.makedirs(os.path.dirname(merged_dir), exist_ok=True)

# Save merged model and tokenizer to the temp dir
merged_model.save_pretrained(merged_dir)
tokenizer.save_pretrained(merged_dir)

And perform the evaluation on the gsm8k test

(Restart the kernel before running the following code, as there will be errors when loading a model into `vllm` after it is loaded with `torch` or `transformers`, the following code is self-contained)

Remount persistant storage if needed

In [None]:
base_path = "/content/drive/MyDrive/ML_Experiments/qwen2.5_0.5B_GRPO_LoRA"

rank = 16 # (Use the same rank you performred the test with)

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

os.makedirs(os.path.dirname(base_path), exist_ok=True)

The following code is identical to the full finetuning version (and changed to be self-contained)

~~See, I know the whole reloading the kernel thing is awkward, but this project is done in a notebook and I don't want to deal with any subprocess shinenigans~~

In [None]:
from vllm import SamplingParams, LLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset

base_model = 'Qwen/Qwen2.5-0.5B-Instruct'
tokenizer = AutoTokenizer.from_pretrained(base_model) # The tokenizer is not updated during the PEFT finetuning process

# 1. Load Test Data (if not already done)
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def get_gsm8k_questions(split="test"):
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

test_data = get_gsm8k_questions(split="test")

# 2. Format Prompts Using Chat Template
test_prompts = []
for example in test_data:
    formatted_prompt = tokenizer.apply_chat_template(
        example["prompt"],
        tokenize=False,
        add_generation_prompt=True
    )
    test_prompts.append(formatted_prompt)

# 3. Load the model into vLLM
llm = LLM(
model=os.path.join(base_path, f"merged/Qwen-0.5B-GRPO-LoRA-r{rank}"),  # Point to merged model path
tensor_parallel_size=1,
gpu_memory_utilization=0.3,
trust_remote_code=True
)

# 4. Configure Sampling Parameters
sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=200,
    stop=["<|im_end|>"] # Qwen's stop token
)

# 5. Generate Responses
outputs = llm.generate(test_prompts, sampling_params)

# 6. Extract Answers
def extract_xml_answer(text: str) -> str:
    if "<answer>" in text and "</answer>" in text:
        return text.split("<answer>")[1].split("</answer>")[0].strip()
    return ""

pred_answers = [extract_xml_answer(output.outputs[0].text) for output in outputs]
true_answers = [example["answer"] for example in test_data]

# 7. Calculate Accuracy
accuracy = sum(p == t for p, t in zip(pred_answers, true_answers)) / len(true_answers)
print(f"GSM8K Test Accuracy: {accuracy * 100:.2f}%")

In [None]:
pred_answers

In [None]:
outputs

And we log the results (optional)

In [None]:
results_path = os.path.join(base_path, "/grpo_lora_results.txt")
os.makedirs(os.path.dirname(results_path), exist_ok=True)

with open(results_path, "a") as f:
    f.write(f"Rank {rank}: {accuracy:.2f}%\n")

### Delete the model to perform another round of training

In [None]:
# After downloading the checkpoint
del model
torch.cuda.empty_cache()