# Grouped Relative Policy Optimization

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re

from dataclasses import dataclass
from datasets import load_dataset, Dataset
from typing import List, Optional, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM

### Configuration
- Training hyperparameters (learning rate, batch size, etc.)
- Model configuration (model name, output dir, etc.)
- GRPO specific parameters (epsilon, beta, group size G)

In [None]:
@dataclass
class GRPOConfig:
    """Config for GRPO."""
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

### Data Loading & Preprocessing Cell
- Load the dataset (e.g., GSM8K)
- Define prompt templates and formatters
- Create data loaders

In [5]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

In [6]:
def extract_xml_answer(text: str) -> str:
    return text.split("<answer>")[-1].split("</answer>")[0].strip()

def extract_hash_answer(text: str) -> str:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

In [None]:
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) 
    return data 

# dataset = get_gsm8k_questions()

### Base Model & Tokenizer Setup Cell

- Load the base model and tokenizer
- Configure model parameters
- Move model to device

In [None]:
model_name = "meta-llama/Llama-3.2-1B"
output_dir = "outputs/Llama-1B-GRPO"
run_name = "LlamaR1-1B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
).to(GRPOConfig.device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

## Reward Functions
#### A. Format Rewards
- Check XML tags structure
- Validate reasoning and answer sections

#### B. Correctness/Outcome Rewards
- Extract Answers
- Compare with ground truth
- Normalize within group

#### Maybe more

In [4]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]


In [9]:
def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

In [11]:
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

In [12]:
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

In [13]:
def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

In [14]:
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

## GRPO Core Functionality

#### A. Sampling Logic
- Generate group outputs for each prompt
- Handle batching and tokenization

#### B. Advantage Calculation
- Calculate Group Statistics
- Normalize Rewards
- Compute advantages

#### C. Policy Update Logic
- Calculate Policy Ratios
- Apply Clipping
- KL Divergence calculation
- Final loss computation

## Training Loop
- Main training iteration
- Logging and checkpointing
- Evaluation on val set

## Evaluation Cell
- Inference Pipeline
- Metrics calculation
- Result visualization