# Coding GRPO from Scratch: A Guide to Distributed Implementation with Qwen2.5-1.5B-Instruct

In this tutorial, we demonstrate how to build a distributed reinforcement learning (RL) pipeline using the GRPO (Group Relative Policy Optimization) method to finetune a language model for math, logic, and coding tasks. These are tasks for which there exist a unique correct answer that can be easily verified with the ground truth answer using a simple string comparison.

GRPO was invented by DeepSeek and used to finetune DeepSeek R1 and R1-Zero models to excel in math and logic tasks by learning to generate a chain of thought (CoT).

The objective of this tutorial is to transform a generalist language model **Qwen2.5-1.5B-Instruct** into a math problem solver. We will code GRPO from scratch and then integrate it with several popular libraries and tools to implement a distributed training pipeline, including:

- **PyTorch:** For tensor operations and distributed training.
- **Hugging Face Transformers:** For loading pre-trained language models and tokenizers.
- **FlashAttention2:** For optimized attention mechanisms that help reduce memory usage and improve training speed (if CUDA is available)

The tutorial is organized into several parts. We start with the basic setup and imports, then move on to data formatting and answer extraction, dataset preparation, evaluation functions, reward functions, training setup and execution, and finally loading and testing the model. In the process, we implement the GRPO algorithm from scratch.

**Note:** Modified from https://github.com/aburkov/theLMbook/blob/main/GRPO_From_Scratch_Multi_GPU_DataParallel_Qwen_2_5_1_5B_Instruct.ipynb

## Part 1: Basic Setup and Imports

In this first part, we install and import all necessary modules. We also set up our environment by configuring random seeds for reproducibility and initializing environment variables required for experiment tracking. In addition, we install and import libraries that provide optimized transformer attention mechanisms (FlashAttention2) and reporting (Weights and Biases):

In [1]:

# Import necessary libraries
# Basic Python libraries for various operations
import random
import copy
import re
import os
import numpy as np

# PyTorch and related libraries for deep learning
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader, Dataset

# Hugging Face libraries for transformer models
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList

class NanSafeLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        # Replace NaN, inf, or -inf values with a very negative number
        safe_scores = torch.nan_to_num(scores, nan=-1e9, posinf=-1e9, neginf=-1e9)
        return safe_scores

# Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

from datasets import load_dataset

def set_random_seed(seed: int = 42):
    """
    Set the random seed for reproducibility across Python, NumPy, and PyTorch.

    Args:
        seed (int): The seed value to use for random number generation.

    Returns:
        None

    Explanation:
        1. Sets seed for Python's built-in random module for basic random operations.
        2. Sets seed for NumPy, ensuring consistent random number generation in array operations.
        3. Sets seed for PyTorch CPU operations.
        4. If CUDA is available, sets seed for all GPU devices.
        5. Configures cuDNN to ensure deterministic behavior:
           - Sets deterministic flag to True, ensuring reproducible results.
           - Disables benchmarking to prevent algorithm selection based on hardware.

    Note:
        Setting deterministic behavior may impact performance but ensures consistent results
        across multiple runs, which is crucial for debugging and research.
    """
    # Set the seed for Python's built-in random module
    random.seed(seed)
    # Set the seed for NumPy
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Call the function to set random seed for reproducibility
set_random_seed(42)

## Part 2: Data Formatting and Answer Extraction
In this section, we define how our data is formatted and how to extract the answer segments from both the model's output and the dataset. To ensure that the model outputs its response in a consistent format, we define a system prompt. The prompt instructs the model to generate output in an XML-like format containing `<reasoning>` and `<answer>` tags. We then provide two functions:
1. **`extract_answer_from_model_output`:** This function takes the model's output text and extracts the content within the `<answer>` tags.
2. **`extract_answer_from_dataset`:** This function extracts the expected answer from the GSM8K dataset, which separates the answer using the `"####"` delimiter:

In [2]:
# enable Chain of Draft (CoD) prompting
SYSTEM_PROMPT = """
You are solving a math problem. Your response must STRICTLY follow this format:

<reasoning>
Step 1: [5 words max]
Step 2: [5 words max]
Step 3: [5 words max]
[continue with numbered steps]
</reasoning>

<answer>
[single number only]
</answer>

DO NOT repeat these instructions. DO NOT explain the format. DO NOT write anything outside the tags. DO NOT write more than 5 words per reasoning step.

Example:
Question: What is 25 x 4?
<reasoning>
25 x 4.
Multiply tens: 20 x 4 = 80.
Multiply units: 5 x 4 = 20.
Add parts: 80 + 20.
</reasoning>
<answer>
100
</answer>
"""

