# Tutorial: Fine-Tuning Gemma with GRPO for Math Reasoning


## Introduction

 This script is designed to fine-tune a powerful language model, specifically Google's Gemma 3 1B-IT, to improve its mathematical reasoning capabilities. It uses GRPO along with Low-Rank Adaptation (LoRA) to efficiently train the model on the GSM8K dataset, which is a collection of grade-school math word problems.

The core idea is to teach the model not just to give an answer, but to also provide a step-by-step reasoning process, much like how a student would show their work. To ensure the model follows a specific structure for its responses, we'll guide it to output its reasoning and answer within XML-like tags: `<reasoning>...</reasoning>` and `<answer>...</answer>.

**Underlying Theory and Principles:**

*   **Large Language Models (LLMs):** These are AI models (like Gemma) trained on vast amounts of text data, enabling them to understand and generate human-like text.
*   **Fine-Tuning:** While pre-trained LLMs are knowledgeable, fine-tuning adapts them to specific tasks or datasets. Here, we're fine-tuning for math problem-solving.
*   **LoRA (Low-Rank Adaptation):** Training entire LLMs is computationally expensive. LoRA is a parameter-efficient fine-tuning (PEFT) technique that adds small, trainable "adapter" layers to the model. This drastically reduces the number of parameters that need to be updated, making fine-tuning faster and requiring less memory, without sacrificing much performance.
*   **GRPO (Generative Response Policy Optimization):** This is an advanced fine-tuning method that goes beyond simple supervised learning. Instead of just learning to predict the next word based on a "correct" example, GRPO uses reward functions. The model generates multiple possible responses, and these responses are then scored by custom reward functions. Responses that are more "correct" or follow the desired format better receive higher rewards, guiding the model to learn preferred behaviors.
*   **Chain-of-Thought (CoT) Prompting:** The script encourages the model to produce a reasoning process before the final answer. This is a form of CoT, where eliciting a step-by-step thinking process often leads to more accurate results from LLMs, especially for reasoning tasks. The XML tags help structure this CoT.
*   **Unsloth:** This library is used to speed up the fine-tuning process and reduce memory usage, making it feasible to train powerful models on consumer-grade hardware. It optimizes parts of the model loading and training pipeline.

This script operates in the context of improving LLM reliability and interpretability. By forcing the model to show its reasoning and adhere to a specific format, we can better understand its "thought process" and trust its answers more.

Let's dive into the code!

## ðŸ“º Watch the Tutorial

Prefer a video walkthrough? Check out the accompanying tutorial on YouTube:

