In [1]:
# Import necessary libraries
import logging
import os
import sys
import re
import math
from dataclasses import dataclass, field
from typing import List, Optional

# Import PyTorch and Hugging Face Transformers
import torch
import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_utils import get_last_checkpoint

# Import dataset utilities
import datasets
from datasets import load_dataset

# Import libraries from TRL (Transformers Reinforcement Learning)
from trl import (
    AutoModelForCausalLMWithValueHead, 
    PPOConfig, 
    PPOTrainer, 
    GRPOTrainer, 
    GRPOConfig, 
    SFTTrainer
)


from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify

In [2]:
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
OUTPUT_DIR = "/tmp/Qwen-GRPO-training" # For saving our trained model

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize tokenizer with chat template
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right"
)

# Set pad token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Vocabulary size: {len(tokenizer)}")
print(f"Model max length: {tokenizer.model_max_length}")
print(f"Pad token: {tokenizer.pad_token}")
print(f"EOS token: {tokenizer.eos_token}")

Vocabulary size: 151665
Model max length: 131072
Pad token: <|endoftext|>
EOS token: <|im_end|>


In [3]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)

print(f"Model parameters: {model.num_parameters():,}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

Model parameters: 494,032,768


In [4]:
# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move model to the appropriate device
model.to(device)

# Test basic inference
def test_model_inference(user_input: str):
    """Test basic model inference with the loaded model and tokenizer."""
    messages = [
        {"role": "system", "content": "You are Qwen, a helpful assistant."},
        {"role": "user", "content": user_input}
    ]

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Tokenize and generate
    inputs = tokenizer(text, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        do_sample=True,
        temperature=0.7
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Test the model
test_input = "how are you?"
response = test_model_inference(test_input)
print(f"Test Input: {test_input}")
print(f"Model Response: {response}")

Using device: cpu
Test Input: how are you?
Model Response: system
You are Qwen, a helpful assistant.
user
how are you?
assistant
Hello! I'm just an AI language model created by Alibaba Cloud. How can I help you today?


In [5]:
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)
def make_conversation(example):
    """Convert dataset examples into conversation format."""
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["problem"]},
        ],
    }

In [6]:
# Load the "AI-MO/NuminaMath-TIR" dataset from DigitalLearningGmbH
MATH_le = load_dataset("AI-MO/NuminaMath-TIR", "default")  

# Access the first sample in the training set
MATH_le['train'][0]