def extract_answer_from_model_output(text):
   """
   Extracts the value from the last <answer> tag in the text.

   Args:
       text (str): The model-generated text containing XML-style <answer> tags.

   Returns:
       str or None: The content inside the <answer> tags, or None if no valid answer is found.

   Explanation:
       1. Splits the text on the <answer> tag to isolate content after the tag.
       2. Checks if at least one <answer> tag exists in the text.
       3. For the last <answer> segment:
          - Verifies it contains a closing </answer> tag.
          - Extracts only the content between the tags.
       4. Returns None if the answer is empty (just "...") or if tags are missing.
   """
   # Split on <answer> and take everything after the last occurrence
   parts = text.split("<answer>")
   if len(parts) < 2:  # No <answer> tag found
       return None
   last_part = parts[-1]

   # Extract content up to </answer>
   if "</answer>" not in last_part:
       return None
   answer = last_part.split("</answer>")[0].strip()
   return None if answer == "..." else answer

def extract_answer_from_dataset(text):
   """
   Extracts the answer from the GSM8K dataset examples.

   Args:
       text (str): The dataset example text containing a question and answer.

   Returns:
       str or None: The extracted answer part after the '####' delimiter, or None if not found.

   Explanation:
       1. Checks if the text contains the '####' delimiter that separates question from answer.
       2. If found, splits the text at this delimiter and returns the second part (the answer).
       3. The answer is stripped of leading/trailing whitespace.
       4. Returns None if no delimiter is present.
   """
   if "####" not in text:
       return None
   return text.split("####")[1].strip()

## Part 3: Dataset Preparation

In this part we prepare the GSM8K dataset for training. GSM8K is a dataset of 8.5K high quality linguistically diverse grade school math word problems created by human problem writers. We will use the examples from this dataset to train our model in the reinforcement learning (RL) paradigm: the model will generate several sample probelem solutions, we will compare these solutions to the ground truth number from a GSM8K example and, if there's a match, we will provide a high reward to the RL algorithm (GRPO) which will update the model's weights so that the chance of getting the high reward next time is increased.

We first load the dataset from Hugging Face and then format each example to include a system prompt and a user prompt. We also extract the expected answer from the dataset. Two helper functions are defined here:

1. **`prepare_dataset`:** Loads and prepares the GSM8K dataset by creating a prompt that includes a system prompt (with the formatting instructions) and a user message (the question). It also extracts the answer from the dataset.
2. **`build_prompt`:** Concatenates the list of message dictionaries into a single prompt string. This ensures consistency in how the prompt is constructed during both training and inference.

In [3]:
def prepare_dataset(split="train"):
   """
   Load and prepare the GSM8K dataset for training with string prompts.

   Args:
       split (str): The dataset split to load ("train" or "test"). Defaults to "train".

   Returns:
       list: A list of formatted examples, each containing a prompt string and answer.

   Explanation:
       1. Loads the GSM8K dataset from the Hugging Face datasets hub.
       2. For each example in the dataset:
          - Creates a list of messages with system prompt and the question.
          - Converts this list into a single string prompt using build_prompt().
          - Extracts the answer from the dataset example.
          - Creates a formatted example dictionary with prompt and answer.
       3. Returns the list of formatted examples ready for model training or evaluation.
   """
   data = load_dataset('openai/gsm8k', 'main')[split]
   formatted_data = []
   for example in data:
       # Convert list of messages to a single string prompt.
       prompt_str = build_prompt([
           {"role": "system", "content": SYSTEM_PROMPT},
           {"role": "user", "content": example["question"]}
       ])
       formatted_example = {
           "prompt": prompt_str,  # Now a string rather than a list.
           "answer": extract_answer_from_dataset(example["answer"])
       }
       formatted_data.append(formatted_example)
   return formatted_data

def build_prompt(messages):
   """
   Build a single prompt string from a list of messages.

   Args:
       messages (list): A list of message dictionaries, each with 'role' and 'content' keys.

   Returns:
       str: A concatenated string of all message contents.

   Explanation:
       1. Takes a list of message dictionaries in the typical chat format.
       2. Extracts the 'content' field from each message and strips whitespace.
       3. Joins all content strings with newlines to create a single prompt.
       4. This preserves the training format while converting from structured messages to a string.
   """
   return "\n".join([msg["content"].strip() for msg in messages])

