# Training a Small Model for Mathematical Reasoning using Reinforcement Learning (GRPO)

In this notebook, we explore the implementation of **GRPO (Generalized Reward Policy Optimization)** to train a small language model, specifically **Qwen2.5-0.5B-Instruct**. The goal is to enhance its capabilities in mathematical reasoning tasks without requiring a massive amount of labeled "chain-of-thought" data, but rather by rewarding the correct final outcome and proper formatting.

We will utilize the **GSM8K dataset**, a standard benchmark consisting of high quality grade school math problems. The objective is to teach the model to generate a valid chain of thought (reasoning) enclosed in specific XML tags before providing the final answer.



# 1. Environment Configuration and vLLM Installation

The first step involves setting up the necessary libraries. We install `vllm`, which is a high-throughput and memory-efficient library for LLM inference and serving. Although we may disable it during the specific training loop to save VRAM on smaller GPUs, it is a standard dependency for modern RL pipelines.

**Important Note:** Installing these libraries involving CUDA kernels often requires a session restart. **After running this cell, please go to 'Runtime' > 'Restart Session' to ensure the drivers are loaded correctly.**

In [None]:
!pip uninstall -y tensorflow
!pip install "numpy<2.0" "pandas<2.2.0" "scipy<1.13.0" protobuf==3.20.3
!pip install vllm
!pip install trl

# 2. Defining the Prompt Structure and System Instructions

In Reinforcement Learning from Human Feedback (RLHF), the model is guided by a **System Prompt**. This prompt establishes the expected behavior and, crucially, the **output format**.

We define a system prompt that instructs the model to respond in a specific XML structure with `<reasoning>` and `<answer>` tags.

In [None]:
# Define the prompt and load necessary libraries
import re
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""


## 2.1 Preparation and restructuring of GSM8K dataset
We define functions to extract the final answer from the dataset format and load the GSM8K training split.

In [None]:
# Functions to extract the answer and load the GSM8K dataset

def extract_xml_answer(text: str) -> str:
    """Extract the answer from XML-formatted response."""
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    """Extract answer from GSM8K format (after ####)."""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# Load and prepare the dataset
data = load_dataset('openai/gsm8k', 'main')['train']
data = data.map(lambda x: {
    'prompt': [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': x['question']}
    ],
    'answer': extract_hash_answer(x['answer'])
})

print(f"Dataset loaded with {len(data)} examples")
print(f"Example prompt: {data[0]['prompt']}")

# 3. Defining Reward Functions

We define multiple reward functions that evaluate different aspects of the model's output:

- **Correctness**: Whether the final answer matches the ground truth
- **Integer format**: Whether the answer is a valid integer
- **XML format**: Whether the response follows the expected XML structure

In [None]:
# Reward Functions

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward for correct final answer."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    """Reward for integer answers."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward for strict XML format compliance."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward for soft XML format compliance."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """Reward for presence of XML tags."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for text in contents:
        count = 0.0
        if "<reasoning>" in text:
            count += 0.125
        if "</reasoning>" in text:
            count += 0.125
        if "<answer>" in text:
            count += 0.125
        if "</answer>" in text:
            count += 0.125
        rewards.append(count)
    return rewards


# 4. Training Hyperparameters and Memory Optimization

We configure the `GRPOConfig` with hyperparameters optimized for training on consumer-grade GPUs (e.g., T4).

Key configurations:
- **Float32 precision**: Prevents gradient scaling issues on T4 architecture
- **Conservative learning rate**: 5e-6 to prevent catastrophic forgetting
- **Gradient accumulation**: Effective batch size of 16 with batch size 8

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

output_dir = "outputs/Qwen-0.5B-GRPO"
run_name = "Qwen-0.5B-GRPO-gsm8k"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=10,
    bf16=False,
    fp16=False,  # Float32 for stability on T4
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_generations=4,
    max_prompt_length=256,
    max_completion_length=512,
    max_steps=400,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=False,
    report_to="none"
)

# 5. Loading Model and Tokenizer

We load **Qwen/Qwen2.5-0.5B-Instruct**. This is a small but capable instruction-following model that is well-suited for fine-tuning on mathematical reasoning tasks.

In [None]:
# Load Model and Tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "left"

print(f"Model loaded: {model_name}")
print(f"Model parameters: {model.num_parameters():,}")

# 6. Verify GPU Availability

In [None]:
# Verify GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Device name: {torch.cuda.get_device_name(0)}")

# 7. Launching the GRPO Training Loop

We initialize the `GRPOTrainer` with our model, the reward functions, and the training arguments defined above.

When `trainer.train()` is executed, the following loop occurs:
1. The model generates 4 different attempts for a math question.
2. The reward functions score these attempts (did it get the math right? did it use the right format?).
3. The model updates its weights to increase the probability of generating the high-scoring answers (correct math + correct XML) and decrease the probability of the low-scoring ones.

In [None]:
# Initialize Trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func
    ],
    args=training_args,
    train_dataset=data,
)

# Start training
print("Starting training...")
trainer.train()
print("Training finished.")