{'problem': 'What is the coefficient of $x^2y^6$ in the expansion of $\\left(\\frac{3}{5}x-\\frac{y}{2}\\right)^8$?  Express your answer as a common fraction.',
 'solution': "To determine the coefficient of \\(x^2y^6\\) in the expansion of \\(\\left(\\frac{3}{5}x - \\frac{y}{2}\\right)^8\\), we can use the binomial theorem.\n\nThe binomial theorem states:\n\\[\n(a + b)^n = \\sum_{k=0}^{n} \\binom{n}{k} a^{n-k} b^k\n\\]\n\nIn this case, \\(a = \\frac{3}{5}x\\), \\(b = -\\frac{y}{2}\\), and \\(n = 8\\).\n\nWe are interested in the term that contains \\(x^2y^6\\). In the general term of the binomial expansion:\n\\[\n\\binom{8}{k} \\left(\\frac{3}{5}x\\right)^{8-k} \\left(-\\frac{y}{2}\\right)^k\n\\]\n\nTo get \\(x^2\\), we need \\(8 - k = 2\\), thus \\(k = 6\\).\n\nSubstituting \\(k = 6\\) into the expression:\n\\[\n\\binom{8}{6} \\left(\\frac{3}{5}x\\right)^{8-6} \\left(-\\frac{y}{2}\\right)^6 = \\binom{8}{6} \\left(\\frac{3}{5}x\\right)^2 \\left(-\\frac{y}{2}\\right)^6\n\\]\n\nNow, we w

In [7]:
dataset = load_dataset(
        "AI-MO/NuminaMath-TIR",
        name="default",
        split=['train', 'test']
    )
    
    # Convert splits into dictionary
dataset = {
    'train': dataset[0],
    'test': dataset[1]
}

In [8]:
dataset['train'].column_names

['problem', 'solution', 'messages']

In [9]:
dataset["train"] = dataset["train"].map(make_conversation)
    

In [10]:
dataset["train"][0]

{'problem': 'What is the coefficient of $x^2y^6$ in the expansion of $\\left(\\frac{3}{5}x-\\frac{y}{2}\\right)^8$?  Express your answer as a common fraction.',
 'solution': "To determine the coefficient of \\(x^2y^6\\) in the expansion of \\(\\left(\\frac{3}{5}x - \\frac{y}{2}\\right)^8\\), we can use the binomial theorem.\n\nThe binomial theorem states:\n\\[\n(a + b)^n = \\sum_{k=0}^{n} \\binom{n}{k} a^{n-k} b^k\n\\]\n\nIn this case, \\(a = \\frac{3}{5}x\\), \\(b = -\\frac{y}{2}\\), and \\(n = 8\\).\n\nWe are interested in the term that contains \\(x^2y^6\\). In the general term of the binomial expansion:\n\\[\n\\binom{8}{k} \\left(\\frac{3}{5}x\\right)^{8-k} \\left(-\\frac{y}{2}\\right)^k\n\\]\n\nTo get \\(x^2\\), we need \\(8 - k = 2\\), thus \\(k = 6\\).\n\nSubstituting \\(k = 6\\) into the expression:\n\\[\n\\binom{8}{6} \\left(\\frac{3}{5}x\\right)^{8-6} \\left(-\\frac{y}{2}\\right)^6 = \\binom{8}{6} \\left(\\frac{3}{5}x\\right)^2 \\left(-\\frac{y}{2}\\right)^6\n\\]\n\nNow, we w

In [11]:
def load_math_dataset():
    """Load and prepare the mathematics dataset."""
    dataset = load_dataset(
        "AI-MO/NuminaMath-TIR",
        name="default",
        split=['train', 'test']
    )
    
    # Convert splits into dictionary
    dataset = {
        'train': dataset[0],
        'test': dataset[1]
    }
    
    # Apply conversation format
    for split in dataset:
        dataset[split] = dataset[split].map(make_conversation)

        # Remove 'messages' column if exists
        if "messages" in dataset[split].column_names:
            dataset[split] = dataset[split].remove_columns("messages")
    
    return dataset

In [12]:
dataset = load_math_dataset()

print(f"Train set size: {len(dataset['train'])}")
print(f"Test set size: {len(dataset['test'])}")

Train set size: 72441
Test set size: 99


In [13]:
dataset['train'][0]

{'problem': 'What is the coefficient of $x^2y^6$ in the expansion of $\\left(\\frac{3}{5}x-\\frac{y}{2}\\right)^8$?  Express your answer as a common fraction.',
 'solution': "To determine the coefficient of \\(x^2y^6\\) in the expansion of \\(\\left(\\frac{3}{5}x - \\frac{y}{2}\\right)^8\\), we can use the binomial theorem.\n\nThe binomial theorem states:\n\\[\n(a + b)^n = \\sum_{k=0}^{n} \\binom{n}{k} a^{n-k} b^k\n\\]\n\nIn this case, \\(a = \\frac{3}{5}x\\), \\(b = -\\frac{y}{2}\\), and \\(n = 8\\).\n\nWe are interested in the term that contains \\(x^2y^6\\). In the general term of the binomial expansion:\n\\[\n\\binom{8}{k} \\left(\\frac{3}{5}x\\right)^{8-k} \\left(-\\frac{y}{2}\\right)^k\n\\]\n\nTo get \\(x^2\\), we need \\(8 - k = 2\\), thus \\(k = 6\\).\n\nSubstituting \\(k = 6\\) into the expression:\n\\[\n\\binom{8}{6} \\left(\\frac{3}{5}x\\right)^{8-6} \\left(-\\frac{y}{2}\\right)^6 = \\binom{8}{6} \\left(\\frac{3}{5}x\\right)^2 \\left(-\\frac{y}{2}\\right)^6\n\\]\n\nNow, we w

In [None]:
# Cold-start SFT data prep (few-shot long-CoT style)
import re

COLD_START_SYSTEM_PROMPT = (
    "You are a math tutor. Solve step-by-step, then give the final result. "
    "Use <think>...</think> for reasoning and <answer>...</answer> for final answer."
)

FEW_SHOT_COT_EXAMPLES = [
    {
        "question": "What is 2 + 3 * 4?",
        "think": "Apply order of operations: 3*4=12, then 2+12=14.",
        "answer": "14",
    },
    {
        "question": "Solve 5x - 10 = 0.",
        "think": "Add 10 on both sides to get 5x=10, then divide by 5.",
        "answer": "x = 2",
    },
]


def extract_final_answer(solution_text: str) -> str:
    boxed = re.findall(r"\\boxed\{([^}]*)\}", solution_text)
    if boxed:
        return boxed[-1].strip()

    marker = re.findall(r"####\s*(.+)", solution_text)
    if marker:
        return marker[-1].strip()

    lines = [line.strip() for line in solution_text.splitlines() if line.strip()]
    return lines[-1] if lines else solution_text.strip()


def make_cold_start_text(example):
    messages = [{"role": "system", "content": COLD_START_SYSTEM_PROMPT}]

    for ex in FEW_SHOT_COT_EXAMPLES:
        messages.append({"role": "user", "content": ex["question"]})
        messages.append(
            {
                "role": "assistant",
                "content": f"<think>{ex['think']}</think>\n<answer>{ex['answer']}</answer>",
            }
        )

    messages.append({"role": "user", "content": example["problem"]})
    messages.append(
        {
            "role": "assistant",
            "content": (
                f"<think>{example['solution']}</think>\n"
                f"<answer>{extract_final_answer(example['solution'])}</answer>"
            ),
        }
    )

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )

    return {"text": text}


cold_start_limit = min(256, len(dataset["train"]))  # Mac-safe debug size
cold_start_sft_train = dataset["train"].select(range(cold_start_limit)).map(
    make_cold_start_text,
    remove_columns=dataset["train"].column_names,
)

print(f"Cold-start SFT examples: {len(cold_start_sft_train)}")


In [None]:
# Preview one cold-start SFT sample
print(cold_start_sft_train[0]["text"][:1500])


In [None]:
# Build cold-start SFT trainer (version + Mac compatible)
import inspect

if torch.cuda.is_available():
    sft_device = torch.device("cuda")
    sft_dtype = torch.bfloat16
elif torch.backends.mps.is_available():
    sft_device = torch.device("mps")
    sft_dtype = torch.float16
else:
    sft_device = torch.device("cpu")
    sft_dtype = torch.float32

print(f"SFT device: {sft_device} | dtype: {sft_dtype}")

sft_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=sft_dtype,
).to(sft_device)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

cold_start_output_dir = "./qwen2_cold_start_sft"

raw_sft_args = {
    "output_dir": cold_start_output_dir,
    "num_train_epochs": 1,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 8,
    "learning_rate": 2e-5,
    "logging_steps": 5,
    "save_strategy": "steps",
    "save_steps": 50,
    "save_total_limit": 1,
    "report_to": "none",
    "bf16": False,
    "fp16": False,
    "dataloader_num_workers": 0,
    "remove_unused_columns": False,
}

accepted_args = set(inspect.signature(TrainingArguments.__init__).parameters)
sft_args = TrainingArguments(**{k: v for k, v in raw_sft_args.items() if k in accepted_args})

sft_init_params = set(inspect.signature(SFTTrainer.__init__).parameters)
trainer_kwargs = {
    "model": sft_model,
    "args": sft_args,
    "train_dataset": cold_start_sft_train,
}

if "tokenizer" in sft_init_params:
    trainer_kwargs["tokenizer"] = tokenizer
elif "processing_class" in sft_init_params:
    trainer_kwargs["processing_class"] = tokenizer

if "dataset_text_field" in sft_init_params:
    trainer_kwargs["dataset_text_field"] = "text"
elif "formatting_func" in sft_init_params:
    trainer_kwargs["formatting_func"] = lambda ex: ex["text"]

if "max_seq_length" in sft_init_params:
    trainer_kwargs["max_seq_length"] = 1024
if "packing" in sft_init_params:
    trainer_kwargs["packing"] = False

cold_start_trainer = SFTTrainer(**trainer_kwargs)
print("Cold-start SFT trainer ready")


In [None]:
# Train and save cold-start SFT checkpoint
cold_start_trainer.train()
cold_start_trainer.save_model(cold_start_output_dir)
tokenizer.save_pretrained(cold_start_output_dir)

GRPO_BASE_MODEL = cold_start_output_dir
print(f"Saved cold-start model to: {cold_start_output_dir}")
print(f"GRPO_BASE_MODEL set to: {GRPO_BASE_MODEL}")


In [None]:
# Quick test from the cold-start SFT model
sft_model.eval()

test_question = "If 3x + 5 = 20, what is x?"
messages = [
    {"role": "system", "content": COLD_START_SYSTEM_PROMPT},
    {"role": "user", "content": test_question},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(sft_device)

with torch.no_grad():
    output = sft_model.generate(
        **inputs,
        max_new_tokens=96,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id,
    )

print(tokenizer.decode(output[0], skip_special_tokens=True))


In [14]:
def accuracy_reward(completions, **kwargs):
    """
    Reward function to check if the model's response is mathematically 
    equivalent to the ground truth solution.
    Uses latex2sympy2 for parsing and math_verify for validation.
    """
    
    # Extract responses
    contents = [completion[0]["content"] for completion in completions]
    rewards = []

    solutions = kwargs.get("solution") # Get solutions from kwargs
    
    for content, sol in zip(contents, solutions):
        # Parse the ground truth solution
        gold_parsed = parse(sol, extraction_mode="first_match", 
                            extraction_config=[LatexExtractionConfig()])
        
        if gold_parsed:  # Check if parsing was successful
            # Parse the model's answer with relaxed normalization
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            # Reward 1.0 if correct, 0.0 if incorrect
            reward = float(verify(answer_parsed, gold_parsed))
        else:
            # If ground truth cannot be parsed, assign neutral reward (0.5)
            reward = 0.5
            print("Warning: Failed to parse gold solution:", sol)

        rewards.append(reward)
    
    return rewards

In [15]:
def format_reward(completions, **kwargs):
    """
    Reward function to check if the completion has the correct format:
    <think>...</think> <answer>...</answer>.
    """
    # Define the regex pattern for the desired format
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"

    # Extract the content from each completion
    completion_contents = [completion[0]["content"] for completion in completions]

    # Check if each completion matches the pattern
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE)
               for content in completion_contents]

    # Reward 1.0 for correct format, 0.0 otherwise
    return [1.0 if match else 0.0 for match in matches]