## Part 4: Evaluation Functions

Evaluation is crucial to track the model's progress. In this part, we define functions that allow us to evaluate the model on a set of examples. The evaluation functions perform the following tasks:

- **Tokenize the prompt and generate a response:** The model's output is generated given the tokenized prompt.
- **Extract the predicted answer:** The answer is extracted from the generated response.
- **Compare the predicted answer with the expected answer:** This comparison is done using exact matching as well as numeric equivalence checks.

Two helper functions, `_extract_last_number` and `_extract_single_number`, are used to extract numbers from text. The main evaluation function, `evaluate_model`, uses these helpers to determine if the predicted answer is correct:

In [4]:
def extract_last_number(text):
   """
   Extracts the last number appearing in the text.

   Args:
       text (str): The text to extract a number from.

   Returns:
       float or None: The last number in the text, or None if no number is found.

   Explanation:
       1. Removes dollar signs and percent symbols from the text.
       2. Uses regex to find a number that appears at the end of the text (possibly after whitespace).
       3. The pattern matches numbers that appear at the end of the string, with or without decimal points.
       4. Returns the found number as a float, or None if no match is found.
   """
   text = text.replace('$', '').replace('%', '')
   pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$'
   match = re.search(pattern, text)
   return float(match.group(1)) if match else None

def extract_single_number(text):
   """
   Extracts a single number from text if exactly one number is present.

   Args:
       text (str): The text to extract a number from.

   Returns:
       float or None: The single number in the text, or None if zero or multiple numbers are found.

   Explanation:
       1. Uses regex to find all numbers in the text (including negative numbers and decimals).
       2. If exactly one number is found, returns it as a float.
       3. If zero or multiple numbers are found, returns None.
   """
   text = re.sub(r'[^\d\.\-]', '', text)
   numbers = re.findall(r'-?\d*\.?\d+', text)
   return float(numbers[0]) if len(numbers) == 1 else None

def evaluate_model(model, tokenizer, eval_examples, device):
   """
   Evaluates the model on a set of examples and prints detailed results.

   Args:
       model: The language model to evaluate.
       tokenizer: The tokenizer for encoding inputs and decoding outputs.
       eval_examples (list): List of evaluation examples, each containing "prompt" and "answer".
       device: The device (CPU or GPU) to run evaluation on.

   Returns:
       float: The accuracy percentage (correct predictions / total examples * 100).

   Explanation:
       1. Sets the model to evaluation mode.
       2. For each example in the evaluation set:
          - Encodes the prompt and generates a response using the model.
          - Extracts the predicted answer from the generated response.
          - Compares the predicted answer with the expected answer using multiple methods:
            a. Exact string matching
            b. Single number extraction and comparison
            c. Last number extraction and comparison
          - Prints detailed information about each example.
       3. Calculates and returns the overall accuracy.
       4. Returns the model to training mode.
   """
   model.eval()
   correct = 0
   total = len(eval_examples)
   print("\n" + "="*50)
   print("EVALUATION ON", total, "EXAMPLES")
   print("="*50)

   safe_processor = NanSafeLogitsProcessor()
   logits_processor = LogitsProcessorList([safe_processor])


   for example in eval_examples:
       # Get the prompt and expected answer
       full_prompt = example["prompt"]
       expected = example["answer"]

       # Tokenize and generate response
       inputs = tokenizer([full_prompt], return_tensors="pt").to(device)
       with torch.no_grad():
           outputs = model.generate(
               **inputs,
               max_new_tokens=512,
               temperature=0.5,
               num_return_sequences=1,
               pad_token_id=tokenizer.pad_token_id,
               eos_token_id=tokenizer.eos_token_id,
               forced_eos_token_id=tokenizer.eos_token_id,
               early_stopping=False,
               logits_processor=logits_processor,
           )
       response = tokenizer.decode(outputs[0], skip_special_tokens=True)

       try:
           # Extract answer and check correctness
           predicted = extract_answer_from_model_output(response)

           # Try different matching methods
           if predicted == expected:  # Exact match
               is_correct = True
           else:
               # Try single number matching
               pred_num = extract_single_number(str(predicted))
               exp_num = str(expected)
               if pred_num is not None and exp_num is not None and pred_num == exp_num:
                   is_correct = True
               else:
                   # Try last number matching
                   pred_num = extract_last_number(str(predicted))
                   exp_num = extract_last_number(str(expected))
                   is_correct = (pred_num is not None and exp_num is not None and
                               pred_num == exp_num)

           # Update counter for correct answers
           if is_correct:
               correct += 1

           # Print evaluation details
           print("\nPrompt:")
           print(full_prompt)
           print("\nExpected Answer:")
           print(expected)
           print("\nExtracted Answer:")
           print(predicted)
           print("\nFull Generated Response:")
           print(response)
           print("\nCorrect:", "✓" if is_correct else "✗")
           print("-"*50)

       except Exception as e:
           print("\nFailed to parse model output for prompt:")
           print(full_prompt)
           print("Error:", e)
           print("-"*50)

   # Calculate and print final accuracy
   accuracy = (correct / total) * 100
   print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})")
   print("="*50)

   # Return model to training mode
   model.train()
   return accuracy

