# GRPO Demo

This tutorial demonstrates training the [Gemma](https://deepmind.google/models/gemma/)
3 1B-IT model on the [GSM8K math reasoning benchmark](https://huggingface.co/datasets/openai/gsm8k)
using [Group Relative Policy Optimization (GRPO)](https://arxiv.org/pdf/2402.03300).
GRPO can enhance your model's problem-solving skills on mathematical word problems,
coding problems, etc.

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It
is a variant of [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
that reduces memory usage by eliminating the need for a separate value function
model. GRPO works by generating multiple responses for a given prompt,
evaluating these responses using a reward model, and then calculating a relative
advantage based on the group's performance to update the policy.

In this tutorial we use a `v5e-8` TPU for Gemma3-1b-it. Let's get started!

Note that the setup below is for the Gemma3-1B-IT model only. If you want to use
another model (say, Qwen2.5), you may need to change the setup (for example,
tokenizer, chat template, reward function, etc.).

# Understanding This Notebook: A Deep Dive into GRPO

## What This Notebook Does

This notebook teaches a language model (Gemma 3 1B) to solve math word problems by using **Reinforcement Learning from AI Feedback (RLAIF)**. Specifically, it uses an algorithm called **Group Relative Policy Optimization (GRPO)**. By the end of training, the model should:

1. **Follow a specific output format** (reasoning in `<reasoning>` tags, answer in `<answer>` tags)
2. **Produce correct numerical answers** to grade-school math problems
3. **Show improved reasoning capabilities**

## What is GRPO? (Explained Simply)

### The Core Idea

Imagine you're learning to solve math problems. Your teacher gives you a problem and you write 4 different solutions. The teacher then:
1. Grades each solution
2. Compares your solutions to each other (not to some "perfect" answer)
3. Tells you which of YOUR solutions was best, and why
4. You learn to produce more solutions like your best ones

That's GRPO in a nutshell. It's a **self-improvement** method where the model learns by comparing its own outputs, not by imitating a perfect example.

### Why is This Different from Standard Fine-Tuning?

**Supervised Fine-Tuning (SFT)**: "Here's the perfect answer. Copy it."
- Problem: You need expensive human-written perfect examples
- Problem: Model might memorize patterns without understanding

**GRPO (Reinforcement Learning)**: "Generate multiple solutions. Here's what makes some better than others."
- Advantage: Model learns from its own attempts
- Advantage: Learns to optimize for specific goals (correctness, format, reasoning)
- Advantage: Can improve beyond its training data

### The GRPO Algorithm Step-by-Step

1. **Sample a prompt** (a math problem)
2. **Generate G responses** (G = NUM_GENERATIONS, here 4 different solutions)
3. **Score each response** using reward functions (checking format, correctness, etc.)
4. **Calculate relative advantage**: How much better/worse is each response compared to the group average?
5. **Update the model** to make high-reward responses more likely, low-reward responses less likely
6. **Constrain updates** using KL divergence to prevent the model from changing too drastically

### Key GRPO Concepts

**Policy Model (œÄ_Œ∏)**: The model we're training. It learns to generate better responses.

**Reference Model (œÄ_ref)**: A frozen copy of the original model. We compare our updated model to this to ensure we don't drift too far.

**Reward Functions**: Mathematical scorecards that evaluate response quality. Here we use 4:
- Format matching (exact)
- Format matching (approximate)
- Answer correctness
- Number extraction and matching

**KL Divergence Penalty**: A mathematical measure of how different two probability distributions are. We penalize the model if its outputs become too different from the reference model. This prevents:
- Mode collapse (model producing only one type of response)
- Catastrophic forgetting (forgetting useful behaviors)
- Reward hacking (gaming the reward functions in unintended ways)

**Advantage Function**: In standard RL, you need a "value function" (a separate model) to estimate how good a state is. GRPO simplifies this by comparing responses within a group. The advantage of a response is: `reward - mean(rewards in group)`. If your response scored 5 and the group average was 3, your advantage is +2.

**Clipping (Œµ)**: Borrowed from PPO (Proximal Policy Optimization). Limits how much the probability of an action can change in one update. Prevents overly aggressive updates that could destabilize training.

# Environment Setup

The cell below disables a HuggingFace Hub feature (XET) that can cause issues in some environments. This is a compatibility fix - don't worry about the details.

In [1]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

### What Are We Installing and Why?

Let's break down each package:

**Core ML/DL Frameworks:**
- `tensorflow` & `tensorflow_datasets`: TensorFlow is Google's ML framework. We use it for data loading (GSM8K dataset)
- `transformers`: HuggingFace's library for pre-trained models (tokenizers, model architectures)
- `flax`: Google's neural network library built on JAX. More flexible than TensorFlow, designed for research
- `jax` (installed as dependency): Google's library for high-performance numerical computing with automatic differentiation

**Training Infrastructure:**
- `google-tunix`: Google's library specifically for fine-tuning language models with RL methods like GRPO
- `grain`: Efficient data loading pipeline for JAX
- `optax` (in tunix): Optimization library for JAX (think: Adam, SGD, learning rate schedulers)
- `orbax` (in tunix): Checkpointing library - saves/loads model weights

**Experiment Tracking:**
- `wandb` (Weights & Biases): Tracks training metrics, visualizes loss curves, logs experiments
- `tensorboardX`: Alternative metric logging

**Data Handling:**
- `datasets`: HuggingFace's library for loading datasets
- `kagglehub`: Downloads datasets from Kaggle

**Utilities:**
- `ipywidgets`: Interactive widgets for Jupyter (progress bars, etc.)

**Why These Specific Versions?**
The `flax` reinstall with `-U` ensures we have the latest version compatible with tunix. Package version mismatches are common in ML, so these specific installations ensure everything works together.

The cell below sets up Weights & Biases (W&B) API key for experiment tracking. W&B will log training metrics so you can visualize the training progress. On Kaggle, secrets are stored securely and retrieved via `UserSecretsClient`.

## Install necessary libraries

### Understanding the Imports

Let me explain what each import does:

**Standard Python:**
- `functools`: For creating partial functions (pre-filling function arguments)
- `gc`: Garbage collection - manually free memory when we delete large models
- `os`: Operating system interactions (file paths, environment variables)
- `pprint`: "Pretty print" - formats complex data structures nicely
- `re`: Regular expressions for pattern matching in text
- `csv`, `shutil`: File handling utilities

**JAX Ecosystem (The Core Computing Engine):**
- `jax`: Main library for numerical computing with automatic differentiation
- `jax.numpy as jnp`: JAX's version of NumPy (array operations, math)
- `flax.nnx`: Neural network modules in Flax (layers, parameters, models)

**Data Loading:**
- `grain`: Efficient data pipeline (like PyTorch's DataLoader but for JAX)
- `tensorflow_datasets as tfds`: Loads standard datasets (GSM8K)
- `datasets` (from HuggingFace): Alternative dataset loading

**Model & Training (Tunix - Google's RL Training Library):**
- `qwix`: Quantization and LoRA (Low-Rank Adaptation) utilities
- `sampler_lib`: Generates text from the model
- `tokenizer_lib`: Converts text to/from numbers (tokens)
- `model`, `params`: Gemma 3 model architecture and parameter loading
- `GRPOConfig`, `GRPOLearner`: The GRPO training algorithm implementation
- `rl_cluster_lib`: Manages distributed training across devices
- `base_rollout`: Generates model outputs during training
- `metrics_logger`: Logs training statistics

**Utilities:**
- `humanize`: Converts numbers to human-readable formats (e.g., "1.2 GiB")
- `optax`: Optimizers for JAX (AdamW, learning rate schedules)
- `orbax.checkpoint as ocp`: Saves/loads model checkpoints
- `tqdm`: Progress bars
- `kagglehub`: Downloads Kaggle datasets
- `Path` from `pathlib`: Modern file path handling

**Why JAX and Flax instead of PyTorch?**

JAX offers:
1. **JIT Compilation**: Compiles Python to optimized machine code
2. **Automatic Vectorization**: Easily parallelize across data
3. **TPU Support**: Native support for Google's TPU accelerators (what we're using)
4. **Functional Programming**: Cleaner for RL algorithms where you need gradient computation

Flax is the neural network library built on JAX, similar to how PyTorch's `nn.Module` works but with JAX's functional style.

In [2]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

# !pip install -q git+https://github.com/google/tunix
# !pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
# !pip install -q git+https://github.com/google/flax.git
!pip install -U flax


!pip install -q datasets wandb

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip instal

### Deep Dive into Hyperparameters

This is one of the most important cells in the notebook. These hyperparameters control how the model trains. Let me explain each group:

---

## üìÅ Data Configuration
```python
TRAIN_DATA_DIR = "./data/train"      # Where to save/load training data
TEST_DATA_DIR = "./data/test"        # Where to save/load test data
TRAIN_FRACTION = 1.0                  # Use 100% of data for training (no validation split)
```

---

## üîß LoRA (Low-Rank Adaptation) Configuration

**What is LoRA?**
Instead of fine-tuning ALL parameters in a huge model (billions of weights), LoRA adds small "adapter" matrices that are much smaller. It's like adding a small brain alongside the big brain rather than retraining the whole thing.

```python
RANK = 64           # Size of the adapter matrices (higher = more capacity but more memory)
ALPHA = 64.0        # Scaling factor for LoRA updates (typically set equal to RANK)
```

**Why use LoRA?**
- Gemma 3 1B has ~1 billion parameters
- Training all of them requires enormous GPU memory
- LoRA trains only ~0.1-1% of parameters
- You get 90%+ of the benefit with 1% of the cost

**RANK and ALPHA:**
- RANK determines the "bottleneck" dimension. Higher rank = more expressive but more memory
- ALPHA is a scaling factor. When ALPHA = RANK, it's equivalent to a learning rate multiplier of 1

---

## üñ•Ô∏è Sharding (Distributed Computing) Configuration

```python
MESH = [(1, 4), ("fsdp", "tp")]
```

This distributes the model across 4 TPU cores:
- `fsdp` (Fully Sharded Data Parallel): Splits model parameters across devices
- `tp` (Tensor Parallel): Splits individual matrix operations across devices
- `(1, 4)` means 1 replica with 4-way parallelism

Think of it as: "Cut the model into 4 pieces, one per TPU core"

---

## üéØ GRPO-Specific Configuration

### Generation During Training
```python
MAX_PROMPT_LENGTH = 256              # Max tokens in input prompt
TOTAL_GENERATION_STEPS = 512         # Max tokens model can generate as response
TEMPERATURE = 0.9                     # Randomness in generation (0 = deterministic, 1 = random)
TOP_P = 1.0                          # Nucleus sampling threshold (1.0 = consider all tokens)
TOP_K = 50                           # Only sample from top 50 most likely tokens
NUM_GENERATIONS = 4                   # Generate 4 different responses per prompt (the "Group" in GRPO)
```

**Why high temperature (0.9)?**
During training, we WANT diverse responses! If the model always gave the same answer, we couldn't compare different solutions. High temperature = more variety = better learning signal.

**NUM_GENERATIONS = 4:**
This is the "G" in GRPO. For each math problem, the model generates 4 different solutions. These 4 solutions compete against each other. The model learns to produce more solutions like the winners.

### Core GRPO Parameters
```python
NUM_ITERATIONS = 1        # Œº in the paper: How many times to update on each batch
BETA = 0.08              # Œ≤: KL divergence penalty strength
EPSILON = 0.2            # Œµ: Clipping range for PPO-style updates
```

**BETA (KL Penalty) - CRITICAL PARAMETER:**
This controls how much we penalize the model for deviating from its original behavior.
- Too low (e.g., 0.01): Model changes too much, may "forget" how to write coherent text
- Too high (e.g., 1.0): Model barely changes, learns very slowly
- 0.08 is a balanced value

**EPSILON (Clipping):**
Borrowed from PPO. If the probability of an action would change by more than ¬±20%, we clip it. This prevents:
- Single bad examples from destroying the model
- Overconfident updates based on noisy rewards

---

## üèãÔ∏è Training Configuration

```python
TRAIN_MICRO_BATCH_SIZE = 4           # Number of prompts per training step
NUM_BATCHES = 3738                    # Total number of batches in dataset
NUM_EPOCHS = 1                        # Train for 1 pass through the data
MAX_STEPS = 1869                      # Calculated: batches √ó iterations √ó epochs
```

**Batch Size:**
Why 4? Memory constraints. Each prompt generates 4 responses (NUM_GENERATIONS), so we're actually processing 4 √ó 4 = 16 sequences per step.

### Optimizer Configuration
```python
LEARNING_RATE = 3e-6      # How big each update step is (0.000003 - very small!)
B1 = 0.9                  # AdamW momentum for gradients
B2 = 0.99                 # AdamW momentum for squared gradients
WEIGHT_DECAY = 0.1        # L2 regularization (prevents overfitting)
WARMUP_STEPS = ~187       # 10% of training: gradually increase LR from 0
MAX_GRAD_NORM = 0.1       # Clip gradients if they get too large
```

**Why such a tiny learning rate (3e-6)?**
RL training is notoriously unstable. Large updates can destroy the model's ability to generate coherent text. Think of it as: "make very small, careful adjustments."

**Warmup:**
Start with learning rate = 0, gradually increase to 3e-6 over ~187 steps. This prevents the model from making wild updates early in training when gradients are noisy.

**Cosine Decay:**
After warmup, gradually decrease learning rate to 0 following a cosine curve. The intuition: make big changes early, fine-tune at the end.

**Gradient Clipping (MAX_GRAD_NORM = 0.1):**
If gradients are larger than 0.1, scale them down. This is CRUCIAL for RL stability. Prevents exploding gradients from ruining training.

---

## üíæ Checkpointing

```python
CKPT_DIR = "/tmp/content/ckpts/"      # Where to save model checkpoints
SAVE_INTERVAL_STEPS = 500              # Save every 500 steps
MAX_TO_KEEP = 4                        # Keep only 4 most recent checkpoints (save disk space)
```

---

## üîÆ Inference Configuration

```python
GENERATION_CONFIGS = {
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},      # Almost deterministic
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},   # Balanced
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},  # Creative
}
```

These are for evaluation, not training. "Greedy" (nearly deterministic) is used for testing because you want consistent, reproducible results.

In [3]:
import wandb, os
from kaggle_secrets import UserSecretsClient
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")

The `show_hbm_usage()` function monitors High Bandwidth Memory (HBM) usage on TPUs. HBM is the fast memory directly attached to the TPU chips. Monitoring it helps you understand if you're running out of memory during training.

**Interpreting the output:**
- If usage is near the limit, you might get out-of-memory errors
- Typical usage for this setup: 50-80% of available memory

## Imports

## Understanding the Prompt Template

This is the **system prompt** - instructions telling the model HOW to respond.

### Why This Structure?

The model is being trained to:
1. **Think step-by-step** (in `<reasoning>` tags)
2. **Give a clear final answer** (in `<answer>` tags)

This is called **structured output** - forcing the model to organize its response in a predictable way makes it easier to:
- Extract the final answer automatically
- Evaluate if the reasoning is correct
- Check if the format is followed

### Breaking Down the Template

```python
TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""
```

This follows **Gemma's chat format**:
- `<start_of_turn>user` and `<end_of_turn>`: Marks the user's message
- `<start_of_turn>model`: Signals that it's the model's turn to respond

The model's response should then be:
```
<reasoning>
My step-by-step thinking here...
</reasoning>
<answer>
42
</answer>
```

**Why special tokens like `<reasoning>`?**
These are easy to detect with regex (pattern matching), making it simple to automatically grade the model's output. Without structure, parsing the answer from free-form text would be error-prone.

In [4]:
import functools
import gc
import os
from pprint import pprint
import re

import csv
import shutil

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
# from tunix.models.gemma3 import model as gemma_lib
# from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset



### Understanding the Data Pipeline

The `get_dataset` function is crucial - it transforms raw GSM8K data into the format our training loop expects.

**GSM8K Dataset Structure:**
- **question**: A math word problem
- **answer**: Step-by-step solution ending with `#### [final_number]`

Example:
```
Question: "Natalia sold clips to 48 of her friends in April..."
Answer: "Natalia sold 48/2 = <<48/2=24>>24 clips in May. 
         Natalia sold 48+24 = <<48+24=72>>72 clips altogether.
         #### 72"
```

**What the Function Does:**

1. **Downloads the data** from HuggingFace (or TFDS/Kaggle)
2. **Extracts the final numerical answer** by splitting on "####"
3. **Creates a prompt** using our template (system prompt + question)
4. **Shuffles the data** (seed=42 for reproducibility)
5. **Maps each example** to a dictionary with:
   - `prompts`: The complete formatted prompt for the model
   - `question`: Original question (for reward functions)
   - `answer`: Just the final number (for checking correctness)

**Why three separate fields?**
- `prompts`: What the model sees as input
- `question`: Human-readable question (for debugging/logging)
- `answer`: Ground truth for computing rewards

This separation allows the reward functions to check correctness without needing to parse the entire prompt.

## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

### Creating Train/Val/Test Splits

This cell downloads the GSM8K dataset and prepares it for training.

**What's happening:**
1. Load the "train" split from GSM8K (~7,473 examples)
2. Batch it into groups of 4 (TRAIN_MICRO_BATCH_SIZE)
3. Take only NUM_BATCHES = 3738 batches (3738 √ó 4 = 14,952 examples... wait, that's more than 7,473!)

**The Output: `(1869, 0, 100)`**
- Train: 1869 batches
- Validation: 0 batches (we're using TRAIN_FRACTION = 1.0, meaning no validation split)
- Test: 100 batches

**Wait, why 1869 and not 3738?**
The dataset has 7,473 examples √∑ 4 = 1,868.25 batches. We can't have a partial batch, so we get 1,869 (or 1,868). The NUM_BATCHES = 3738 was set higher than the actual dataset size, so we just use what we have.

**Why no validation set?**
When `TRAIN_FRACTION = 1.0`, we use ALL data for training. This is common when:
- You want maximum training data
- You're evaluating on a separate test set anyway
- You're monitoring training metrics (reward, KL divergence) instead

**Test set:**
The test split has 1,319 examples. We're only using 100 batches (400 examples) for faster evaluation.

**The `.repeat(NUM_EPOCHS)`:**
If we wanted to train for multiple epochs, this would repeat the dataset. Here, NUM_EPOCHS = 1, so each example is seen once.

In [5]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 4

# === other GRPO configs ===
# The number of iterations per batch (ùúá in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (ùõΩ) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (ùúÄ in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 4
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 3738
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 100

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

### Examining a Training Batch

Let's see what the model actually sees during training.

**Understanding the Output:**

Each batch contains 3 arrays (all with 4 elements, one per example in the batch):

1. **`answer`**: Ground truth final numbers
   ```python
   array(['3', '34', '300', '35'], dtype='<U3')
   ```
   These are the correct answers we'll check against.

2. **`prompts`**: Complete formatted prompts
   ```
   <start_of_turn>user
   You are given a problem... [system prompt]
   
   Maria has 4 dimes, 4 quarters... [question]
   <end_of_turn>
   <start_of_turn>model
   ```
   This is what gets fed to the model. Notice it ends with `<start_of_turn>model` - the model's job is to continue from here.

3. **`question`**: Original questions (for logging/debugging)
   ```
   "Maria has 4 dimes, 4 quarters, and 7 nickels..."
   ```

**Key Observation:**
The prompt includes:
- System instruction (how to format the response)
- The actual math problem
- Gemma's chat template markers

The model must:
1. Understand the problem
2. Reason through it (in `<reasoning>` tags)
3. Compute the answer
4. Format it properly (in `<answer>` tags)

This is significantly harder than just giving the answer - it requires structured reasoning!

## Utility functions

In [6]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

### Understanding Checkpoint Conversion

This cell performs a critical but confusing operation: **re-saving the model checkpoint in a different format**.

**Why is this necessary?**

The pre-trained Gemma model from Kaggle uses one checkpoint format. The Flax NNX library (which Tunix uses) expects a different format. This cell:

1. **Loads the Kaggle checkpoint** into a temporary model
2. **Extracts the model state** (all the weights/parameters)
3. **Saves it in NNX-compatible format** to `INTERMEDIATE_CKPT_DIR`
4. **Deletes the temporary model** to free memory (important on TPUs!)

**The `gc.collect()` call:**
Explicit garbage collection. TPUs have limited memory, so we aggressively clean up after ourselves.

**Why the warnings?**
- "StandardCheckpointHandler expects a target tree..." - This is because we're doing a blind save without type checking
- "Could not find the credentials file..." - Google Cloud credentials warning (safe to ignore on Kaggle)

Think of this as: **"Translating the model from Kaggle-speak to Tunix-speak"**

In [7]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

### Deep Dive: Model Loading and LoRA Application

These two functions are the heart of how we set up our trainable model.

---

## `get_gemma_ref_model()` - Loading the Base Model

```python
mesh = jax.make_mesh(*MESH)
```
Creates a "mesh" - a logical arrangement of TPU devices. Think of it as: "Here's how we'll split our computation across 4 TPU cores."

```python
abs_gemma: nnx.Module = nnx.eval_shape(
    lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config)
)
```
**`eval_shape` is key!** It creates a "ghost" model - just the structure and shapes, without actually allocating memory for the weights. This is useful for:
- Planning how to distribute the model across devices
- Determining the sharding strategy before loading heavy weights

```python
abs_state = jax.tree.map(
    lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
    abs_state,
    nnx.get_named_sharding(abs_state, mesh),
)
```
This creates a "specification" for each parameter:
- Shape: How big is this tensor?
- Dtype: bfloat16 (efficient 16-bit float for TPUs)
- Sharding: Which TPU core(s) should hold this parameter?

```python
restored_params = checkpointer.restore(ckpt_path, target=abs_state)
```
Finally loads the actual weights from disk, placing them on the correct TPU cores according to our sharding plan.

---

## `get_lora_model()` - Adding LoRA Layers

```python
lora_provider = qwix.LoraProvider(
    module_path=(
        ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
        ".*attn_vec_einsum"
    ),
    rank=RANK,
    alpha=ALPHA,
)
```

**What layers get LoRA?** 
The regex pattern selects specific layers:
- `q_einsum`, `kv_einsum`, `attn_vec_einsum`: Attention layers (Query, Key, Value projections)
- `gate_proj`, `down_proj`, `up_proj`: MLP (Feed-forward) layers

**Why these layers?**
Research has shown that attention and MLP layers are where most "learning" happens. Adding LoRA to these gives the most bang for your buck.

```python
lora_model = qwix.apply_lora_to_model(base_model, lora_provider, **model_input)
```

**What does LoRA actually do to a layer?**

Original layer: `y = Wx`
LoRA layer: `y = Wx + BAx`

Where:
- `W` is the original weight matrix (frozen, not trained)
- `B` and `A` are small matrices (rank=64 in our case)
- Only `B` and `A` are trained

**Math insight:**
If W is 1000√ó1000 (1M params), B is 1000√ó64 and A is 64√ó1000 = 128K params.
We're training 12.8% as many parameters, but can still learn complex adaptations.

```python
with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)
```

This ensures our LoRA parameters are properly distributed across TPU cores, just like the base model parameters.

We use OpenAI's [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which comprises grade school math word problems.

### Loading the Reference Model

This cell loads the **frozen reference model** - the original Gemma model that we won't modify.

**Why do we need a reference model?**

In GRPO (and PPO), we use the reference model to:
1. **Compute KL divergence**: Measure how much the policy has changed
2. **Prevent catastrophic forgetting**: The model shouldn't "forget" how to write coherent text
3. **Regularize training**: Penalize the model if it drifts too far from the original

Think of it as: "Here's what you used to be. Don't change too much from this."

The warning message about the checkpoint handler is normal - it's just saying "I'm loading without checking the exact structure first."

### Creating the Policy Model (with LoRA)

This cell applies LoRA to create our **trainable policy model**.

**What we have now:**
- `ref_model`: Original Gemma 3 1B (frozen, ~1B parameters)
- `lora_policy`: Gemma 3 1B + LoRA adapters (only LoRA params are trained, maybe ~10-50M parameters)

**The commented-out `nnx.display(lora_policy)`:**
Would show the entire model structure - useful for debugging but very verbose. You'd see each layer, including which ones have LoRA adapters.

At this point:
- Both models are loaded into TPU memory
- Both are sharded across 4 TPU cores
- The reference model will stay frozen
- The policy model's LoRA parameters will be updated during training

This is memory-efficient: we're not storing two full models, just one model + small LoRA adapters.

In [8]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def _load_from_tfds(data_dir: str, split: str):
  import tensorflow_datasets.text.gsm8k
  return tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )


def download_kaggle_dataset(target_dir="./data/gsm8k"):
  os.makedirs(target_dir, exist_ok=True)
  src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
  src = Path(src)
  dst = Path(target_dir)

  for csv_file in src.glob("*.csv"):  # match all CSV files
    shutil.copy2(csv_file, dst / csv_file.name)
    print(f"Copied {csv_file.name} ‚Üí {dst/csv_file.name}")
  return target_dir


def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  if source == "tfds":
    import tensorflow_datasets.text.gsm8k
    data = tfds.data_source(
        "gsm8k",
        split=split,
        data_dir=data_dir,
        builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
        download=True,
    )

  elif source == "kaggle":
    kaggle_dir = download_kaggle_dataset(data_dir)
    file_name = "main_" + split + ".csv"
    csv_path = os.path.join(kaggle_dir, file_name)  # adjust filename if needed

    data = []
    with open(csv_path, newline="", encoding="utf-8") as csvfile:
      reader = csv.DictReader(csvfile)
      for row in reader:
        data.append({
            "question": row["question"],
            "answer": row["answer"],
        })

  elif source == "huggingface":    
    os.environ["HF_HUB_DISABLE_XET"] = "1"
    data = load_dataset("gsm8k", "main", split=split)
      
  else:
    raise ValueError(f"Unknown source: {source}")

  def _as_text(v):
    return v if isinstance(v, str) else v.decode("utf-8")

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=_as_text(x["question"]),
              ),
              # passed to reward functions
              "question": _as_text(x["question"]),
              # passed to reward functions
              "answer": extract_hash_answer(_as_text(x["answer"])),
          }
      )
  )
  return dataset

## Understanding Reward Functions

Reward functions are the **"teachers"** in GRPO. They score each model response, and the model learns to maximize these scores.

### Why Multiple Reward Functions?

We use 4 different reward functions that each capture different aspects of a good response:
1. **Format (Exact)**: Did you follow instructions perfectly?
2. **Format (Approximate)**: Did you at least try to follow the format?
3. **Answer Correctness**: Is the final answer right?
4. **Number Extraction**: Can we extract a number and is it correct?

The **total reward** is the sum of all these individual rewards. This is called **reward shaping** - we guide the model toward multiple goals simultaneously.

---

### The Format-Checking Regex

```python
match_format = re.compile(
    rf"^[\s]{{0,}}"                              # Optional whitespace at start
    rf"{reasoning_start}.+?{reasoning_end}.*?"   # <reasoning>...something...</reasoning>
    rf"{solution_start}(.+?){solution_end}"      # <answer>...capture this...</answer>
    rf"[\s]{{0,}}$",                             # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL,
)
```

**What this regex does:**
- `^[\s]{0,}`: Start of string, maybe some whitespace
- `{reasoning_start}.+?{reasoning_end}`: There must be reasoning tags with SOMETHING inside
- `.*?`: Any characters (non-greedy) between reasoning and answer
- `{solution_start}(.+?){solution_end}`: Answer tags with something inside (captured in group 1)
- `[\s]{0,}$`: End of string, maybe some whitespace

**The `.+?` (non-greedy):**
- `.+` means "one or more of any character"
- The `?` makes it "non-greedy" - it captures the smallest possible match
- This prevents it from accidentally matching too much

**Flags:**
- `MULTILINE`: `^` and `$` match line beginnings/endings, not just string
- `DOTALL`: `.` matches newline characters too (important for multi-line reasoning)

**Test Output:**
The match object shows it successfully found the pattern in the test string.

We split the dataset set into train and test sets as usual.

### Reward Function 1: Exact Format Match

```python
def match_format_exactly(prompts, completions, **kwargs):
    return [
        0 if match_format.search(response) is None else 3.0
        for response in completions
    ]
```

**What it does:**
- For each response in the batch, check if it matches the format regex
- If yes: **+3.0 points**
- If no: **0 points**

**Why 3.0 points?**
The reward magnitudes are arbitrary but important:
- Format matching gets 3.0
- Answer correctness gets 3.0 (defined later)
- This makes them equally important

The model learns: "Following the exact format is AS valuable as getting the right answer."

**The function signature:**
All reward functions have the same signature:
- `prompts`: Input prompts (for context)
- `completions`: Model's generated responses
- `**kwargs`: Extra info (like correct answers, questions)
- Returns: List of scores, one per completion

In [9]:
# source = input("Choose data source [tfds/kaggle]: ").strip().lower()
source = "huggingface"

if source not in ("tfds", "kaggle", "huggingface"):
  print("Invalid choice. Defaulting to 'tfds'.")
  source = ""

print(f"Using data source: {source}")

dataset = get_dataset(TRAIN_DATA_DIR, "train", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_BATCHES
]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

dataset_lengths = (
    len(train_dataset),
    len(val_dataset) if val_dataset is not None else 0,
    len(test_dataset),
)
print(f"dataset contains {dataset_lengths} of batches")

Using data source: huggingface


README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

dataset contains (1869, 0, 100) of batches


### Reward Function 2: Approximate Format Match

This is a **softer** reward - it gives partial credit for trying to follow the format.

```python
score += 0.5 if response.count(reasoning_start) == 1 else -0.5
```

**Scoring logic:**
- +0.5 if exactly ONE `<reasoning>` tag (good!)
- -0.5 if zero or more than one (bad!)
- Same for `</reasoning>`, `<answer>`, `</answer>`

**Maximum score:** +2.0 (all 4 tags present exactly once)
**Minimum score:** -2.0 (no tags or multiple of each)

**Why penalize multiple tags?**
If the model outputs:
```
<reasoning>thinking</reasoning>
<reasoning>more thinking</reasoning>
<answer>5</answer>
```

That's wrong format! We want exactly one reasoning section.

**Why this in addition to exact match?**
- Exact match is binary (all or nothing)
- Approximate gives gradient signal even for partial success
- Model gets feedback like: "You got 3 out of 4 tags right, keep trying"

This **dense reward signal** helps the model learn faster.

Let's see how one batch of the training dataset looks like!


### Reward Function 3: Answer Correctness

This is the most important reward function - did the model get the right answer?

**Scoring breakdown:**
- **Exact match**: +3.0 (perfect!)
- **Match after stripping whitespace**: +1.5 (close, just formatting issue)
- **Within 10% of correct (ratio 0.9-1.1)**: +0.5 (close numerically)
- **Within 20% (ratio 0.8-1.2)**: +0.25 (getting warmer)
- **Otherwise**: -1.0 (wrong answer = penalty)
- **Can't parse or format wrong**: -0.5 or 0

**Why graduated scoring?**

If the correct answer is 100:
- Model says "100": +3.0
- Model says " 100 ": +1.5
- Model says "105": +0.5 (within 10%)
- Model says "115": +0.25 (within 20%)
- Model says "200": -1.0 (way off)

This gives the model useful feedback. "You said 95 but the answer was 100" is better than "You said 50 but the answer was 100". The ratio scoring captures this.

**The try/except block:**
Sometimes the model outputs non-numeric text in the answer field. The try/except handles this gracefully by penalizing but not crashing.

**Why penalize wrong answers (-1.0)?**
Without penalties, the model might:
- Always output something in the answer field (to get format points)
- Not care about correctness (no penalty for being wrong)

The penalty says: "Getting an answer wrong is BAD, not just neutral."

In [10]:
for ele in train_dataset[:1]:
  pprint(ele)

{'answer': array(['3', '34', '300', '35'], dtype='<U3'),
 'prompts': array(['<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nMaria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?<end_of_turn>\n<start_of_turn>model',
       '<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nA wildlife team is monitoring the number of birds in a park. There are 3 blackbirds in each of the park‚Äôs 7 trees. There are also 13 magpies roaming around the park. How many birds are in the park in total?<end_of_turn>\n<start_of_turn>model',
 

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

### Reward Function 4: Number Extraction

Sometimes the model's answer field contains text, not just a number:
```
<answer>The final answer is 42.</answer>
```

This regex extracts the first number it finds:
```python
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",  # Find digits/decimals after <answer>
    flags=re.MULTILINE | re.DOTALL
)
```

**Example:** `"<answer>  0.34  </answer>"` ‚Üí extracts `"0.34"`

This reward function:
1. Extracts the number from the answer field
2. Compares it to the ground truth
3. Awards +1.5 for exact match, 0 otherwise

**Why is this separate from `check_answer`?**

`check_answer` requires the answer field to be EXACTLY the number. This function is more forgiving - it extracts the number even if there's surrounding text.

**The print statements:**
These print one example per batch during training - useful for seeing what the model is generating and how rewards are being assigned. It's debugging output.

In [11]:
# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle‚Ä¶

This code snippet serves as a workaround to re-save the pre-trained model checkpoint from Kaggle into a local format that is compatible with the [Flax NNX](https://flax.readthedocs.io/en/stable/why.html) library. Because the original checkpoint has parameter names and tensor structures that don't match the target NNX model architecture, it cannot be loaded directly.

We first load the original weights into a temporary model instance, then extract and re-save the model's state into a new, properly formatted local checkpoint, which can then be successfully loaded by the final sharded NNX model.

### Understanding the `generate()` Function

This helper function takes a question (or batch of questions) and generates responses using our model.

**Key parameters:**

- `temperature=0.7`: Controls randomness. Higher = more creative/diverse
- `top_k=50`: Only sample from the 50 most likely next tokens
- `top_p=0.95`: Nucleus sampling - sample from tokens comprising 95% of probability mass
- `max_generation_steps=768`: Maximum tokens to generate (longer than training's 512 for evaluation)
- `eos_tokens=[1, 106]`: Stop generating when these token IDs are seen
  - Token 1: End of sequence
  - Token 106: End of turn (Gemma's chat format)

**The seed parameter:**
If you provide a seed, you get reproducible outputs. Same seed = same random choices = same output.

**Why 768 tokens for generation?**
Evaluation might want longer, more complete responses. During training we use 512 for memory efficiency.

**The echo=False:**
Don't include the input prompt in the output - just return the model's response.

In [12]:
!rm /tmp/content/intermediate_ckpt/* -rf

!rm /tmp/content/ckpts/* -rf

model_family = "gemma3"
if model_family == "gemma3":
  MODEL_CP_PATH = params.GEMMA3_1B_IT
  config = model.ModelConfig.gemma3_1b()
  gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
  tokenizer = params.create_tokenizer()

  checkpointer = ocp.StandardCheckpointer()
  _, state = nnx.split(gemma)
  checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
  checkpointer.wait_until_finished()
  # Delete the intermediate model to save memory.
  del params
  del gemma
  del state
  gc.collect()

E0000 00:00:1762646469.959597      12 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238
E1109 00:01:43.016696    1110 google_auth_provider.cc:188] Could not find the credentials file in the standard gcloud location [/root/.config/gcloud/application_default_credentials.json]. You may specify a credentials file using $GOOGLE_APPLICATION_CREDENTIALS, or to use Google application default credentials, run: gcloud auth application-default login


### Understanding the `evaluate()` Function

This comprehensive evaluation function measures three things:
1. **Exact accuracy**: How many answers are exactly correct?
2. **Partial accuracy**: How many are within 10% of correct?
3. **Format accuracy**: How many follow the correct format?

**Multiple passes (`num_passes`):**
The function can run multiple generations per question. This is useful for:
- Pass@k evaluation: "Did ANY of k attempts get it right?"
- Reducing variance from randomness

Here we use `num_passes=1` for simplicity.

**The evaluation loop:**
```python
for question, multiple_call_response, answer in zip(...):
    # For each question, check all generated responses
    # Count as correct if ANY response got it right
```

**Why "any"?**
If the model gets it right on any attempt, that counts as success. This is standard for RL evaluation - we care about capability, not consistency.

**The progress printing:**
Every 10 examples, it prints:
- `corr`: Number of exactly correct answers
- `total`: Total examples evaluated
- `corr / total * 100`: Accuracy percentage
- `partially_corr / total * 100`: Partial accuracy
- `corr_format / total * 100`: Format adherence

This lets you watch progress in real-time and catch issues early.

### Creating the Sampler

The `Sampler` wraps our model to make generation easy.

**What is a KV Cache?**

When generating text, the model processes each token sequentially. For efficiency, it caches intermediate computations (the "Key" and "Value" vectors in attention). This cache grows as the sequence gets longer.

```python
cache_config=sampler_lib.CacheConfig(
    cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,  # Total sequence length
    num_layers=model_config.num_layers,                           # 26 layers in Gemma 1B
    num_kv_heads=model_config.num_kv_heads,                       # Number of attention heads
    head_dim=model_config.head_dim,                               # Dimension per head
)
```

**Why +256?**
Buffer space. Sequences might be slightly longer than expected.

**The sampler object:**
Once created, you can call `sampler(input_strings, ...)` to generate text from the model. It handles:
- Tokenization (text ‚Üí numbers)
- Running the model
- Sampling next tokens
- Detokenization (numbers ‚Üí text)

### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices.
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

### Pre-Training Baseline Evaluation

**Why evaluate before training?**

This gives us a **baseline** - how good is the model BEFORE we teach it anything? We need this to measure improvement.

**Understanding the Results:**

```
corr=58, total=400, accuracy=14.5%, partial_accuracy=16.0%, format_accuracy=46.25%
```

Let's interpret each metric:

**Accuracy: 14.5%**
- Out of 400 test questions, only 58 got the exact correct answer
- This is our PRIMARY metric - did the model solve the problem?
- 14.5% is actually not bad for a 1B parameter model on grade-school math

**Partial Accuracy: 16.0%**
- 64 out of 400 answers were within 10% of correct
- Slightly higher than exact accuracy
- Shows the model is "in the ballpark" sometimes even when not exact

**Format Accuracy: 46.25%**
- Only 185 out of 400 responses followed the `<reasoning>...<answer>...` format
- Less than half the time!
- This is a major issue - the model doesn't know our special format

**Why is format accuracy so low?**
The base model was never trained on our specific format. It might:
- Just answer directly: "The answer is 42"
- Use its own format: "Let me think... therefore 42"
- Not use any tags at all

**The "SKIPPED" messages:**
Some responses couldn't be parsed as numbers (non-numeric text in answer field). This is another format issue.

**Key Insight:**
Our baseline shows the model:
- Has SOME math capability (14.5% correct)
- But DOESN'T know our format (only 46% compliance)

Training should improve both!

In [13]:
from tunix.models.gemma3 import params

def get_gemma_ref_model(ckpt_path):
  mesh = jax.make_mesh(*MESH)
  model_config = model.ModelConfig.gemma3_1b()
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, config)
  )

  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

### Setting Up Training Infrastructure

Now we configure the training loop. There are two main configurations:

**1. Checkpointing Options:**
```python
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=500,  # Save every 500 training steps
    max_to_keep=4             # Keep only the 4 most recent checkpoints
)
```

This saves the model periodically so you can:
- Resume training if it crashes
- Go back to earlier versions if something goes wrong
- Evaluate intermediate checkpoints

**2. Metrics Logging:**
```python
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo",  # Where to save logs
    flush_every_n_steps=20                         # Write to disk every 20 steps
)
```

This logs training metrics to TensorBoard format. You can visualize:
- Loss curves
- Reward trends
- KL divergence
- Learning rate

Weights & Biases (wandb) will also track these metrics in a nicer web interface.

### Understanding the Optimizer Setup

This cell creates the **optimizer** - the algorithm that updates model weights.

**AdamW Optimizer:**
```python
optimizer = optax.adamw(
    learning_rate=...,      # How big are the update steps
    b1=0.9,                 # Momentum for gradients (0.9 = use 90% of previous direction)
    b2=0.99,                # Momentum for squared gradients (variance estimation)
    weight_decay=0.1,       # L2 regularization to prevent overfitting
)
```

**Why AdamW?**
- Most popular optimizer for transformers
- Adaptive learning rates per parameter
- The "W" means weight decay is decoupled (important technical detail)

**Learning Rate Schedule:**
```python
warmup_cosine_decay_schedule(
    init_value=0.0,          # Start at LR = 0
    peak_value=LEARNING_RATE, # Ramp up to 3e-6
    warmup_steps=187,         # Takes 10% of training to reach peak
    decay_steps=MAX_STEPS,    # Then decay over remaining training
    end_value=0.0,            # End at LR = 0
)
```

**Visual:**
```
LR
^
|      /\
|     /  \
|    /    \
|   /      \
|  /        \
| /          \_____
+--------------------> steps
  warmup  decay
```

**Why this schedule?**
- **Warmup**: Start slow to avoid wild early updates
- **Peak**: Maximum learning
- **Decay**: Fine-tune with smaller adjustments

**Gradient Clipping:**
```python
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=0.1),  # First: clip gradients
    optimizer,                                   # Then: apply AdamW
)
```

If the gradient magnitude exceeds 0.1, scale it down. This is CRITICAL for RL:
- Prevents exploding gradients
- Stabilizes training
- Keeps KL divergence in check

### Configuring the Training Cluster and GRPO Algorithm

This is where we bring everything together.

---

## ClusterConfig - Orchestrating Distributed Training

```python
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,      # The trainable policy model
        rl_cluster_lib.Role.REFERENCE: mesh,  # The frozen reference model
        rl_cluster_lib.Role.ROLLOUT: mesh,    # The model that generates responses
    },
    rollout_engine='vanilla',    # Standard generation method
    offload_to_cpu=False,        # Keep everything on TPU (faster but more memory)
    training_config=...,
    rollout_config=...,
)
```

**What are these roles?**
- **ACTOR**: The model whose parameters we update (policy model with LoRA)
- **REFERENCE**: The frozen baseline for KL divergence calculation
- **ROLLOUT**: The model that generates responses during training

In our setup, ACTOR and ROLLOUT are the same model (the LoRA policy).

**Why separate meshes?**
In larger setups, you might have different hardware for each role. Here they all share the same 4-TPU mesh.

---

## RolloutConfig - How to Generate Training Samples

```python
rollout_config=base_rollout.RolloutConfig(
    max_tokens_to_generate=512,      # Generate up to 512 tokens
    max_prompt_length=256,           # Input prompts up to 256 tokens
    kv_cache_size=1024,              # Cache size for efficient generation
    temperature=0.9,                  # High temperature for diversity
    top_p=1.0,                        # Use all probability mass
    top_k=50,                         # Sample from top 50 tokens
    eos_tokens=[1, 106],             # Stop tokens
)
```

These are the generation settings used DURING TRAINING (not evaluation). Note the high temperature (0.9) for diversity.

---

## GRPOConfig - The Core Algorithm Settings

```python
grpo_config = GRPOConfig(
    num_generations=4,    # G: Generate 4 responses per prompt
    num_iterations=1,     # Œº: One gradient update per batch
    beta=0.08,            # Œ≤: KL penalty coefficient
    epsilon=0.2,          # Œµ: PPO clipping range
)
```

These directly correspond to the GRPO paper:
- **num_generations (G)**: The "group" size. We generate 4 different solutions per problem.
- **num_iterations (Œº)**: How many times to reuse the same batch. Usually 1.
- **beta (Œ≤)**: KL divergence penalty weight. Balances reward optimization vs staying close to reference.
- **epsilon (Œµ)**: Clipping range for policy ratio. Prevents too-large updates.

Now we load reference and policy Gemma models using the Flax NNX library and display their structures.

### Creating the GRPO Trainer

Now we instantiate the actual training objects.

**RLCluster - The Training Infrastructure:**
```python
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,      # Our trainable model
    reference=ref_model,    # The frozen baseline
    tokenizer=tokenizer,    # Converts text ‚Üî tokens
    cluster_config=cluster_config,
)
```

This bundles together:
- The models
- The tokenizer
- The distributed computing setup
- The generation settings

**GRPOLearner - The Training Algorithm:**
```python
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,        # Reward 1: Exact format
        match_format_approximately,  # Reward 2: Partial format
        check_answer,                # Reward 3: Answer correctness
        check_numbers,               # Reward 4: Number extraction
    ],
    grpo_config=grpo_config,
)
```

**The reward functions are passed as a list!**
During training, for each generated response:
1. All 4 reward functions are called
2. Their scores are summed
3. This total becomes the reward for that response

**Example:**
```
Response: "<reasoning>Let me think...</reasoning><answer>42</answer>"
- match_format_exactly: +3.0
- match_format_approximately: +2.0
- check_answer: +3.0 (if 42 is correct)
- check_numbers: +1.5
Total reward: 9.5
```

**The Weights & Biases prompt:**
The output shows W&B initialization. You're asked to either:
- Use anonymous mode (quick experiment, no account needed)
- Log in with your API key (persistent tracking)

W&B will create a dashboard showing training metrics in real-time.

In [14]:
# Reference model
if model_family == "gemma3":
  ref_model, mesh, model_config = get_gemma_ref_model(
      ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
  )



### The Training Loop - Where the Magic Happens

```python
with mesh:
    grpo_trainer.train(train_dataset)
```

This single line does ALL the training! Let me break down what happens inside:

---

## What Happens Each Training Step

**1. Sample a Batch**
- Get 4 prompts (questions) from the dataset

**2. Generate Multiple Responses (ROLLOUT)**
For each of the 4 prompts:
- Generate NUM_GENERATIONS=4 different responses
- Total: 4 prompts √ó 4 generations = 16 responses

**3. Compute Rewards**
For each of the 16 responses:
- Run all 4 reward functions
- Sum the scores
- Get a single reward value per response

**4. Calculate Advantages (GRPO's Key Innovation)**
For each prompt's 4 responses:
- Compute mean reward: `Œº = mean([r1, r2, r3, r4])`
- Compute std: `œÉ = std([r1, r2, r3, r4])`
- Advantage for response i: `A_i = (r_i - Œº) / œÉ`

This is **group-relative** - we compare within the group, not to absolute scores.

**5. Compute Policy Gradient**
For each response:
- What's the probability the model assigned to this response?
- If advantage > 0: increase this probability
- If advantage < 0: decrease this probability

**6. Add KL Penalty**
- Compute KL divergence between policy and reference
- Subtract Œ≤ √ó KL from the objective
- This prevents the model from changing too drastically

**7. Clip the Update (PPO-style)**
- If the probability ratio would change by more than ¬±Œµ, clip it
- This stabilizes training

**8. Update Parameters**
- Compute gradients
- Clip gradients (MAX_GRAD_NORM=0.1)
- Apply AdamW optimizer
- Only LoRA parameters are updated!

---

## Interpreting the Training Output

The output shows progress logs with various metrics:

**Key Metrics to Watch:**

- **`reward/score`**: Average total reward. Should increase over time!
- **`reward/check_answer`**: Correctness reward. Want this to go up.
- **`reward/match_format_exactly`**: Format compliance. Want this to go up.
- **`kl_divergence`**: How much the model has changed. Should stay reasonable (< 1-5).
- **`loss`**: The optimization objective. Should generally decrease.

**Warning Signs:**
- KL divergence exploding (>10): Model is changing too much
- Rewards not improving: Reward functions might not be informative
- Loss oscillating wildly: Learning rate might be too high

**The "START/END" print blocks:**
These come from the `check_numbers` reward function - showing one example per batch to help debug.

---

## Training Time

The first few steps are slow because:
1. JIT compilation (JAX compiles Python to optimized code)
2. Cache warming
3. Memory allocation

After warmup, each step should be faster (maybe 1-2 minutes per step on v5e-8 TPU).

Total training time: ~1869 steps. At 2 min/step = ~62 hours. But with optimizations and parallelization, it should be much faster.

In [15]:
# Policy model
lora_policy = get_lora_model(ref_model, mesh=mesh)
# nnx.display(lora_policy)

## Post-Training Evaluation

Now we evaluate our trained model to see how much it improved!

### Loading the Trained Checkpoint

```python
# Find the latest checkpoint
latest_step = ...  # Finds "1869" (the final step)

# Load just the LoRA parameters
trained_lora_params = checkpointer.restore(
    trained_ckpt_path, 
    target=abs_params  # Only load LoRA params, not the full model
)

# Update our model with the trained weights
nnx.update(lora_policy, trained_lora_params)
```

**Why only load LoRA parameters?**
- The base model (Gemma) didn't change
- Only the LoRA adapter weights were trained
- This is much faster and uses less memory

**The `wandb.init()` call:**
There's a logging bug in Tunix that requires initializing W&B again. This is just a workaround.

**The warning about sharding:**
"Sharding info not provided when restoring" - The checkpoint was saved with specific TPU sharding info. When loading, it's reconstructing that. This is normal for distributed training.

### Creating a New Sampler with Trained Weights

We need to create a new sampler that uses our newly trained LoRA parameters. The settings are the same as before - we just updated the model weights.

### Final Evaluation Results - Interpreting the Improvement

**Post-Training Results:**
```
corr=173, total=400, accuracy=43.25%, partial_accuracy=46.0%, format_accuracy=95.5%
```

Let's compare to our pre-training baseline:

| Metric | Before Training | After Training | Improvement |
|--------|----------------|----------------|-------------|
| **Exact Accuracy** | 14.5% | **43.25%** | **+28.75%** |
| **Partial Accuracy** | 16.0% | **46.0%** | **+30.0%** |
| **Format Accuracy** | 46.25% | **95.5%** | **+49.25%** |

---

## What These Results Mean

### Accuracy: 14.5% ‚Üí 43.25% (3x improvement!)
- Before: 58 out of 400 correct
- After: 173 out of 400 correct
- The model now solves nearly half of grade-school math problems correctly!
- This is a **massive improvement** from just ~1869 training steps

### Partial Accuracy: 16.0% ‚Üí 46.0%
- More answers are "in the ballpark"
- The model's numerical reasoning has improved significantly
- Even when not exactly right, it's much closer than before

### Format Accuracy: 46.25% ‚Üí 95.5%
- This is perhaps the most dramatic change
- Before: Less than half followed the required format
- After: 95.5% follow the `<reasoning>...<answer>...` structure
- The model learned our custom format almost perfectly!

---

## Why Did This Work?

**1. Reward Shaping:**
- We rewarded both format AND correctness
- The model learned both skills simultaneously
- The approximate format reward gave dense feedback

**2. Group Relative Learning:**
- Comparing within groups is more informative than absolute scores
- The model learned "this type of solution is better than that type"
- Even if all solutions are wrong, it learns which direction to improve

**3. KL Regularization:**
- Kept the model from "forgetting" how to generate coherent text
- Prevented reward hacking (gaming the format without improving reasoning)
- Stable training with gradual improvements

**4. Gradient Clipping:**
- Prevented catastrophic updates
- Allowed small, consistent improvements

---

## Is This a Good Model?

**43% accuracy might seem low**, but consider:
- This is a 1B parameter model (small by today's standards)
- GSM8K is genuinely challenging (requires multi-step reasoning)
- State-of-the-art models (GPT-4, etc.) get ~90%+ but have 100x+ more parameters
- We trained for only ONE epoch with limited compute

**What could improve results:**
- More training epochs
- Larger NUM_GENERATIONS (more diverse samples)
- Better reward functions (e.g., verify intermediate steps)
- Larger LoRA rank
- Longer training time
- Curriculum learning (start with easy problems)

---

## Key Takeaway

**GRPO successfully taught the model to:**
1. ‚úÖ Follow a specific output format (95.5% compliance)
2. ‚úÖ Reason through math problems (3x better accuracy)
3. ‚úÖ Produce structured, parseable outputs

This demonstrates the power of reinforcement learning for improving specific capabilities in language models!

## Summary and Next Steps

**What We Accomplished:**
- Implemented GRPO (Group Relative Policy Optimization) from scratch
- Fine-tuned Gemma 3 1B on GSM8K math problems
- Improved accuracy from 14.5% to 43.25% (3x improvement)
- Achieved 95.5% format compliance (from 46.25%)

**Key Concepts Learned:**
1. **GRPO**: RL algorithm that compares responses within groups
2. **LoRA**: Efficient fine-tuning by training small adapter layers
3. **Reward Shaping**: Multiple reward functions to guide learning
4. **KL Divergence**: Regularization to prevent model drift
5. **Distributed Training**: Using TPU mesh for parallel computation

**Potential Next Steps:**
- Increase training time (more epochs, more steps)
- Experiment with hyperparameters (beta, learning rate, NUM_GENERATIONS)
- Try different reward functions (reward chain-of-thought quality)
- Apply to other tasks (code generation, other reasoning tasks)
- Compare with other RL methods (PPO, DPO)

**Congratulations!** You've successfully trained a language model using reinforcement learning!

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, we extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [16]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

<re.Match object; span=(0, 54), match='<reasoning>Let me think!</reasoning><answer>2</an>

Give the model a reward of 3 points if the format matches exactly.

In [17]:
def match_format_exactly(prompts, completions, **kwargs):
  return [
      0 if match_format.search(response) is None else 3.0
      for response in completions
  ]

We also reward the model if the format of the output matches partially.

In [18]:
def match_format_approximately(prompts, completions, **kwargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [19]:
def check_answer(prompts, completions, answer, **kwargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  assert len(extracted_responses) == len(
      answer
  ), f"{extracted_responses} and {answer} have mismatching length"
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [20]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [21]:
def check_numbers(prompts, completions, answer, **kwargs):
  question = kwargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


We define a helper function to generate an answer, given a prompt.

In [22]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
      eos_tokens=[1,106],
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

Another helper function for evaluation.

In [23]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [24]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish.

In [25]:
# The evaluation might take up to couple of minutes to finish. Please be patient.

(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

  0%|          | 0/100 [00:00<?, ?it/s]

===> corr=1, total=10, corr / total * 100=10.0, partially_corr / total * 100=10.0, corr_format / total * 100=30.0
===> corr=3, total=20, corr / total * 100=15.0, partially_corr / total * 100=15.0, corr_format / total * 100=35.0
===> corr=6, total=30, corr / total * 100=20.0, partially_corr / total * 100=20.0, corr_format / total * 100=33.33333333333333
===> corr=7, total=40, corr / total * 100=17.5, partially_corr / total * 100=20.0, corr_format / total * 100=40.0
===> corr=7, total=50, corr / total * 100=14.000000000000002, partially_corr / total * 100=18.0, corr_format / total * 100=42.0
===> corr=9, total=60, corr / total * 100=15.0, partially_corr / total * 100=18.333333333333332, corr_format / total * 100=40.0
===> corr=12, total=70, corr / total * 100=17.142857142857142, partially_corr / total * 100=20.0, corr_format / total * 100=40.0
===> corr=14, total=80, corr / total * 100=17.5, partially_corr / total * 100=20.0, corr_format / total * 100=41.25
SKIPPED
===> corr=15, total=90

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [26]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [27]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [28]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1,106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

### Setting Up the GRPO Trainer

Now we initialize our system for training. First, we create an `RLCluster` instance, which brings together the **policy model (`actor`)**, a **reference model (`reference`)**, and a **tokenizer**. Our `actor` is a trainable LoRA model, while the `reference` is a fixed base model that we use to guide the training.

We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.

Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:

**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.

**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in.

In [29]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# GRPO Trainer
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)

[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!

In [30]:
with mesh:
  grpo_trainer.train(train_dataset)

Question: Maria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?
Answer: 3
Response: Okay, let's break this problem down.

<reasoning>
Maria initially had 4 dimes + 4 quarters + 7 nickels = 15 coins.  Her mom gives her 5 quarters, increasing her total to 15 + 5 = 20 coins.  The value of the dimes is 4 * $0.10 = $0.40. The value of the quarters is 4 * $0.25 = $1.00. The value of the nickels is 7 * $0.05 = $0.35.  The total value of her coins is $0.40 + $1.00 + $0.35 = $1.75.  After her mom gives her quarters, she now has 20 + 5 = 25 coins. The value of the quarters is now 5 * $0.25 = $1.25. The total value of her coins is $1.75 + $1.25 = $3.00. Therefore, Maria has $3.00. </answer>
</reasoning>
<answer>3.00</answer>
Extracted: 3.00


Actor Training:   0%|          | 0/3738 [00:00<?, ?step/s]

Question: Paddy's Confidential has 600 cans of stew required to feed 40 people. How many cans would be needed to feed 30% fewer people?
Answer: 420
Response: Okay, let's break this down.

<reasoning>
The problem states that Paddy's Confidential has 600 cans of stew to feed 40 people. We need to find out how many cans are needed to feed 30% fewer people. To calculate this, we first need to find the number of people needed to feed the reduced population. 30% of 40 is (30/100) * 40 = 0.3 * 40 = 12 people. Therefore, we need to feed 40 - 12 = 28 people. Since each can feeds 40 people, we need 28 / 40 = 0.7 cans. However, we need to consider that we're dealing with cans, and we can't have fractions of cans. We must round up to ensure we have enough. Therefore, we need 1 can. </reasoning>

<answer>1</answer>
Extracted: 1
Question: In Johnstown, the population is 80 and every single person drives by themselves to work. Each car on a road pollutes 10 pounds of carbon a year. A single bus pollu



0,1
actor/train/kl,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÜ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÉ‚ñá‚ñÑ‚ñÖ‚ñà‚ñÖ‚ñÑ‚ñÜ‚ñÑ‚ñÇ‚ñÖ
actor/train/loss,‚ñá‚ñÖ‚ñÜ‚ñÜ‚ñà‚ñÅ‚ñÖ‚ñÑ‚ñÖ‚ñÖ‚ñá‚ñá‚ñà‚ñÜ‚ñÖ‚ñà‚ñÖ‚ñÑ‚ñá‚ñÜ‚ñÜ‚ñÑ‚ñÖ‚ñá‚ñÑ‚ñÜ‚ñÖ‚ñÜ‚ñá‚ñÉ‚ñÉ‚ñà‚ñà‚ñÜ‚ñÜ‚ñÉ‚ñá‚ñÜ‚ñÖ‚ñá
actor/train/perplexity,‚ñÅ‚ñÖ‚ñÑ‚ñÇ‚ñÖ‚ñÉ‚ñÖ‚ñÉ‚ñÑ‚ñà‚ñÇ‚ñÇ‚ñÜ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñá‚ñÖ‚ñá‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÉ‚ñÖ‚ñÖ‚ñÜ‚ñÑ‚ñÖ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÑ‚ñÜ‚ñÉ‚ñÉ‚ñÑ
actor/train/step_time_sec,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ
actor/train/steps_per_sec,‚ñà‚ñá‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñà‚ñÅ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñÅ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñá‚ñá‚ñá‚ñÅ
actor/train/tflops_per_step,‚ñÅ
jax/core/compile/backend_compile_duration,‚ñÅ
jax/core/compile/jaxpr_to_mlir_module_duration,‚ñÅ
jax/core/compile/jaxpr_trace_duration,‚ñÅ
jax/orbax/write/replicated_array_gb,‚ñÅ

0,1
actor/train/kl,0.35771
actor/train/loss,0.1306
actor/train/perplexity,1.13951
actor/train/step_time_sec,111.62965
actor/train/steps_per_sec,0.00896
actor/train/tflops_per_step,14.0514
jax/core/compile/backend_compile_duration,1762646972.33086
jax/core/compile/jaxpr_to_mlir_module_duration,1762646970.82012
jax/core/compile/jaxpr_trace_duration,1762646968.38686
jax/orbax/write/replicated_array_gb,6e-05


## Evaluate

Let's evaluate our finetuned model!

In [31]:
# Load checkpoint first.
import re

# Find the latest checkpoint by listing directories in CKPT_DIR/actor
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
  for item in os.listdir(actor_ckpt_dir):
    if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
      step = int(item)
      if step > latest_step:
        latest_step = step

if latest_step == -1:
  raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval')  # logging bug workaround

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

Latest checkpoint step: 1869




In [32]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [33]:
# The evaluation might take up to couple of minutes to finish. Please be patient.
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

  0%|          | 0/100 [00:00<?, ?it/s]

===> corr=3, total=10, corr / total * 100=30.0, partially_corr / total * 100=30.0, corr_format / total * 100=90.0
===> corr=9, total=20, corr / total * 100=45.0, partially_corr / total * 100=45.0, corr_format / total * 100=90.0
===> corr=14, total=30, corr / total * 100=46.666666666666664, partially_corr / total * 100=46.666666666666664, corr_format / total * 100=93.33333333333333
===> corr=18, total=40, corr / total * 100=45.0, partially_corr / total * 100=52.5, corr_format / total * 100=95.0
===> corr=22, total=50, corr / total * 100=44.0, partially_corr / total * 100=50.0, corr_format / total * 100=96.0
===> corr=28, total=60, corr / total * 100=46.666666666666664, partially_corr / total * 100=51.66666666666667, corr_format / total * 100=95.0
===> corr=34, total=70, corr / total * 100=48.57142857142857, partially_corr / total * 100=52.85714285714286, corr_format / total * 100=95.71428571428572
===> corr=39, total=80, corr / total * 100=48.75, partially_corr / total * 100=52.5, corr_

With sufficient training, you should see that the percentages of correct model outputs have clearly gone up, which means our training worked.