In [16]:
def reasoning_steps_reward(completions, **kwargs):
    r"""
    Reward function to encourage clear step-by-step reasoning.
    It looks for patterns like "Step 1:", numbered lists, bullet points,
    and transition words.
    """
    # Regex pattern to find indicators of reasoning steps
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"

    # Extract completion contents
    completion_contents = [completion[0]["content"] for completion in completions]

    # Count the number of reasoning step indicators in each completion
    matches = [len(re.findall(pattern, content, re.MULTILINE))
               for content in completion_contents]

    # Reward is proportional to the number of reasoning steps, maxing out at 1.0
    # We're using a "magic number" 3 here - encourage at least 3 steps for full reward
    return [min(1.0, count / 3) for count in matches]

In [17]:
def get_cosine_scaled_reward(
    min_value_wrong: float = -0.5,
    max_value_wrong: float = -0.1,
    min_value_correct: float = 0.8,
    max_value_correct: float = 1.0,
    max_len: int = 1000,
):
    """
    Returns a cosine scaled reward function. This function scales the accuracy reward
    based on completion length. Shorter correct solutions get higher rewards,
    longer incorrect solutions get less penalty.
    """
    def cosine_scaled_reward(completions, solution, accuracy_rewards, **kwargs):
        """
        Cosine scaled reward function that adjusts accuracy rewards based on completion length.
        """
        contents = [completion[0]["content"] for completion in completions]
        rewards = []

        for content, sol, acc_reward in zip(contents, solution, accuracy_rewards):
            gen_len = len(content)  # Length of the generated answer
            progress = gen_len / max_len # How far we are to max length
            cosine = math.cos(progress * math.pi) # Cosine value based on progress

            if acc_reward > 0.5: # Assuming accuracy_reward gives ~1.0 for correct answers
                min_value = min_value_correct
                max_value = max_value_correct
            else: # Incorrect answer
                min_value = max_value_wrong  # Note the swap!
                max_value = min_value_wrong

            # Cosine scaling formula!
            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
            rewards.append(float(reward))
        return rewards
    return cosine_scaled_reward