## Part 5: Reward Functions

In reinforcement learning, reward functions guide the training process by providing feedback on the model's output. In our pipeline, we define two reward functions:

1. **`correctness_reward`:**  
   This function assigns rewards based on whether the generated answer is correct. It compares the extracted answer from the model output with the expected answer, using both exact string matching and numeric equivalence checks. A exact match earns a higher reward (2.0), while a match based on numeric equivalence receives a smaller reward (1.5).
   
2. **`format_reward`:**  
   This function encourages the model to adhere to the desired XML-like output format. It provides a small reward for the presence of the `<reasoning>`, `</reasoning>`, `<answer>`, and `</answer>` tags in the generated text. We use a relatively value of 0.05 for each of the four pieces because the model is already capable of using these tags from previous supervised finetuning step, so we give this small reward so that it doesn't forget to do that because of the RL updates.

In [5]:
def correctness_reward(prompts, completions, answer, **kwargs):
   """
   Assigns a reward based on the correctness of the model's answer.

   Args:
       prompts (list): List of input prompts.
       completions (list): List of model completions, each containing content.
       answer (list): List of expected answers.
       **kwargs: Additional keyword arguments.

   Returns:
       list: List of numerical rewards for each completion.

   Explanation:
       1. Extracts the content from each completion.
       2. Extracts the answer portion from each response using extract_answer_from_model_output.
       3. Assigns rewards based on matching criteria:
          - 2.0 points for an exact match
          - 1.5 points for numeric equivalence (when values match but format differs)
          - 0.0 points for incorrect answers
       4. Tracks completion lengths for analysis.
   """
   responses = [completion[0]['content'] for completion in completions]
   extracted = [extract_answer_from_model_output(r) for r in responses]
   rewards = []
   
   for r, a in zip(extracted, answer):
        # Use a case-insensitive, stripped comparison for robustness
        if r is None:
            rewards.append(0.0)
            continue
        if r.strip().lower() == a.strip().lower():
            rewards.append(2.0)
        else:
            # Try numeric equivalence
            r_num = extract_single_number(str(r))
            a_num = extract_single_number(str(a))
        
        if r_num is not None and a_num is not None and r_num == a_num:
            rewards.append(1.5)
        else:
            # Instead of 0.0, give a small baseline reward to avoid zero signal
            rewards.append(0.1)
   return rewards


def format_reward(completions, **kwargs):
   """
   Assigns a reward for adhering to the desired XML format.

   Args:
       completions (list): List of model completions, each containing content.
       **kwargs: Additional keyword arguments.

   Returns:
       list: List of format compliance scores for each completion.

   Explanation:
       1. Extracts the content from each completion.
       2. Evaluates format compliance by checking for required XML tags:
          - 0.2 points for each tag present (<reasoning>, </reasoning>, <answer>, </answer>)
          - Maximum score of 0.8 for perfect format compliance
       3. Stores and returns the format compliance scores.
   """
   responses = [completion[0]['content'] for completion in completions]
   rewards = []
   format_scores = []
   for response in responses:
       score = 0.0
       if "<reasoning>" in response: score += 0.2
       if "</reasoning>" in response: score += 0.2
       if "<answer>" in response: score += 0.2
       if "</answer>" in response: score += 0.2
       rewards.append(score)
       format_scores.append(score)
   return rewards