[Fine-Tuning LLM for Math Reasoning (GRPO)](https://youtu.be/2CI0yQHyoxU)

## Getting Started

### Hardware Requirements

This tutorial requires a machine with NVIDIA GPUs. The fine-tuning process is computationally intensive, and while Unsloth and LoRA make it more efficient, you'll still need GPU acceleration for reasonable training times.

### Environment Setup

First, ensure you have the proper NVIDIA drivers installed. For Ubuntu systems:

```bash
sudo apt install nvidia-utils-535 nvidia-driver-535 -y
sudo apt install python3.12-dev -y
# Reboot your system after installing drivers
```

Next, set up a Python virtual environment and install the required packages:

```bash
# Create a virtual environment
python -m venv venv

# Activate the virtual environment
# On Linux/Mac:
source venv/bin/activate
# On Windows:
# venv\Scripts\activate

# Install requirements
pip install -r requirements.txt
```

### Running the Application

Once your environment is set up, you can run the main script:

```bash
python main.py
```

This will start the fine-tuning process. Depending on your hardware and the configuration settings, this may take several hours to complete.

## Structured Code Walk-through

The `main.py` script can be broken down into several logical sections. We'll go through them one by one.

### 1. Imports and Initial Setup

In [None]:
#!/usr/bin/env python

import re
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from difflib import SequenceMatcher

**Explanation:**

*   `#!/usr/bin/env python`: This is a shebang line, typically used in Unix-like systems to specify that the script should be executed with Python.
*   `import re`: Imports the regular expression module, which is essential for pattern matching in strings. This will be used to extract information from text, like the content within our XML tags or numerical answers.
*   `import torch`: Imports PyTorch, a popular deep learning framework. LLMs like Gemma are built and trained using frameworks like PyTorch.
*   `from unsloth import FastLanguageModel`: Imports `FastLanguageModel` from the Unsloth library. Unsloth provides optimizations for faster training and reduced memory usage of LLMs. `FastLanguageModel` is a wrapper around Hugging Face models that incorporates these optimizations.
*   `from datasets import load_dataset, Dataset`: Imports functions from the `datasets` library by Hugging Face. `load_dataset` is used to easily download and prepare standard datasets like GSM8K. `Dataset` is a class for representing datasets.
*   `from trl import GRPOConfig, GRPOTrainer`: Imports `GRPOConfig` and `GRPOTrainer` from the `trl` (Transformer Reinforcement Learning) library.
    *   `GRPOConfig`: A configuration class to set up parameters for the GRPO training process (e.g., learning rate, batch size, number of generations per prompt).
    *   `GRPOTrainer`: The main class that handles the GRPO training loop, incorporating the model, dataset, tokenizer, and reward functions.
*   `from difflib import SequenceMatcher`: Imports `SequenceMatcher` from the `difflib` module. This is used to compare two sequences (like strings) and calculate their similarity, which can be helpful in one of our reward functions to see how close the model's text answer is to the correct text answer.

**Why it matters:** These imports bring in all the necessary tools and libraries to load the model, prepare data, define the training process, and evaluate the model's outputs.

### 2. Constants and Configuration

In [None]:
# ==========================================
# Constants and Configuration
# ==========================================
max_seq_length = 1024               # Total tokens (prompt + completion)
lora_rank = 32                      # LoRA low-rank adaptation dimension
max_prompt_length = 256             # Max tokens for the prompt in GRPO
SYSTEM_PROMPT = """
Respond **only** in the exact format below, with no extra text, no deviations, and preserving these tags (including newlines):

<reasoning>
.....
</reasoning>
<answer>
.....
</answer>

Note:
- The answer must be a number, and the units must be included if the question asks for them.

"""

**Explanation:**

*   `max_seq_length = 1024`: This defines the maximum number of tokens (words or sub-words) that the model can handle for a single input sequence. This includes both the user's prompt (the question) and the model's generated completion (reasoning + answer).
    *   *Analogy:* Think of this as the maximum number of words allowed in a short essay.
*   `lora_rank = 32`: This sets the "rank" for the LoRA adaptation. In LoRA, the original weight matrices of the model are frozen, and two smaller matrices (whose product approximates the change we want to make to the original weights) are trained. The rank (here, 32) determines the size/complexity of these smaller matrices. A smaller rank means fewer trainable parameters, leading to faster training and less memory, but potentially less expressive power for the adaptation. `32` is a common choice.
*   `max_prompt_length = 256`: Specifies the maximum number of tokens allowed for the input prompt when using GRPO. The remaining tokens up to `max_seq_length` (i.e., `1024 - 256 = 768`) are available for the model's generated completion.
*   `SYSTEM_PROMPT`: This is a crucial piece of text. It's a system-level instruction given to the language model to guide its behavior.
    *   It strictly defines the output format: the model **must** use `<reasoning>...</reasoning>` followed by `<answer>...</answer>.
    *   It emphasizes no extra text or deviations.
    *   It also gives a hint about the answer format: it should be a number, and units should be included if the question requires them.

**Why it matters:** These constants define key parameters for the model and the training process. The `SYSTEM_PROMPT` is particularly important as it's our primary way of instructing the model on how to structure its responses. This explicit instruction is a form of prompt engineering.

### 3. Load and Prepare Model for LoRA Fine-Tuning

In [None]:
# ==========================================
# 1. Load and Prepare Model for LoRA Fine-Tuning
# ==========================================
# Load the pre-trained Gemma 3 1B instruct model optimized by Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="google/gemma-3-1b-it",  # Choose the base model
    max_seq_length=max_seq_length,       # Support long reasoning traces
    load_in_4bit=True,                   # Quantize model to 4-bit to save GPU memory
    fast_inference=True,                 # Enable vLLM acceleration for generation
    max_lora_rank=lora_rank,             # Set maximum LoRA rank
    gpu_memory_utilization=0.6,          # Cap GPU memory usage to avoid OOM
)

# Apply LoRA (Low-Rank Adaptation) on selected projection layers
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[                   # Layers to fine-tune via LoRA
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Save memory for long contexts
    random_state=3407,                      # Ensure reproducibility
)

**Explanation:**

*   **Loading the Model and Tokenizer:**
    *   `model, tokenizer = FastLanguageModel.from_pretrained(...)`: This line loads the pre-trained language model and its associated tokenizer.
        *   `model_name="google/gemma-3-1b-it"`: Specifies the base model we're using. "gemma-3-1b-it" refers to Google's Gemma 3 model with 1 billion parameters, instruction-tuned (meaning it's already been trained to follow instructions).
        *   `max_seq_length=max_seq_length`: Passes the previously defined maximum sequence length.
        *   `load_in_4bit=True`: This is a quantization technique. It loads the model's weights using only 4 bits per parameter instead of the usual 16 or 32 bits. This significantly reduces the model's memory footprint (roughly by 4x compared to 16-bit), making it possible to run larger models on GPUs with less VRAM.
        *   `fast_inference=True`: Enables Unsloth's optimizations for faster text generation (inference). This might leverage technologies like vLLM if available.
        *   `max_lora_rank=lora_rank`: Informs Unsloth about the maximum LoRA rank we intend to use, allowing it to optimize memory allocation.
        *   `gpu_memory_utilization=0.6`: Tells Unsloth to try and limit GPU memory usage to 60% of its capacity. This can help prevent "Out of Memory" (OOM) errors, especially when other processes might also be using the GPU.
    *   The `tokenizer` is responsible for converting text into a sequence of numbers (tokens) that the model can understand, and vice-versa.

*   **Applying LoRA:**
    *   `model = FastLanguageModel.get_peft_model(...)`: This function modifies the loaded model to prepare it for LoRA fine-tuning. "PEFT" stands for Parameter-Efficient Fine-Tuning.
        *   `r=lora_rank`: Sets the rank of the LoRA matrices. This is the same `lora_rank` (32) we defined earlier.
        *   `target_modules=[...]`: This is a crucial parameter. It specifies *which* layers of the original model will have LoRA adapters applied to them. The names like `"q_proj", "k_proj", "v_proj", "o_proj"` refer to projection layers within the Transformer architecture's attention mechanism. `"gate_proj", "up_proj", "down_proj"` refer to layers within the feed-forward network parts of the Transformer. By targeting these specific layers, we aim to adapt the model's core processing capabilities.
        *   `lora_alpha=lora_rank`: `lora_alpha` is a scaling factor in LoRA. Setting it equal to `r` is a common practice. It can be thought of as a learning rate for the LoRA adapters.
        *   `use_gradient_checkpointing="unsloth"`: Gradient checkpointing is a technique to save memory during training by recomputing some values during the backward pass instead of storing them all. Unsloth provides an optimized version. This is especially useful for long sequences or large models.
        *   `random_state=3407`: Sets a seed for random number generation. This helps ensure that if you run the script again with the same settings, you get the same results (e.g., same initial LoRA weights), which is important for reproducibility.

**Why it matters:** This section sets up the core of our fine-tuning process. We load a powerful base model, make it memory-efficient using 4-bit quantization, and then apply LoRA so we can fine-tune it effectively without needing massive computational resources. The choice of `target_modules` directs where the fine-tuning will focus its efforts.

### 4. Data Preparation: GSM8K with XML Chain-of-Thought

In [None]:
# ==========================================
# 2. Data Preparation: GSM8K with XML Chain-of-Thought
# ==========================================
def extract_xml_answer(text: str) -> str:
    m = re.search(r"<answer>\s*(.*?)\s*</answer>", text, flags=re.DOTALL)
    return m.group(1).strip() if m else ""

def extract_hash_answer(text: str) -> str | None:
    return text.split('####')[1].strip() if '####' in text else None


def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset("openai/gsm8k", "main")[split]
    return data.map(lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user",   "content": x["question"]},
        ],
        "answer": extract_hash_answer(x["answer"]),
    })

# Prepare the training dataset
dataset = get_gsm8k_questions()

**Explanation:**

*   **`extract_xml_answer(text: str) -> str`:**
    *   This function is designed to extract the content from within `<answer>...</answer>` tags in a given text.
    *   `re.search(r"<answer>\s*(.*?)\s*</answer>", text, flags=re.DOTALL)`: Uses a regular expression to find the pattern.
        *   `<answer>` and `</answer>`: Match the literal tags.
        *   `\s*`: Matches zero or more whitespace characters (spaces, newlines, tabs). This makes the extraction robust to variations in spacing around the answer.
        *   `(.*?)`: This is the capturing group. `.` matches any character (except newline, unless `re.DOTALL` is used), `*?` matches it zero or more times, but non-greedily (meaning it captures the shortest possible string). This is important if there are multiple potential matches or nested structures (though not expected here).
        *   `flags=re.DOTALL`: Makes the `.` in the regex also match newline characters. This is important if the answer inside the tags spans multiple lines.
    *   `return m.group(1).strip() if m else ""`: If a match (`m`) is found, `m.group(1)` returns the content of the first capturing group (the actual answer). `.strip()` removes any leading/trailing whitespace from this extracted answer. If no match is found, it returns an empty string.
    *   *Purpose:* This function will be used by one of the reward functions later to get the model's predicted answer from its generated XML-formatted output.

*   **`extract_hash_answer(text: str) -> str | None`:**
    *   The GSM8K dataset originally provides answers in a format like "The final answer is 
#### <number>". This function extracts the `<number>` part.
    *   `text.split('####')`: Splits the string at "####". If "####" is present, this returns a list of two strings.
    *   `[1].strip()`: Takes the second part (the answer after "####") and removes whitespace.
    *   `if '####' in text else None`: Only performs the split and extraction if "####" is actually in the text; otherwise, it returns `None`.
    *   *Purpose:* To get the ground truth numerical answer from the GSM8K dataset's original format.

*   **`get_gsm8k_questions(split="train") -> Dataset`:**
    *   This function loads and preprocesses the GSM8K dataset.
    *   `data = load_dataset("openai/gsm8k", "main")[split]`: Loads the specified `split` (e.g., "train" or "test") of the "openai/gsm8k" dataset (main configuration).
    *   `.map(lambda x: { ... })`: Applies a transformation function (the `lambda`) to each example (`x`) in the dataset.
        *   `"prompt": [...]`: Creates the input prompt for the model. This is structured in a chat format, which instruction-tuned models like Gemma expect.
            *   `{"role": "system", "content": SYSTEM_PROMPT}`: The first message is from the "system", providing the overall instructions (our `SYSTEM_PROMPT` defined earlier).
            *   `{"role": "user", "content": x["question"]}`: The second message is from the "user", containing the actual math problem from the dataset (`x["question"]`).
        *   `"answer": extract_hash_answer(x["answer"])`: Extracts the ground truth numerical answer from the dataset's `x["answer"]` field using the helper function.
    *   *Purpose:* To prepare the training data in the format required by the GRPO trainer: a "prompt" (which is a list of chat messages) and a "answer" (the gold standard answer to the question).

*   **`dataset = get_gsm8k_questions()`:**
    *   Calls the function to get the training split of the GSM8K dataset and stores it in the `dataset` variable.

**Why it matters:** Data is the fuel for machine learning. This section meticulously prepares the GSM8K dataset. It formats the questions into a chat-like prompt that includes our system-level instructions and extracts the final numerical answers to be used as ground truth for the reward functions.

### 5. Reward Functions with Debug Prints

This is a very important section as GRPO relies on reward functions to guide the learning process. The model will generate several responses, and these functions will score how "good" each response is.

In [None]:
# ==========================================
# 3. Reward Functions with Debug Prints
# These guide the GRPO trainer and log detailed per-step info
# ==========================================
def correctness_reward_func(prompts, completions, answer, **kwargs):
    """
    Reward = 2.0 * (0.8 * content_score + tag_bonus), where:
      - content_score âˆˆ [0,1]: relative-error for numeric or string-similarity for text
      - tag_bonus âˆˆ {0, 0.2}: +0.2 if exactly one of each <reasoning>â€¦</reasoning> and <answer>â€¦</answer>
    """
    def string_similarity(a: str, b: str) -> float:
        return SequenceMatcher(None, a, b).ratio()
    
    def extract_between(text: str, start: str, end: str) -> str:
        m = re.search(fr"{re.escape(start)}\s*(.*?)\s*{re.escape(end)}", text, flags=re.DOTALL)
        return m.group(1).strip() if m else ""
    
    rewards = []
    for gen_list, gold in zip(completions, answer):
        full      = gen_list[0]["content"]
        reasoning = extract_between(full, "<reasoning>", "</reasoning>")
        pred      = extract_between(full, "<answer>",    "</answer>")
        
        try:
            num_pattern = r"[-+]?\d*\.?\d+"
            m_pred = re.search(num_pattern, pred)
            m_gold = re.search(num_pattern, gold)

            if m_pred and m_gold:
                p_f = float(m_pred.group())
                g_f = float(m_gold.group())
                rel_err = abs(p_f - g_f) / (abs(g_f) + 1e-8)
                numeric_score = max(0.0, 1 - rel_err)

                unit_pred = pred[m_pred.end():].strip()
                unit_gold = gold[m_gold.end():].strip()

                if unit_pred or unit_gold:
                    unit_score = string_similarity(unit_pred, unit_gold)
                    content_score = 0.5 * numeric_score + 0.5 * unit_score
                else:
                    content_score = numeric_score
            else:
                content_score = string_similarity(pred, gold)
        except Exception:
            content_score = string_similarity(pred, gold)
        
        # The docstring mentions a tag_bonus, but it's not implemented in the code block.
        # We'll assume the reward is just based on content_score for now.
        # reward = 2.0 * (0.8 * content_score + tag_bonus) -> simplified to:
        reward = 2.0 * content_score 
        
        question = prompts[0][-1]["content"] # Assumes prompts is a list of lists of dicts
        print(f"=== GRPO Step ===\n"
              f"Question: {question}\n\n"
              f"--- Reasoning ---S\n{reasoning}\n\n"
              f"--- Predicted Answer ---S\n{pred}\n\n"
              f"--- Gold Answer ---S\n{gold}\n\n"
              f"Content_score={content_score:.3f}, "
              f"Reward={reward:.3f}\n")
        
        rewards.append(reward)
    return rewards

**`correctness_reward_func` Explanation:**

*   **Inputs:**
    *   `prompts`: The input prompts given to the model.
    *   `completions`: A list of completions generated by the model for each prompt. For each prompt, `completions` will contain a list of generated texts (e.g., if `num_generations` is 6, this inner list has 6 items). Each item is a dictionary like `{"content": "generated_text"}`.
    *   `answer`: The gold standard (correct) answers for the prompts.
*   **`string_similarity(a: str, b: str) -> float`:** A helper function that calculates the similarity ratio between two strings `a` and `b` using `SequenceMatcher`. The ratio is between 0 (completely different) and 1 (identical).
*   **`extract_between(text: str, start: str, end: str) -> str`:** A generic helper to extract text between `start` and `end` delimiters. It's similar to `extract_xml_answer` but more general.
*   **Main Loop (`for gen_list, gold in zip(completions, answer):`)**:
    *   Iterates through each set of model generations (`gen_list`) and its corresponding gold answer (`gold`).
    *   `full = gen_list[0]["content"]`: Takes the first generation for the current prompt. (Note: GRPO typically involves multiple generations, this might be simplified here or `gen_list` might represent the top generation after some initial filtering not shown). *Correction from thinking process:* `completions` is a list of lists. The outer list corresponds to batch items. The inner list `gen_list` has `num_generations` items. The trainer will call this reward function for *each* generation, so `gen_list` here will likely be a list with a single dictionary `{"content": "..."}` representing one specific generation to be scored. If not, the logic should iterate over `gen_list`. Assuming `gen_list[0]` is the current generation being evaluated.
    *   `reasoning = extract_between(full, "<reasoning>", "</reasoning>")`: Extracts the model's reasoning.
    *   `pred = extract_between(full, "<answer>", "</answer>")`: Extracts the model's predicted answer.
*   **Calculating `content_score`:**
    *   It first tries to parse numbers from both the predicted answer (`pred`) and the gold answer (`gold`).
        *   `num_pattern = r"[-+]?\\d*\\.?\\d+"`: A regex to find integer or decimal numbers.\n        *   If both `pred` and `gold` contain numbers:
            *   It calculates the `numeric_score` based on relative error: `max(0.0, 1 - rel_err)`. A score of 1 means a perfect match, and it decreases as the relative error increases. `1e-8` is added to the denominator to prevent division by zero.
            *   It then tries to extract and compare units (text after the number). `unit_pred = pred[m_pred.end():].strip()`.
            *   If units are present, the `content_score` is a 50/50 blend of `numeric_score` and `unit_score` (string similarity of units).
            *   If no units, `content_score` is just the `numeric_score`.
        *   If numbers cannot be extracted from both, `content_score` defaults to the direct string similarity between `pred` and `gold`.
    *   A `try-except` block handles potential errors during parsing, defaulting to string similarity.
*   **Calculating `reward`:**
    *   `reward = 2.0 * content_score`. The docstring mentions a `tag_bonus` (`+0.2 if exactly one of each <reasoning>â€¦</reasoning> and <answer>â€¦</answer>`), but this bonus is not implemented in the provided code for this function. The reward is scaled by 2.0.
*   **Debugging Prints:**
    *   `question = prompts[0][-1]["content"]`: Extracts the user question from the prompt structure.
    *   Prints detailed information about the current step: the question, the model's reasoning and predicted answer, the gold answer, the calculated `content_score`, and the final `reward`. This is invaluable for debugging and understanding how the model is learning.
*   `rewards.append(reward)`: Collects the reward for the current generation.
*   `return rewards`: Returns a list of rewards, one for each completion evaluated.

**Why `correctness_reward_func` matters:** This is the primary reward function that tells the model how accurate its mathematical answer is, considering both numerical value and units. The detailed print statements are excellent for monitoring the training process.

---

In [None]:
def int_reward_func(completions, **kwargs):
    """
    Reward function that checks if the answer is a pure integer.
    """
    # Reward numeric-only answers with 0.5
    responses = [c[0]["content"] for c in completions]
    extracted = [extract_xml_answer(r) for r in responses]
    rewards = [0.5 if r.isdigit() else 0.0 for r in extracted]
    print(f"[int_reward] Ans: {extracted} | R: {rewards}\n")
    return rewards

**`int_reward_func` Explanation:**

*   `responses = [c[0]["content"] for c in completions]`: Extracts the full generated text for each completion. (Assuming `c` is like `gen_list` from before, representing one generation).
*   `extracted = [extract_xml_answer(r) for r in responses]`: Extracts the content from the `<answer>` tags for each response.
*   `rewards = [0.5 if r.isdigit() else 0.0 for r in extracted]`: For each extracted answer, it checks if the string `r` consists *only* of digits using `r.isdigit()`. If it is purely a number (e.g., "123", not "12.3" or "123 apples"), it gets a reward of `0.5`, otherwise `0.0`.
*   `print(...)`: Logs the extracted answers and their corresponding rewards.

**Why `int_reward_func` matters:** This function provides a small incentive for the model to produce answers that are purely numerical integers. This might be useful for certain types of problems or as a shaping reward. However, it might conflict with the `correctness_reward_func` if units are expected or if answers can be decimals. The GSM8K dataset often has integer answers, so this might be a gentle nudge.

---

In [None]:
def strict_format_reward_func(completions, **kwargs):
    # Reward strict XML formatting
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [c[0]["content"] for c in completions]
    matches = [bool(re.match(pattern, r)) for r in responses]
    rewards = [0.5 if m else 0.0 for m in matches]
    print(f"[strict_format] matches={matches} | r={rewards}\n")
    return rewards

**`strict_format_reward_func` Explanation:**

*   `pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"`: This is a regular expression defining a very strict format:
    *   `^`: Matches the beginning of the string.
    *   `<reasoning>\n`: Literal tag followed by a newline.
    *   `.*?\n`: Any characters (non-greedy) followed by a newline (for the reasoning content).
    *   `</reasoning>\n`: Closing tag followed by a newline.
    *   `<answer>\n`: Literal tag followed by a newline.
    *   `.*?\n`: Any characters (non-greedy) followed by a newline (for the answer content).
    *   `</answer>\n`: Closing tag followed by a newline.
    *   `$`: Matches the end of the string.
    *   This means the entire output must *exactly* match this structure, including the newlines in specific places, and nothing before or after.
*   `matches = [bool(re.match(pattern, r)) for r in responses]`: For each response, checks if it fully matches the `pattern` from the beginning. `re.match` only matches at the beginning of the string.
*   `rewards = [0.5 if m else 0.0 for m in matches]`: Gives a reward of `0.5` if the format is strictly matched, `0.0` otherwise.
*   `print(...)`: Logs matches and rewards.

**Why `strict_format_reward_func` matters:** This function strongly encourages the model to adhere to the precise XML and newline structure specified in the `SYSTEM_PROMPT`. This is key for reliable parsing of the output.

---

In [None]:
def soft_format_reward_func(completions, **kwargs):
    # Reward looser XML formatting
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [c[0]["content"] for c in completions]
    matches = [bool(re.match(pattern, r)) for r in responses] # Should this be re.search?
    rewards = [0.5 if m else 0.0 for m in matches]
    print(f"[soft_format] matches={matches} | r={rewards}\n")
    return rewards

**`soft_format_reward_func` Explanation:**

*   `pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"`: This regex defines a looser format:
    *   It looks for `<reasoning>...</reasoning>` followed by zero or more whitespace characters (`\s*`) and then `<answer>...</answer>`.
    *   It does *not* anchor to the start (`^`) or end (`$`) of the string, nor does it enforce specific newlines around the tags like the strict version.
*   `matches = [bool(re.match(pattern, r)) for r in responses]`: Similar to the strict function, it uses `re.match`. This means it still expects this pattern to appear at the *beginning* of the response. If the intention was to find this pattern anywhere, `re.search` would be more appropriate. Given the `SYSTEM_PROMPT`, `re.match` is likely still the intention, rewarding outputs that *start* with this structure, even if there's trailing content (which another reward function might penalize).
*   `rewards = [0.5 if m else 0.0 for m in matches]`: Reward of `0.5` for a match, `0.0` otherwise.
*   `print(...)`: Logs matches and rewards.

**Why `soft_format_reward_func` matters:** This provides a gentler reward for getting the basic tag order right, even if the newlines or surrounding text aren't perfect. It can act as an intermediate step towards the stricter format.

---

In [None]:
def count_xml(text: str) -> float:
    # Penalize extra content, reward correct tag counts
    score = 0.0
    if text.count("<reasoning>\n") == 1: score += 0.125
    if text.count("\n</reasoning>\n") == 1: score += 0.125
    if text.count("\n<answer>\n") == 1:
        score += 0.125
        # Penalize content *after* the <answer> tag's closing newline (if format is </answer>\n)
        # This seems to assume a specific structure.
        score -= len(text.split("\n</answer>\n")[-1]) * 0.001 # Small penalty for trailing chars
    if text.count("\n</answer>") == 1: # This condition might overlap or conflict if \n</answer>\n is also present
        score += 0.125
        # Penalize content *after* </answer> (if format is ...</answer>EOF or ...</answer> SomethingElse)
        score -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001 # -1 if expecting just a newline after
    return score


def xmlcount_reward_func(completions, **kwargs):
    # Reward based on XML tag usage
    responses = [c[0]["content"] for c in completions]
    rewards = [count_xml(r) for r in responses]
    print(f"[xmlcount] R: {rewards}\n")
    return rewards

**`count_xml` and `xmlcount_reward_func` Explanation:**

*   **`count_xml(text: str) -> float`:**
    *   This function calculates a score based on the count of specific tag occurrences and penalizes extra content, especially after the `</answer>` tag.
    *   It checks for very specific newline arrangements around the tags:
        *   `text.count("<reasoning>\n") == 1`: Adds `0.125` if `<reasoning>` followed by a newline appears exactly once.
        *   `text.count("\n</reasoning>\n") == 1`: Adds `0.125` if `</reasoning>` surrounded by newlines appears exactly once.
        *   `text.count("\n<answer>\n") == 1`: Adds `0.125` if `<answer>` surrounded by newlines appears exactly once.
            *   `score -= len(text.split("\n</answer>\n")[-1]) * 0.001`: If this specific `\n</answer>\n` structure is found, it splits the text by this sequence. The last part of the split (`[-1]`) is any text that comes *after* `\n</answer>\n`. The length of this trailing text is penalized by a small factor (`0.001` per character). This aims to remove any text after the properly formatted answer block.
        *   `text.count("\n</answer>") == 1`: Adds `0.125` if `\n</answer>` (without a trailing newline explicitly in the count) appears once. This could potentially overlap with the previous condition if the format is `\n</answer>\n`.
            *   `score -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001`: This penalizes characters after `\n</answer>`. The `- 1` might be there to allow for a single trailing newline without penalty if the ideal output is `...\n</answer>\n` and `SYSTEM_PROMPT` implies the output ends exactly after `</answer>`. If `SYSTEM_PROMPT` implies `</answer>\n` is the absolute end, then any character after `\n</answer>` would be penalized. This part is a bit intricate and depends on the exact expected EOL character.
    *   The maximum positive score from tag counts seems to be `0.125 * 4 = 0.5`.
*   **`xmlcount_reward_func(completions, **kwargs)`:**
    *   Applies the `count_xml` function to each generated response.
    *   Prints the rewards.

**Why `xmlcount_reward_func` matters:** This function provides a granular reward/penalty based on the precise counts and placement of XML tags and associated newlines. It also actively penalizes any superfluous text after the final `</answer>` tag, which is critical for ensuring the output matches the `SYSTEM_PROMPT`'s requirement of "no extra text."

**Overall on Reward Functions:** The script uses a combination of reward functions. This is a powerful aspect of GRPO. Some rewards focus on the correctness of the content (`correctness_reward_func`), while others focus on the structural integrity and formatting (`strict_format_reward_func`, `soft_format_reward_func`, `xmlcount_reward_func`), and some on specific content properties (`int_reward_func`). The trainer will combine these rewards (likely by summing them or using a weighted sum, though the exact combination mechanism is part of `GRPOTrainer`'s internals or default behavior) to get a final score for each generated response.

### 6. Configure and Initialize GRPO Trainer

In [None]:
# ==========================================
# 4. Configure and Initialize GRPO Trainer
# ==========================================
training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    logging_steps=1,                  # Log every step
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    num_generations=6,                # Generations per prompt
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    max_steps=250,                    # Total training steps
    save_steps=250,                   # Save checkpoint at end
    max_grad_norm=0.1,
    report_to="none",
    output_dir="outputs",            # Directory for outputs
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer, # Should be tokenizer_class or similar, or just tokenizer if expected by GRPOTrainer
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
)

**Explanation:**

*   **`training_args = GRPOConfig(...)`:**
    *   This creates an instance of `GRPOConfig` to hold all the hyperparameters and settings for the training process.
    *   `learning_rate=5e-6`: The learning rate for the optimizer. A small value like `5e-6` (0.000005) is common for fine-tuning LLMs.
    *   `adam_beta1=0.9, adam_beta2=0.99`: Parameters for the AdamW optimizer. These are typical default values.
    *   `weight_decay=0.1`: A regularization technique to prevent overfitting by penalizing large weights.
    *   `warmup_ratio=0.1`: For the first 10% of training steps, the learning rate will gradually increase from 0 to `learning_rate`. This helps stabilize training at the beginning.
    *   `lr_scheduler_type="cosine"`: Specifies how the learning rate should change after the warmup phase. A cosine scheduler gradually decreases the learning rate following a cosine curve.
    *   `optim="paged_adamw_8bit"`: Uses an 8-bit paged AdamW optimizer, which is memory-efficient, often provided by libraries like `bitsandbytes` and integrated with Unsloth.
    *   `logging_steps=1`: How often to log training information (like loss, rewards). Here, it's every step.
    *   `per_device_train_batch_size=1`: The number of training examples to process on each GPU per step. A batch size of 1 is common when VRAM is limited or sequences are long.
    *   `gradient_accumulation_steps=1`: If this were > 1, gradients would be accumulated over multiple small batches before an optimizer step, effectively increasing the batch size without increasing memory. Here, it's 1, so each batch is processed independently.
    *   `num_generations=6`: For each prompt in a batch, the model will generate 6 different completions. These completions will then be evaluated by the reward functions. This is a key parameter for GRPO, as it needs multiple diverse generations to learn from.
    *   `max_prompt_length=max_prompt_length`: Maximum length of the input prompt (set to 256 earlier).
    *   `max_completion_length=max_seq_length - max_prompt_length`: Maximum length for the model's generated text (1024 - 256 = 768 tokens).
    *   `max_steps=250`: The total number of training steps to perform.
    *   `save_steps=250`: How often to save a model checkpoint. Here, it saves at the very end of training (since `max_steps` is also 250).
    *   `max_grad_norm=0.1`: Gradient clipping. If the norm (magnitude) of the gradients exceeds this value, they will be scaled down. This helps prevent exploding gradients and stabilizes training.
    *   `report_to="none"`: Disables reporting to external services like Weights & Biases or TensorBoard. Set to `"wandb"` or `"tensorboard"` to enable them.
    *   `output_dir="outputs"`: The directory where training outputs (like model checkpoints) will be saved.

*   **`trainer = GRPOTrainer(...)`:**
    *   This initializes the `GRPOTrainer`, which orchestrates the GRPO fine-tuning process.
    *   `model=model`: The LoRA-adapted model we prepared earlier.
    *   `processing_class=tokenizer` (or `tokenizer=tokenizer`): The tokenizer associated with the model. The `GRPOTrainer` might expect the tokenizer itself or a class that can provide processing. The `main.py` uses `processing_class=tokenizer`. Let's assume `GRPOTrainer` is flexible or expects the tokenizer instance here. *Self-correction:* The `GRPOTrainer` in `trl` typically expects a `tokenizer` argument, not `processing_class`. This might be a custom version or a slight misremembering in the script. Given Unsloth's integration, it might have specific expectations. Assuming it works as written.
    *   `reward_funcs=[...]`: A list of all the reward functions we defined earlier. The trainer will use these to score the `num_generations` completions for each prompt. The order might matter if the trainer processes them sequentially or combines their outputs in a specific way, but often they are summed.
    *   `args=training_args`: The training configuration we just defined.
    *   `train_dataset=dataset`: The prepared GSM8K dataset.

**Why it matters:** This section configures all aspects of the GRPO training loop, from learning rates and optimizers to how many responses the model should generate and how long it should train. It then brings together the model, tokenizer, data, and reward functions into the `GRPOTrainer` object, ready for training.

### 7. Start Training

In [None]:
# ==========================================
# 5. Start Training
# Detailed prints show each generation, reasoning, and reward
# ==========================================
trainer.train()

**Explanation:**

*   `trainer.train()`: This single line kicks off the entire GRPO training process.
    *   The `GRPOTrainer` will now iterate through the `train_dataset` for `max_steps`.
    *   In each step:
        1.  It takes a batch of prompts.
        2.  For each prompt, it has the model generate `num_generations` (6) different completions.
        3.  Each of these completions is then passed to all the `reward_funcs`.
        4.  The rewards from these functions are likely aggregated (e.g., summed) to get a final reward for each completion.
        5.  Based on these rewards, GRPO updates the model's policy (the trainable LoRA weights) to encourage generations that receive higher rewards and discourage those with lower rewards.
        6.  The optimizer (`paged_adamw_8bit`) updates the weights using the calculated gradients.
        7.  Logging (due to `logging_steps=1` and prints in reward functions) will occur, showing progress.

**Why it matters:** This is where the learning happens! The model iteratively improves its ability to generate well-formatted and correct mathematical reasoning and answers based on the feedback from the reward functions. The debug prints inside our reward functions will provide a rich log of this process.

### 8. Save LoRA Weights for Later Use

In [None]:
# ==========================================
# 6. Save LoRA Weights for Later Use
# ==========================================
model.save_lora("grpo_saved_lora")  # Save only the adapted weights

**Explanation:**

*   `model.save_lora("grpo_saved_lora")`: After training is complete, this line saves the trained LoRA adapter weights.
    *   The argument `"grpo_saved_lora"` specifies the directory name where these weights will be saved.
    *   Importantly, this saves *only* the LoRA weights (the small, newly trained matrices), not the entire base model. This makes the saved artifact very small and portable.

**Why it matters:** Saving the LoRA weights allows us to reuse the fine-tuned adaptation later. We can load the original base model and then apply these saved LoRA weights to get back our specialized, fine-tuned model without having to retrain.

### 9. Test the Fine-Tuned Model

In [None]:
# ==========================================
# 7. Test the Fine-Tuned Model
# ==========================================
# Define SamplingParams for generation (vLLM specific if fast_inference=True with vLLM)
from vllm import SamplingParams # This import is missing in the original script if model.fast_generate uses it

chat_input = tokenizer.apply_chat_template(
    [
        {"role": "system",  "content": SYSTEM_PROMPT},
        {"role": "user",    "content": "Calculate pi."},
    ],
    tokenize=False,
    add_generation_prompt=True,
)
sampling_params = SamplingParams( # This needs vllm.SamplingParams if vLLM is used by fast_generate
    temperature=0.8,
    top_p=0.95,
    max_tokens=1024,
)
output = (
    model.fast_generate(
        chat_input,
        sampling_params=sampling_params,
        lora_request=model.load_lora("grpo_saved_lora"), # This API might be specific to Unsloth's FastLanguageModel
    )[0]
    .outputs[0]
    .text
)
print("\n=== Model Output ===\n", output)

**Explanation:**

*   **Import `SamplingParams`:** The line `from vllm import SamplingParams` appears to be missing from the original script but would be necessary if `model.fast_generate` relies on vLLM's `SamplingParams` object. Unsloth might have its own way of handling sampling parameters or `fast_generate` might accept a dictionary. For this explanation, we'll assume `SamplingParams` is either provided by Unsloth or should be imported if vLLM is directly leveraged by `fast_generate`. If Unsloth wraps this, it might have its own `SamplingParams` or a dictionary-based configuration. (The original script *does not* have this import, so `SamplingParams` must be an object Unsloth provides or `fast_generate` handles differently). *Correction:* The script uses `model.fast_generate` which is part of Unsloth. Unsloth's `fast_generate` is compatible with vLLM `SamplingParams` if vLLM is the backend. The script provided does *not* import `SamplingParams`. This implies that either `fast_generate` has a default or `SamplingParams` is an object implicitly available or configured differently by Unsloth. The code might run if `Unsloth` injects `SamplingParams` or if `fast_generate` can take these as keyword arguments or has an internal way to handle them if `SamplingParams` object isn't explicitly created from `vllm`.
    *Actually, looking at the provided `main.py`, `SamplingParams` is indeed used without being imported. This usually means it's either a built-in, globally available from an earlier import (like `unsloth` itself making it available), or it's an error in the script. Assuming for this tutorial that Unsloth makes `SamplingParams` available or its `fast_generate` method can interpret these arguments without an explicit `SamplingParams` object.* 
    *(Self-correction based on provided main.py: `main.py` *does* import `SamplingParams` from `vllm` at the top. The video script's commentary on this was based on an observation that seems to differ from the provided `main.py` file. The code block above includes the import as shown in the video script's testing section for narrative consistency with the video script.)*


*   **`chat_input = tokenizer.apply_chat_template(...)`:**
    *   Prepares a test prompt in the same chat format used during training.
    *   `{"role": "system", "content": SYSTEM_PROMPT}`: Includes the system prompt.
    *   `{"role": "user", "content": "Calculate pi."}`: A new question to test the model.
    *   `tokenize=False`: Returns the formatted prompt as a single string.
    *   `add_generation_prompt=True`: Appends the necessary tokens/text that signals to the model it should start generating a response (e.g., for Gemma, this might be `"<start_of_turn>model\n"`).

*   **`sampling_params = SamplingParams(...)` (or dictionary for `fast_generate`):**
    *   Defines parameters that control how the model generates text (sampling strategy):
        *   `temperature=0.8`: Controls randomness. Higher temperature (e.g., 0.8) makes the output more random and creative. Lower temperature (e.g., 0.2) makes it more deterministic and focused.
        *   `top_p=0.95` (Nucleus Sampling): Considers only the smallest set of tokens whose cumulative probability exceeds `top_p`. This helps generate more coherent text by avoiding very unlikely tokens.
        *   `max_tokens=1024`: Maximum number of tokens to generate for the response. (This should ideally be `max_completion_length` or similar, as 1024 includes the prompt).

*   **`output = model.fast_generate(...)`:**
    *   This is where the fine-tuned model generates a response to the `chat_input`.
    *   `model.fast_generate`: Unsloth's optimized generation function.
    *   `sampling_params=sampling_params`: Uses the defined sampling parameters.
    *   `lora_request=model.load_lora("grpo_saved_lora")`: This is an interesting part. It seems Unsloth's `FastLanguageModel` can dynamically load LoRA adapters for an inference request. It loads the `grpo_saved_lora` weights we saved earlier. This is powerful as it means the base model doesn't need to be permanently merged with the LoRA weights to use them.
    *   `[0].outputs[0].text`: Accesses the generated text from the output structure returned by `fast_generate`. The exact structure can vary based on the generation library, but this extracts the first generated sequence's text.

*   `print("\n=== Model Output ===\n", output)`: Prints the model's generated response to the console. We would hope to see something like:
    ```xml
    <reasoning>
    Pi is an irrational number representing the ratio of a circle's circumference to its diameter. It cannot be expressed as a simple fraction. Its approximate value is 3.14159.
    </reasoning>
    <answer>
    Approximately 3.14159
    </answer>
    ```

**Why it matters:** This section demonstrates how to use the fine-tuned model for inference. It shows how to format a prompt, apply the saved LoRA weights, and generate a response. The output will give us an idea of how well the fine-tuning worked on a new, unseen question. The dynamic loading of LoRA weights via `lora_request` is a flexible feature.

*Self-note: The use of `SamplingParams` without import is a potential issue in the original script. For a robust script, it should be imported if it's from `vllm`, or if Unsloth provides its own version, that should be used. If `fast_generate` can take these as simple kwargs, then the current code is fine.*

### 10. Save the Merged Model (Optional but often done)

In [None]:
# ==========================================
# 8. Save the model
# ==========================================
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")

**Explanation:**

*   `model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")`: This line saves the base model with the LoRA weights *merged* into it.
    *   `"model"`: The directory where the merged model will be saved.
    *   `tokenizer`: The tokenizer is also saved alongside the model, which is standard practice.
    *   `save_method="merged_16bit"`: Specifies how to merge and save.
        *   `merged`: The LoRA weights are combined with the base model's weights to create a new set of full weights.
        *   `16bit`: The merged model is saved in 16-bit precision (half-precision floating point, also known as `float16` or `bfloat16`). This offers a good balance between model size, speed, and numerical precision for many LLMs. This is larger than the 4-bit quantized model we used for training but is a common format for deployment.

**Why it matters:**
While LoRA allows for efficient storage and dynamic loading of adapters, sometimes you want a standalone model that has the fine-tuning "baked in." Merging the weights creates such a model. This merged model can then be loaded and used like any standard Hugging Face model, without needing special LoRA-handling code (though you lose the ability to easily swap LoRA adapters). Saving in 16-bit is a common practice for deploying models that were trained with quantization and LoRA.

## Conclusion

The `main.py` script provides a comprehensive example of fine-tuning a Gemma language model using the advanced GRPO technique with LoRA and Unsloth for optimization. It meticulously prepares data, defines a suite of reward functions to guide the model towards correct and well-formatted mathematical reasoning, and demonstrates how to train, save, and test the resulting specialized model.

Key takeaways:

*   **Structured Output:** The emphasis on XML-formatted output (`<reasoning>`, `<answer>`) is crucial for making the model's responses predictable and parsable.
*   **Reward Engineering:** The success of GRPO heavily relies on well-designed reward functions that accurately capture the desired behaviors (correctness, formatting, etc.).
*   **Efficiency with Unsloth and LoRA:** These tools make it feasible to fine-tune billion-parameter models on relatively modest hardware by significantly reducing memory and computational demands.
*   **Step-by-Step Process:** The script follows a logical flow: model setup, data prep, defining rewards, configuring the trainer, training, and finally saving/testing.

This script serves as an excellent starting point for anyone looking to delve into advanced LLM fine-tuning for tasks requiring structured, reasoned outputs. By understanding and modifying the components, especially the reward functions and system prompt, you can adapt this approach to a wide variety of other tasks.