In [18]:
def get_repetition_penalty_reward(ngram_size: int = 3, max_penalty: float = -0.1):
    """
    Returns a repetition penalty reward function. Penalizes repetitions of n-grams
    in the generated text.
    """
    if max_penalty > 0:
        raise ValueError(f"max_penalty {max_penalty} should not be positive")

    def zipngram(text: str, ngram_size: int):
        """Helper function to generate n-grams from text."""
        words = text.lower().split() # Lowercase and split into words
        return zip(*[words[i:] for i in range(ngram_size)]) # Create n-grams

    def repetition_penalty_reward(completions, **kwargs) -> float:
        """
        Repetition penalty reward function.
        """
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        for completion in contents:
            if completion == "": # No penalty for empty completions
                rewards.append(0.0)
                continue
            if len(completion.split()) < ngram_size: # No penalty for short completions
                rewards.append(0.0)
                continue

            ngrams = set() # Use a set to store unique n-grams
            total = 0
            for ng in zipngram(completion, ngram_size): # Generate n-grams
                ngrams.add(ng) # Add n-gram to the set (duplicates are ignored)
                total += 1 # Count total n-grams

            # Calculate scaling factor: more repetition -> higher scaling
            scaling = 1 - len(ngrams) / total
            reward = scaling * max_penalty # Apply penalty based on scaling
            rewards.append(reward)
        return rewards
    return repetition_penalty_reward