def combined_reward(prompts, completions, answer):
   """
   Combines correctness and format rewards.

   Args:
       prompts (list[str]): List of prompt texts
       completions (list[list[dict]]): List of completion dictionaries
       answer (list[str]): List of expected answers

   Returns:
       list[float]: Combined rewards for each prompt-completion pair

   Explanation:
       1. Calculates separate rewards for correctness and format compliance.
       2. Combines the rewards with the following weights:
          - Correctness score range: 0.0 to 2.0
          - Format score range: 0.0 to 0.8
          - Total possible range: 0.0 to 2.8
       3. Returns the combined reward for each example.
   """
   # Get individual rewards
   correctness_scores = correctness_reward(prompts=prompts, completions=completions, answer=answer)
   format_scores = format_reward(completions=completions)

   # Combine rewards - correctness is weighted more heavily
   combined_rewards = []
   for c_score, f_score in zip(correctness_scores, format_scores):
       # Correctness score range: 0.0 to 2.0
       # Format score range: 0.0 to 0.8
       # Total range: 0.0 to 2.8
       combined_rewards.append(c_score + f_score)

   return combined_rewards

## Part 6: DataParallel GRPO From Scratch

In this section, we implement all the building blocks of the GRPO algorithm from scratch. The implementation assumes that the machine running the code has at least 2 GPUs. We use PyTorch's `DataParallel` API to distribute the policy model across the GPU cores, one copy of the model per GPU core. The batch is split between the GPU cores.

In [6]:
def selective_log_softmax(logits, input_ids):
    """
    Computes log probabilities for specific tokens in the vocabulary.
    """
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def compute_log_probs(model, input_ids, attention_mask, logits_to_keep):
    """
    Computes the log probabilities for a batch of tokens.

    Args:
        model: The language model.
        input_ids (torch.Tensor): Token IDs for input sequences.
        attention_mask (torch.Tensor): Attention mask for input sequences.
        logits_to_keep (int): Number of tokens to keep from the end of the sequence.

    Returns:
        torch.Tensor: Log probabilities of the selected tokens.

    Explanation:
        1. Gets logits from the model for the input sequence.
        2. Selects logits for all tokens except the last one (as we predict next tokens).
        3. Selects only the last 'logits_to_keep' tokens from both logits and input_ids.
        4. Computes log probabilities for these tokens using selective_log_softmax.
    """
    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1, :]
    input_ids = input_ids[:, -logits_to_keep:]
    logits = logits[:, -logits_to_keep:, :]
    return selective_log_softmax(logits, input_ids)

def create_completion_mask(completion_ids, eos_token_id):
    """
    Creates a mask for completion tokens that excludes tokens after the EOS token.

    Args:
        completion_ids (torch.Tensor): Token IDs of the generated completions.
        eos_token_id (int): The ID of the end-of-sequence token.

    Returns:
        torch.Tensor: A binary mask with 1s for valid tokens and 0s after the EOS token.

    Explanation:
        1. Identifies positions where EOS tokens occur in each sequence.
        2. Finds the index of the first EOS token in each sequence.
        3. Creates a mask where positions before and including the first EOS are 1, others are 0.
        4. If no EOS token is found in a sequence, all positions are set to 1.
    """
    is_eos = completion_ids == eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
    mask_exists = is_eos.any(dim=1)
    eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
    sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
    return (sequence_indices <= eos_idx.unsqueeze(1)).int()

def generate_completions(model, tokenizer, prompts, num_generations=2, max_completion_length=16):
    """
    Generates multiple completions for each prompt.
    Reduced num_generations and max_completion_length for M1 compatibility.
    """
    # Use MPS if available, otherwise CPU
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
    prompt_ids = inputs["input_ids"].to(device)
    prompt_mask = inputs["attention_mask"].to(device)
    print(f"Input batch size: {prompt_ids.size(0)}, Device: {prompt_ids.device}")
    
    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)
    prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0)
    
    # Use smaller batches for generation to avoid OOM
    batch_size = 2  # Small batch size for M1
    all_outputs = []

    safe_processor = NanSafeLogitsProcessor()
    logits_processor = LogitsProcessorList([safe_processor])
    
    for i in range(0, prompt_ids.size(0), batch_size):
        batch_end = min(i + batch_size, prompt_ids.size(0))
        batch_prompt_ids = prompt_ids[i:batch_end]
        batch_prompt_mask = prompt_mask[i:batch_end]
        
        batch_outputs = model.generate(
            batch_prompt_ids,
            attention_mask=batch_prompt_mask,
            max_new_tokens=max_completion_length,
            do_sample=True,
            temperature=0.5,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            early_stopping=False,
            logits_processor=logits_processor
        )
        torch.mps.empty_cache()
        all_outputs.append(batch_outputs)
    
    outputs = torch.cat(all_outputs, dim=0)
    print(f"Output batch size: {outputs.size(0)}, Device: {outputs.device}")
    
    completion_ids = outputs[:, prompt_length:]
    completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
    return prompt_ids, prompt_mask, completion_ids, completion_mask