In [19]:
@dataclass
class GRPOScriptArguments:
    """
    Script arguments for GRPO training, specifically related to reward functions.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={
            "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
        },
    )
    cosine_min_value_wrong: float = field(
        default=-0.5,
        metadata={"help": "Minimum reward for cosine scaling for wrong answers"},
    )
    cosine_max_value_wrong: float = field(
        default=-0.1,
        metadata={"help": "Maximum reward for cosine scaling for wrong answers"},
    )
    cosine_min_value_correct: float = field(
        default=0.8,
        metadata={"help": "Minimum reward for cosine scaling for correct answers"},
    )
    cosine_max_value_correct: float = field(
        default=1.0,
        metadata={"help": "Maximum reward for cosine scaling for correct answers"},
    )
    cosine_max_len: int = field(
        default=1000,
        metadata={"help": "Maximum length for cosine scaling"},
    )

    repetition_n_grams: int = field(
        default=3,
        metadata={"help": "Number of n-grams for repetition penalty reward"},
    )
    repetition_max_penalty: float = field(
        default=-0.1,
        metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
    )

In [20]:
# Define TrainingArguments from transformers
import inspect
import os
import shutil

if os.path.exists(OUTPUT_DIR):
    shutil.rmtree(OUTPUT_DIR)  # This manually deletes the old folder

raw_training_kwargs = {
    "output_dir": OUTPUT_DIR,
    "overwrite_output_dir": True,
    "num_train_epochs": 1,
    "per_device_train_batch_size": 8,
    "per_device_eval_batch_size": 16,
    "gradient_accumulation_steps": 2,
    "learning_rate": 5e-5,
    "warmup_ratio": 0.1,
    "weight_decay": 0.01,
    "logging_steps": 10,
    "eval_strategy": "steps",
    "evaluation_strategy": "steps",  # fallback for older transformers versions
    "eval_steps": 50,
    "save_strategy": "steps",
    "save_steps": 50,
    "save_total_limit": 2,
    "dataloader_num_workers": 2,
    "seed": 42,
    "bf16": True,
    "push_to_hub": False,
    "gradient_checkpointing": True,
    "report_to": "none",
    "remove_unused_columns": False,
}

accepted_params = set(inspect.signature(TrainingArguments.__init__).parameters)
training_kwargs = {k: v for k, v in raw_training_kwargs.items() if k in accepted_params}

# Keep only one eval strategy key depending on installed transformers version.
if "eval_strategy" in training_kwargs and "evaluation_strategy" in training_kwargs:
    if "eval_strategy" in accepted_params:
        training_kwargs.pop("evaluation_strategy", None)
    else:
        training_kwargs.pop("eval_strategy", None)

training_args = TrainingArguments(**training_kwargs)



warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


In [21]:
@dataclass
class ModelConfig:
    """
    Configuration for the model.
    """
    model_name_or_path: str = field(
        default=MODEL_NAME, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    model_revision: Optional[str] = field(
        default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
    )
    torch_dtype: Optional[str] = field(
        default="bfloat16", metadata={"help": "Override the default `torch_dtype` and load the model under this dtype."}
    )
    trust_remote_code: bool = field(
        default=True, metadata={"help": "Trust remote code when loading model and tokenizer."}
    )
    attn_implementation: Optional[str] = field(
        default="flash_attention_2", metadata={"help": "Attention implementation to use. 'flash_attention_2' or None"}
    )

In [22]:
# Instantiate configuration objects
script_args = GRPOScriptArguments()
model_args = ModelConfig()

In [23]:
reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={
            "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
        },
    )
reward_funcs

Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x1077ce120>,default_factory=<function <lambda> at 0x124cb0670>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"}),kw_only=<dataclasses._MISSING_TYPE object at 0x1077ce120>,doc=None,_field_type=None)

In [24]:
for func in script_args.reward_funcs:
    print(f"Selected reward function: {func}")

Selected reward function: accuracy
Selected reward function: format


In [25]:
def get_reward_functions(script_args):
    """
    Returns a list of reward functions based on the script arguments.
    """
    reward_funcs_list = []
    reward_funcs_registry = {
        "accuracy": accuracy_reward,  # Assuming accuracy_reward is defined in previous steps
        "format": format_reward,      # Assuming format_reward is defined in previous steps
        "reasoning_steps": reasoning_steps_reward, # Assuming reasoning_steps_reward is defined
        "cosine": get_cosine_scaled_reward( # Assuming get_cosine_scaled_reward is defined
            min_value_wrong=script_args.cosine_min_value_wrong,
            max_value_wrong=script_args.cosine_max_value_wrong,
            min_value_correct=script_args.cosine_min_value_correct,
            max_value_correct=script_args.cosine_max_value_correct,
            max_len=script_args.cosine_max_len,
        ),
        "repetition_penalty": get_repetition_penalty_reward( # Assuming get_repetition_penalty_reward is defined
            ngram_size=script_args.repetition_n_grams,
            max_penalty=script_args.repetition_max_penalty,
        ),
    }

    for func_name in script_args.reward_funcs:
        if func_name not in reward_funcs_registry:
            raise ValueError(f"Reward function '{func_name}' not found in registry.")
        reward_funcs_list.append(reward_funcs_registry[func_name])

    return reward_funcs_list

In [26]:

logger = logging.getLogger(__name__)

class LoggingCallback(TrainerCallback):
    """
    A simple callback for logging training information at specific steps.
    """
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % args.logging_steps == 0:
            if state.log_history and len(state.log_history) > 0:
                logger.info(f"Step {state.global_step}: Loss = {state.log_history[-1].get('loss', None)}, Learning Rate = {state.log_history[-1].get('learning_rate', None)}")
            else:
                logger.info(f"Step {state.global_step}: No logging information available yet")

def get_callbacks(training_args, model_args, script_args):
    """
    Returns a list of callbacks to be used during training.
    For now, it includes only the LoggingCallback. You can extend this to add more callbacks.
    """
    callbacks = [LoggingCallback()] # Instantiate our LoggingCallback
    return callbacks

In [27]:
# Get reward functions and callbacks
reward_functions = get_reward_functions(script_args)
callbacks = get_callbacks(training_args, model_args, script_args)

In [28]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)

In [29]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

class ScratchGRPOTrainer:
    def __init__(self, model, ref_model, tokenizer, reward_funcs, training_args, script_args):
        self.model = model
        self.ref_model = ref_model 
        self.ref_model.eval() # Reference model is ALWAYS frozen
        self.tokenizer = tokenizer
        self.reward_funcs = reward_funcs
        self.args = training_args
        self.script_args = script_args
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)

    def compute_log_probs(self, model, input_ids, attention_mask, prompt_len):
        """
        Calculates log-probs for the 'completion' tokens ONLY.
        """
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [Batch*G, SeqLen, Vocab]
        
        # Shift so logit at i predicts token at i+1
        logits = logits[:, :-1, :] 
        labels = input_ids[:, 1:] 
        
        log_probs = F.log_softmax(logits, dim=-1)
        per_token_log_probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2)
        
        # KEY FIX: Mask out the prompt tokens. We only care about the answer.
        # Everything before 'prompt_len - 1' is the question.
        return per_token_log_probs[:, (prompt_len - 1):]

    def train_step(self, batch_prompts, batch_solutions):
        G = 8 # Group size
        device = self.model.device
        
        # 1. GENERATION (Done only ONCE per step)
        # We take the first prompt in the batch for this demo
        prompt = batch_prompts[0]
        solution = batch_solutions[0]
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
        prompt_len = inputs["input_ids"].shape[1]

        # Generate G completions
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs, 
                max_new_tokens=256,
                num_return_sequences=G,
                do_sample=True,
                temperature=0.9,
                pad_token_id=self.tokenizer.pad_token_id
            ) # Shape: [G, SeqLen]

        # 2. REWARDS
        # Decode only the NEWLY generated tokens for scoring
        completion_ids = output_ids[:, prompt_len:]
        completions_text = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
        
        formatted_completions = [[{"content": text}] for text in completions_text]
        
        rewards = torch.zeros(G, device=device)
        for func in self.reward_funcs:
            scores = func(completions=formatted_completions, solution=[solution]*G)
            rewards += torch.tensor(scores, device=device)

        # 3. ADVANTAGES (The "Secret Sauce" of GRPO)
        # We normalize rewards within the group of G
        advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) # Shape: [G]

        # 4. LOG-PROBS (Training the 'Brain')
        # Get log-probs from the model we are training
        curr_log_probs = self.compute_log_probs(self.model, output_ids, None, prompt_len)
        
        # Get log-probs from the frozen reference model
        with torch.no_grad():
            ref_log_probs = self.compute_log_probs(self.ref_model, output_ids, None, prompt_len)

        # 5. GRPO LOSS CALCULATION
        # Ratio of new policy vs old policy
        ratio = torch.exp(curr_log_probs - ref_log_probs) # Shape: [G, CompletionLen]
        
        # Combine with advantages (Broadcasting [G] across [G, CompletionLen])
        surrogate_loss = -(ratio * advantages.unsqueeze(1)).mean()

        # KL Penalty: Don't let the model drift too far from the reference
        kl_div = torch.exp(ref_log_probs - curr_log_probs) - (ref_log_probs - curr_log_probs) - 1
        kl_loss = 0.1 * kl_div.mean()

        total_loss = surrogate_loss + kl_loss

        # 6. BACKPROPAGATION
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        return total_loss.item()

In [None]:
# 1. Initialize models and tokenizer
from torch.utils.data import DataLoader

grpo_base_model = GRPO_BASE_MODEL if "GRPO_BASE_MODEL" in globals() else "Qwen/Qwen2-0.5B"
print(f"GRPO base model: {grpo_base_model}")

if torch.cuda.is_available():
    grpo_device = torch.device("cuda")
elif torch.backends.mps.is_available():
    grpo_device = torch.device("mps")
else:
    grpo_device = torch.device("cpu")

print(f"GRPO device: {grpo_device}")

model = AutoModelForCausalLM.from_pretrained(grpo_base_model, trust_remote_code=True).to(grpo_device)
ref_model = AutoModelForCausalLM.from_pretrained(grpo_base_model, trust_remote_code=True).to(grpo_device)
ref_model.eval()  # Reference model never trains

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 2. Instantiate your custom trainer
trainer = ScratchGRPOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    reward_funcs=[accuracy_reward],
    training_args=training_args,
    script_args=script_args,
)

# 3. Build a simple dataloader for the custom trainer
def collate_batch(examples):
    return {
        "prompts": [ex["problem"] for ex in examples],
        "solutions": [ex["solution"] for ex in examples],
    }

train_dataset = dataset["train"] if isinstance(dataset, dict) else dataset

dataloader = DataLoader(
    train_dataset,
    batch_size=1,  # Mac-safe debug batch size
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=0,
)

# 4. The training loop (short debug loop)
epochs = 1
max_steps_per_epoch = 20
for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        if step >= max_steps_per_epoch:
            break

        loss = trainer.train_step(
            batch_prompts=batch["prompts"],
            batch_solutions=batch["solutions"],
        )
        if step % 2 == 0:
            print(f"Epoch {epoch} | Step {step} | Loss: {loss:.4f}")



In [None]:
# Save trained model + tokenizer
save_dir = "./grpo_qwen2_0.5b_finetuned"

model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

print(f"Saved model to: {save_dir}")


In [None]:
# Reload and test generation
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

test_dir = "./grpo_qwen2_0.5b_finetuned"
device = "cuda" if torch.cuda.is_available() else "cpu"

test_tokenizer = AutoTokenizer.from_pretrained(test_dir)
test_model = AutoModelForCausalLM.from_pretrained(test_dir).to(device)
test_model.eval()

prompt = "Solve: If 3x + 5 = 20, what is x?"

inputs = test_tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    out = test_model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=test_tokenizer.eos_token_id,
    )

print(test_tokenizer.decode(out[0], skip_special_tokens=True))


In [None]:
sample = dataset["test"][0]
print("Q:", sample["problem"])
print("GT:", sample["solution"][:300], "...")