def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
    """
    Generates data for GRPO rollouts including completions and log probabilities.
    """
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

    prompts = [prompt for prompt in batch_samples["prompt"]]
    answers = [answer for answer in batch_samples["answer"]]
    
    with torch.no_grad():
        prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
            model, tokenizer, prompts, num_generations, max_completion_length
        )
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)
        old_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep)
        ref_log_probs = compute_log_probs(ref_model, input_ids, attention_mask, logits_to_keep)
    formatted_completions = [[{'content': tokenizer.decode(ids, skip_special_tokens=True)}] for ids in completion_ids]
    repeated_prompts = [p for p in prompts for _ in range(num_generations)]
    repeated_answers = [a for a in answers for _ in range(num_generations)]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "completion_mask": completion_mask,
        "old_log_probs": old_log_probs,
        "ref_log_probs": ref_log_probs,
        "formatted_completions": formatted_completions,
        "repeated_prompts": repeated_prompts,
        "repeated_answers": repeated_answers,
        "logits_to_keep": logits_to_keep,
        "batch_size": len(prompts),
        "num_generations": num_generations
    }


def grpo_loss(model, ref_model, rollout_data, tokenizer, reward_function, beta=0.01, epsilon=0.2):
    """
    Computes the GRPO loss for updating the policy model.

    Args:
        model: The policy model being trained.
        ref_model: The reference model for KL divergence calculation.
        rollout_data (dict): Data generated by generate_rollout_data.
        tokenizer: The tokenizer for encoding and decoding text.
        reward_function: Function that calculates rewards for completions.
        beta (float): KL penalty coefficient.
        epsilon (float): Clipping parameter for PPO.

    Returns:
        torch.Tensor: The GRPO loss to be minimized.

    Explanation:
        1. Computes current token log probabilities using the policy model.
        2. Calculates the probability ratio between current and old policies.
        3. Computes rewards using the provided reward_function.
        4. Calculates advantages by standardizing rewards within each prompt.
        5. Computes the PPO surrogate objective with clipping.
        6. Calculates the KL divergence between reference and policy models.
        7. Combines surrogate loss and KL penalty.
        8. Averages the loss across all tokens and batches.
    """
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

    input_ids = rollout_data["input_ids"]
    attention_mask = rollout_data["attention_mask"]
    completion_mask = rollout_data["completion_mask"]
    logits_to_keep = rollout_data["logits_to_keep"]
    old_log_probs = rollout_data["old_log_probs"]
    ref_log_probs = rollout_data["ref_log_probs"]
    
    # Process in smaller batches
    batch_size = 4  # Small batch size for M1
    token_log_probs_list = []
    
    for i in range(0, input_ids.size(0), batch_size):
        batch_end = min(i + batch_size, input_ids.size(0))
        batch_input_ids = input_ids[i:batch_end]
        batch_attention_mask = attention_mask[i:batch_end]
        
        batch_token_log_probs = compute_log_probs(model, batch_input_ids, batch_attention_mask, logits_to_keep)
        token_log_probs_list.append(batch_token_log_probs)
    
    token_log_probs = torch.cat(token_log_probs_list, dim=0)
    ratio = torch.exp(token_log_probs - old_log_probs)
    
    rewards = torch.tensor(
        reward_function(prompts=rollout_data["repeated_prompts"], completions=rollout_data["formatted_completions"], answer=rollout_data["repeated_answers"]),
        dtype=torch.float32,
        device=device
    )
    
    batch_size = rollout_data["batch_size"]
    num_generations = rollout_data["num_generations"]
    rewards = rewards.view(batch_size, num_generations)
    avg_reward = rewards.mean().item()
    print("Average Reward:", avg_reward)
    
    mean_rewards = rewards.mean(dim=1).repeat_interleave(num_generations)
    std_rewards = rewards.std(dim=1).repeat_interleave(num_generations)
    advantages = ((rewards.view(-1) - mean_rewards) / (std_rewards + 1e-4)).unsqueeze(1)
    
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
    surrogate_loss = torch.min(surr1, surr2)
    
    kl = torch.exp(ref_log_probs - token_log_probs) - (ref_log_probs - token_log_probs) - 1
    per_token_loss = surrogate_loss - beta * kl
    loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    
    return loss, avg_reward

def optimize_model_memory(model):
    """
    Optimizes the model to use less memory during training.
    """
    model.train()
    model.config.use_cache = False  # Disable KV cache to save memory

    # Ensure inputs will require gradients
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    # Enable gradient checkpointing to trade computation for memory
    model.gradient_checkpointing_enable()
    
    return model

def train_with_grpo_m1(model, tokenizer, train_data, num_iterations=1, num_steps=100, batch_size=2,
                       num_generations=4, max_completion_length=400, beta=0.1,
                       learning_rate=5e-6, mu=1, epsilon=0.2, reward_function=None,
                       accumulation_steps=4, num_workers=0):
    """
    M1-optimized training function with DataLoader integration, gradient accumulation, and mixed precision.
    """
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Training on device: {device}")

    model.to(device)
    model = optimize_model_memory(model)

    # Ensure all parameters in the main model require gradients
    for param in model.parameters():
        param.requires_grad = True

    scaler = GradScaler(enabled=False)

    # Initialize DataLoader
    dataset = CustomDataset(train_data)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    for iteration in range(num_iterations):
        print(f"\nIteration {iteration + 1}/{num_iterations}")

        ref_model = copy.deepcopy(model)
        ref_model.eval()
        for param in ref_model.parameters():
            param.requires_grad = False
        print("Reference model created.")

        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        model.train()

        accumulation_counter = 0

        for step, batch_samples in enumerate(train_loader):
            with torch.no_grad():
                rollout_data = generate_rollout_data(
                    model, ref_model, tokenizer, batch_samples,
                    num_generations, max_completion_length
                )

            for grpo_iter in range(mu):
                #with torch.autocast(device_type='mps', dtype=torch.float16):
                loss, avg_reward = grpo_loss(
                    model, ref_model, rollout_data, tokenizer,
                    reward_function, beta=beta, epsilon=epsilon
                )

                loss = loss / accumulation_steps
                scaler.scale(loss).backward()
                accumulation_counter += 1

                if accumulation_counter % accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    accumulation_counter = 0

                print(f"Iteration {iteration + 1}/{num_iterations}, Step {step + 1}/{len(train_loader)}, "
                      f"GRPO iter {grpo_iter + 1}/{mu}, loss: {loss.item():.4f}")

                if device.type == 'mps':
                    torch.mps.empty_cache()

    return model

# Reduced training dataset size function
def get_reduced_dataset(dataset, size=100):
    """
    Reduces the dataset size for M1 compatibility.
    
    Args:
        dataset: The original dataset.
        size: The desired reduced size.
        
    Returns:
        A reduced subset of the dataset.
    """
    return random.sample(dataset, min(size, len(dataset)))

## Part 7: Training Setup and Execution

In this section, we put together all components to set up and run the training. We begin by loading the pre-trained model and tokenizer, prepare evaluation data, and then do reinforcement learning (RL) fine-tuning using the our own `train_with_grpo` we implemented from scratch above.

Key steps include:

- **Model and Tokenizer Initialization:**  
  The model `"Qwen/Qwen2.5-1.5B-Instruct"` is loaded with optimized settings (using `torch.bfloat16` and FlashAttention2). The tokenizer is also loaded, and its padding token is set to the end-of-sequence token. Loading a model with `torch.bfloat16` converts its parameters to use 16 bits instead of 32 bits per number, which cuts the model's memory usage in half and can make training faster on modern GPUs.
  
- **Initial Evaluation:**  
  Before fine-tuning, the model is evaluated on a few examples to establish a baseline performance.
    
- **Reinforcement Learning Fine-Tuning (RL):**  
  The training function `train_with_grpo` implementing GRPO from scratch is configured with the appropriate training arguments and reward functions. The RL training then proceeds on the remaining training data.
  
- **Final Evaluation and Model Saving:**  
  After RL fine-tuning, the model is evaluated again, and the final model is saved.

In the code below:
  
- The device is determined (GPU if available, otherwise CPU).
- The pre-trained Qwen2.5-1.5B-Instruct model and tokenizer are loaded. The tokenizer's pad token is set to the eos_token.
- A small subset of the dataset is reserved for evaluation to provide a baseline.
- The model is optimized for memory efficiency by enabling gradient checkpointing and disabling KV caching.
- **Step 1:** The model is evaluated before fine-tuning to establish a baseline accuracy.
- **Step 2:** Reinforcement learning fine-tuning is performed using the `train_with_grpo` function with our defined reward functions (`format_reward` and `correctness_reward`, combined into `combined_reward`). The model is trained using a multi-GPU.
- **Step 3:** The final, fine-tuned model and tokenizer are saved to disk.

We used the following hyperparameters for our GRPO training pipeline:

### **Training Configuration**

These parameters configure the reinforcement learning fine-tuning run using the GRPO algorithm. We set them as follows:

- **num_iterations=1**  
  The number of outer iterations where a new reference model is created from the current policy model. One iteration is one pass over the entire dataset.

- **num_steps=500**  
  The training loop will perform a maximum of 500 steps, each processing a batch of examples.

- **batch_size=7**  
  Each step processes 7 examples per batch which, in the case of 8 GPUs, puts 1 example at each GPU. One GPU (0) is used as the master by `DataParallel` for aggregating gradients and gathering outputs.

- **num_generations=14**  
  For every prompt in the training data, the trainer will generate 14 different completions. These multiple generations are used to compute a relative advantage (or reward signal) that guides the RL update. Reduce this number if you have GPUs with less VRAM.

- **max_completion_length=400**  
  When generating completions (the "response" portion of the sequence), the generation is capped at 400 tokens. This limits the length of the outputs produced by the model during the RL phase. Reduce this number if you have GPUs with less VRAM.

- **beta=0.04**  
  The coefficient for the KL divergence penalty in the GRPO loss function. This controls how much the model is allowed to diverge from the reference model.

- **learning_rate=5e-6**  
  The learning rate for RL finetuning. A relatively low learning rate is used for stable policy updates.

- **mu=1**  
  The number of policy updates performed for each batch of rollout data. In our case, we perform just one update per batch.

- **epsilon=0.1**  
  The clipping parameter for the PPO component of GRPO. This prevents the policy from changing too drastically in a single update.

The model is evaluated both before and after fine-tuning to measure the improvement in accuracy. Finally, the fine-tuned model is saved to the "grpo_finetuned_model" directory.

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


# 1. Load your model and tokenizer
model_name = "Qwen/Qwen2.5-0.5B-Instruct"  # Choose a smaller model if needed
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Use float8 instead of bfloat16 for M1
    device_map="auto"
).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.config.eos_token_id = tokenizer.eos_token_id

# 2. Optimize the model for memory efficiency
model = optimize_model_memory(model)

# 3. Load your dataset and reduce its size
all_data = prepare_dataset("train")
train_data = get_reduced_dataset(all_data, size=50)  # Use a smaller dataset
eval_data = get_reduced_dataset(all_data[:30], size=10)  # Use a smaller eval set

print(f"Using device: {device}")

print("\nInitial model evaluation before finetuning:")
pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Pre-GRPO Accuracy: {pre_grpo_accuracy:.2f}%")

print("\nStarting RL fine-tuning using GRPO...")
# 4. Configure training with smaller batches and shorter sequences
training_config = {
    'num_iterations': 1,
    'num_steps': 50,  # Reduced from 500
    'batch_size': 2,  # Reduced from 7
    'num_generations': 2,  # Reduced from 12
    'max_completion_length': 400,  # Reduced from 400
    'beta': 0.04,
    'learning_rate': 5e-6,
    'mu': 1,
    'epsilon': 0.1
}

# 5. Train the model
model = train_with_grpo_m1(
    model=model,
    tokenizer=tokenizer,
    train_data=train_data,
    reward_function=combined_reward,
    **training_config
)

# 6. Save the model
model.save_pretrained("grpo_finetuned_model_m1")
tokenizer.save_pretrained("grpo_finetuned_model_m1")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Using device: mps

Initial model evaluation before finetuning:

Starting RL fine-tuning using GRPO...
Training on device: mps

Iteration 1/1
Reference model created.
Input batch size: 2, Device: mps:0




Output batch size: 4, Device: mps:0
Average Reward: 0.17499999701976776
Iteration 1/1, Step 1/25, GRPO iter 1/1, loss: nan
Input batch size: 2, Device: mps:0


As you can see, the model learned to generate the correct solution for 90% of problems.

In [None]:
# Evaluate the finetuned model
post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device)
print(f"Post-GRPO Accuracy: {post_grpo_accuracy:.2f}%")