# Multi-Domain GRPO Training with Gemma 3 (1B) using Tunix

This notebook presents a **complete, end-to-end training pipeline** for teaching a **Gemma 3 (1B)** language model to **explicitly show its reasoning** before producing a final answer.

The core objective is not just answer accuracy, but **reasoning transparency**.  
The model is trained to consistently generate outputs in the following structured format:

```

<reasoning>
step-by-step reasoning trace
</reasoning>
<answer>
final answer
</answer>
```

Training is performed using **Group Relative Policy Optimization (GRPO)** via **Tunix**, Google‚Äôs JAX-native post-training library. GRPO enables stable reinforcement-learning‚Äìstyle updates without requiring a separate value model, making it well-suited for **small models and limited TPU budgets**.

---

## What This Notebook Demonstrates

* Fine-tuning **Gemma 3 (1B)** on a **single Kaggle TPU session**
* Multi-domain reasoning training across:

  * Math
  * Logic
  * Basic science
  * Coding (non-executed, verifiable by structure)
* Strict enforcement of **reasoning + answer separation**
* Reward shaping for:

  * Format compliance
  * Refusal avoidance
  * Correctness (where applicable)
  * Conciseness and termination discipline
* A **phase-based curriculum** inspired by Open-R1 / R1-style training

This notebook is intentionally written to be:

* **Minimal** (no YAML, no hidden configs)
* **Debuggable** (explicit reward functions)
* **Reproducible** (single-session training)
* **Notebook-native** (no external orchestration)

---

## Scope and Constraints

* **Model**: Gemma 3 (1B-IT)
* **Framework**: Tunix (JAX)
* **Hardware**: Kaggle TPU (single session)
* **Max output length**: < 1K tokens
* **Language**: English only
* **No tool use or code execution during training**

Coding tasks are evaluated **without running unit tests**. Instead, the model is rewarded for producing **well-formed, logically correct Python code** that follows the required format, aligning with the hackathon‚Äôs LLM-as-a-judge evaluation setup.

---

## Why This Matters

Large reasoning models are expensive to train.
This notebook shows that **even a 1B-parameter model** can be taught to reason more transparently using:

* Careful reward design
* Curriculum-style training
* Strong format constraints

The result is a small, efficient model that **thinks before it answers**‚Äîmaking reasoning more accessible, interpretable, and reproducible.

In [None]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
# !pip install "google-tunix[prod]==0.1.5"

!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
# !pip install -U flax
!pip install flax==0.12.0

!pip install -q datasets wandb==0.22.0

In [None]:
import wandb, os
from kaggle_secrets import UserSecretsClient
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")

## Imports

In [None]:
import functools
import gc
import os
from pprint import pprint
import re

import csv
import shutil

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from datasets import load_dataset

## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [None]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 4

# === other GRPO configs ===
# The number of iterations per batch (ùúá in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (ùõΩ) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (ùúÄ in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 4
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
# NUM_BATCHES = 3738
NUM_BATCHES = 3
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 100

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Utility functions

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

In [None]:
REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
ANSWER_START = "<answer>"
ANSWER_END = "</answer>"

SYSTEM_PROMPT = f"""
You are given a problem.

Think step by step and write your reasoning between
{REASONING_START} and {REASONING_END}.

Then write the final answer between
{ANSWER_START} and {ANSWER_END}.

Do not write anything outside these tags.
""".strip()

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

## üì¶ Training Datasets

This notebook trains a reasoning model using a **mixed, multi-domain dataset** constructed from high-signal public benchmarks hosted on Hugging Face.  
Each example is tagged with a `domain` field, which is later used for **domain-aware reward computation** during GRPO training.

| Domain | Dataset | Source | Purpose |
|------|--------|--------|--------|
| Math | GSM8K | [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) | Numerical and word-problem reasoning |
| Code | MBPP (sanitized) | [google-research-datasets/mbpp](https://huggingface.co/datasets/google-research-datasets/mbpp) | Python program synthesis |
| Science | ARC-Easy | [allenai/ai2_arc](https://huggingface.co/datasets/allenai/ai2_arc) | Basic scientific reasoning |
| Logic | StrategyQA | [ChilleD/StrategyQA](https://huggingface.co/datasets/ChilleD/StrategyQA) | Yes / No logical reasoning |

### Why this dataset mix?

- All datasets are **verifiable or semi-verifiable**
- Suitable for **stable GRPO training** within a single Kaggle TPU session
- Covers **diverse reasoning styles** (numerical, symbolic, procedural)
- Inspired by early **Open-R1 / DeepSeek-R1 curriculum design**

Datasets are **mixed at sampling time**, not concatenated, ensuring each training batch contains heterogeneous reasoning tasks and encouraging generalizable reasoning behavior.


In [None]:
from datasets import load_dataset
import grain
import os
import time

# =========================
# GSM8K answer extractor
# =========================
def extract_hash_answer(text: str) -> str | None:
    if not isinstance(text, str):
        return None
    if "####" not in text:
        return None
    return text.split("####")[-1].strip()

# =========================
# DATASET BUILDER
# =========================
def get_dataset(split="train") -> grain.MapDataset:
    print(f"\nüöÄ Building mixed dataset | split = {split}")
    t0 = time.time()

    os.environ["HF_HUB_DISABLE_XET"] = "1"
    os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

    # ============
    # GSM8K (MATH)
    # ============
    print("üì• Loading GSM8K...")
    gsm = load_dataset("openai/gsm8k", "main", split=split)
    print(f"‚úÖ GSM8K loaded ({len(gsm)} samples)")

    gsm_ds = (
        grain.MapDataset.source(gsm)
        .map(lambda x: {
            "prompts": TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=str(x["question"]),
            ),
            "question": str(x["question"]),
            "answer": extract_hash_answer(str(x["answer"])),
            "domain": "math",
        })
        .filter(lambda x: x["answer"] is not None)
    )

    print("üß† GSM8K pipeline built")

    # ============
    # MBPP (CODE ‚Äî STRUCTURE ONLY, TPU SAFE)
    # ============
    print("üì• Loading MBPP...")
    mbpp = load_dataset("google-research-datasets/mbpp", "sanitized", split=split)
    print(f"‚úÖ MBPP loaded ({len(mbpp)} samples)")
    
    def mbpp_map(x):
        question = x.get("prompt") or x.get("text")
        if question is None:
            return None
    
        return {
            "prompts": TEMPLATE.format(
                system_prompt=SYSTEM_PROMPT,
                question=str(question),
            ),
            "question": str(question),
            "answer": None,
            "domain": "code",
        }
    
    mbpp_ds = (
        grain.MapDataset.source(mbpp)
        .map(mbpp_map)
        .filter(lambda x: x is not None)
    )
    
    print("üß† MBPP pipeline built (STRUCTURE-ONLY)")

    # ============
    # ARC Easy (SCIENCE)
    # ============
    print("üì• Loading ARC Easy...")
    arc = load_dataset("allenai/ai2_arc", "ARC-Easy", split=split)
    print(f"‚úÖ ARC Easy loaded ({len(arc)} samples)")

    def arc_prompt(x):
        choices = "\n".join(
            f"{l}. {t}"
            for l, t in zip(x["choices"]["label"], x["choices"]["text"])
        )
        return f"{x['question']}\n\nChoices:\n{choices}"

    arc_ds = grain.MapDataset.source(arc).map(lambda x: {
        "prompts": TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=arc_prompt(x),
        ),
        "question": x["question"],
        "answer": x["answerKey"],
        "domain": "science",
    })

    print("üß† ARC pipeline built")

    # =================
    # StrategyQA (LOGIC)
    # =================
    print("üì• Loading StrategyQA...")
    strategyqa = load_dataset("ChilleD/StrategyQA", split=split)
    print(f"‚úÖ StrategyQA loaded ({len(strategyqa)} samples)")

    def strategyqa_prompt(x):
        return f"Question: {x['question']}\nAnswer yes or no."

    strategyqa_ds = grain.MapDataset.source(strategyqa).map(lambda x: {
        "prompts": TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=strategyqa_prompt(x),
        ),
        "question": x["question"],
        "answer": "yes" if x["answer"] else "no",
        "domain": "logic",
    })

    print("üß† StrategyQA pipeline built")

    # ============
    # MIX
    # ============
    print("üß™ Mixing datasets...")
    mixed = (
        grain.MapDataset.mix(
            datasets=[gsm_ds, mbpp_ds, arc_ds, strategyqa_ds],
            weights=[0.35, 0.30, 0.20, 0.15],
        )
        .shuffle(seed=42)
    )

    print("‚úÖ Mixed dataset pipeline ready")
    print(f"‚è±Ô∏è Total build time: {time.time() - t0:.2f}s")

    return mixed

We split the dataset set into train and test sets as usual.

In [None]:
print("Using data source: huggingface (mixed domains)")

import os
import time
import grain
import itertools
from collections import Counter

os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

t0 = time.time()

print("\nüöß Building TRAIN dataset (STREAMING)...")
train_ds_raw = get_dataset("train")

train_domain_counter = Counter(
    ex["domain"] for ex in itertools.islice(train_ds_raw, 300)
)
print("üîç TRAIN domain mix:", train_domain_counter)

train_dataset = (
    train_ds_raw
    .to_iter_dataset()
    .batch(TRAIN_MICRO_BATCH_SIZE)
)

print("‚úÖ TRAIN streaming dataset ready")

print("\nüöß Building TEST dataset (STREAMING)...")
test_ds_raw = get_dataset("test")

test_domain_counter = Counter(
    ex["domain"] for ex in itertools.islice(test_ds_raw, 300)
)
print("üîç TEST domain mix:", test_domain_counter)

test_dataset = (
    test_ds_raw
    .to_iter_dataset()
    .batch(TRAIN_MICRO_BATCH_SIZE)
)

print("‚úÖ TEST streaming dataset ready")

print("\nüìä DATASET STATUS")
print("   TRAIN: pure streaming")
print("   TEST:  pure streaming")
print(f"‚è±Ô∏è Total setup time: {time.time() - t0:.2f}s")

Let's see how one batch of the training dataset looks like!


In [None]:
for ele in train_dataset[:1]:
  pprint(ele)

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [None]:
import os
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

os.environ["KAGGLE_KEY"] = user_secrets.get_secret("KAGGLE_KEY")
os.environ["KAGGLE_USERNAME"] = user_secrets.get_secret("KAGGLE_USERNAME")

# Now this will NOT trigger login
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
    kagglehub.login()

This code snippet serves as a workaround to re-save the pre-trained model checkpoint from Kaggle into a local format that is compatible with the [Flax NNX](https://flax.readthedocs.io/en/stable/why.html) library. Because the original checkpoint has parameter names and tensor structures that don't match the target NNX model architecture, it cannot be loaded directly.

We first load the original weights into a temporary model instance, then extract and re-save the model's state into a new, properly formatted local checkpoint, which can then be successfully loaded by the final sharded NNX model.

In [None]:
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model as gemma_model

import jax
import jax.numpy as jnp
import orbax.checkpoint as ocp
from flax import nnx

def get_gemma_ref_model(ckpt_path):
    # ===============================
    # Device mesh
    # ===============================
    mesh = jax.make_mesh(*MESH)

    # ===============================
    # Model config (‚úÖ CORRECT API)
    # ===============================
    model_config = gemma_model.ModelConfig.gemma3_1b_it()

    # ===============================
    # Build abstract (shape-only) model
    # ===============================
    abs_gemma: nnx.Module = nnx.eval_shape(
        lambda: params.create_model_from_checkpoint(
            params.GEMMA3_1B_IT,
            model_config,
        )
    )

    # ===============================
    # Prepare sharded state structure
    # ===============================
    abs_state = nnx.state(abs_gemma)
    pspecs = nnx.get_named_sharding(abs_state, mesh)

    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(
            a.shape,
            jnp.bfloat16,
            sharding=s,
        ),
        abs_state,
        pspecs,
    )

    # ===============================
    # Restore checkpoint
    # ===============================
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(
        ckpt_path,
        target=abs_state,
    )

    # ===============================
    # Materialize reference model
    # ===============================
    graph_def, _ = nnx.split(abs_gemma)
    ref_model = nnx.merge(graph_def, restored_params)

    return ref_model, mesh, model_config


def get_lora_model(base_model, mesh):
    # ===============================
    # LoRA configuration
    # ===============================
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|"
            ".*down_proj|.*up_proj|.*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )

    # ===============================
    # Apply LoRA
    # ===============================
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model,
        lora_provider,
        **model_input,
    )

    # ===============================
    # Re-apply sharding
    # ===============================
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)

    return lora_model

In [None]:
# ===============================
# Cleanup
# ===============================
!rm -rf /tmp/content/intermediate_ckpt/*
!rm -rf /tmp/content/ckpts/*

import os, gc, jax
import jax.numpy as jnp
from tunix.models.gemma3 import params
from tunix.models.gemma3 import model as gemma_model
from flax import nnx
import orbax.checkpoint as ocp

CKPT_PATH = os.path.join(INTERMEDIATE_CKPT_DIR, "state")

# ===============================
# Model config (CORRECT)
# ===============================
model_config = gemma_model.ModelConfig.gemma3_1b_it()

# ===============================
# Load base Gemma 3 1B
# ===============================
base_model = params.create_model_from_checkpoint(
    params.GEMMA3_1B_IT,
    model_config,
)

tokenizer = params.create_tokenizer()
print("‚úÖ Base Gemma-3 1B loaded")

# ===============================
# Save clean base state
# ===============================
checkpointer = ocp.StandardCheckpointer()
_, base_state = nnx.split(base_model)

checkpointer.save(CKPT_PATH, base_state)
checkpointer.wait_until_finished()

print("‚úÖ Clean base checkpoint saved")

### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_gemma_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices.
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

Now we load reference and policy Gemma models using the Flax NNX library and display their structures.

In [None]:
# ===============================
# Load reference model
# ===============================
ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=CKPT_PATH
)

print("‚úÖ Reference model loaded")

# ===============================
# Create LoRA actor
# ===============================
lora_policy = get_lora_model(ref_model, mesh)

print("‚úÖ LoRA actor created")

# ===============================
# Cleanup memory
# ===============================
del base_model, base_state
gc.collect()

# ===============================
# Sanity check
# ===============================
actor_params = nnx.state(lora_policy)
print(f"Actor param leaves: {len(jax.tree.leaves(actor_params))}")


## Reward Function Design

This notebook uses a **composed reward signal** to train Gemma 3 (1B) to produce
accurate answers **with explicit reasoning**, while remaining concise and
well-structured.

Rather than relying on a single reward, multiple reward functions are combined
additively. Each reward targets a specific behavior, and some are enabled only
during specific training phases (curriculum learning).

### Refusal Penalty (`punish_refusal`)
- Strong negative reward for:
  - Empty outputs
  - Refusals (e.g. ‚ÄúI can‚Äôt help with that‚Äù)
  - Deflections or clarification requests
- Purpose: prevent RL collapse and force the model to always attempt an answer  
- **Active in all phases**

---

### Strict Format Reward (`match_format_exactly`)
- Rewards outputs that **exactly match** the required structure:
  <reasoning>...</reasoning>
  <answer>...</answer>

* Acts as a **hard gate**: without structure, learning does not proceed
* High weight early in training
* **Active in all phases**

---

### Soft Format Reward (`match_format_approximately`)

* Provides partial credit for near-correct formatting
* Encourages compliance without brittleness
* Helps stabilize early training when strict format is still being learned
* **Active in all phases**

---

### Answer Correctness Reward (`check_answer`)

* Domain-aware correctness signal:

  * **Math**: numeric exact match + tolerance bands
  * **Science**: multiple-choice exact match
  * **Logic**: yes / no correctness
  * **Code**: skipped (handled structurally)
* Correct answers receive a strong positive reward
* Incorrect answers receive a negative reward
* **Enabled only after format stabilization (Phase 1 onward)**

---

### Reasoning Quality Reward (`reasoning_quality_reward`)

* Encourages **high-quality reasoning traces** without incentivizing verbosity
* Rewards:

  * Presence of a `<reasoning>` block
  * Multi-step structure
  * Logical connectors and grounding signals
* Penalizes:

  * Keyword spam
  * Pure fluff
  * Excessive rambling
* Intentionally **lower magnitude** than correctness
* Used to break ties between otherwise correct answers
* **Enabled after correctness learning begins**

---

### Length & Termination Penalty (`penalize_length_and_rambling`)

* Penalizes:

  * Overlong reasoning
  * Text after `</answer>`
  * Repetitive or rambling outputs
* Enforces concise, clean completions
* **Enabled only in the final training phase**

---

### Why Multiple Rewards?

Each reward targets a **single failure mode**:

* Refusal ‚Üí nuked
* Bad format ‚Üí blocked
* Wrong answer ‚Üí penalized
* Good answer + bad reasoning ‚Üí improved
* Long-winded answers ‚Üí trimmed

Together, they teach the model not just *what* to answer, but *how* to answer.

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{REASONING_START}.+?{REASONING_END}.*?"
    rf"{ANSWER_START}(.+?){ANSWER_END}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{REASONING_START}Let me"
    f" think!{REASONING_END}{ANSWER_START}2{ANSWER_END}",
)

Give the model a reward of 3 points if the format matches exactly.

In [None]:
def match_format_exactly(prompts, completions, **kwargs):
  return [
      0 if match_format.search(response) is None else 3.0
      for response in completions
  ]

We also reward the model if the format of the output matches partially.

In [None]:
def match_format_approximately(prompts, completions, answer, **kwargs):
  question = kwargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]
  print("START ============================")
  print(f"Question: {question[0]}")
  print(f"Answer: {answer[0]}")
  print(f"Response: {responses[0]}")
  print(f"Extracted: {extracted_responses[0]}")
  print("END ==============================")
    
  scores = []
  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(REASONING_START) == 1 else -0.5
    score += 0.5 if response.count(REASONING_END) == 1 else -0.5
    score += 0.5 if response.count(ANSWER_START) == 1 else -0.5
    score += 0.5 if response.count(ANSWER_END) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    """
    Domain-aware correctness reward.
    Works with mixed-domain GRPO batches.
    Includes structural reward for MBPP-style coding tasks.
    """

    domains = kwargs["domain"]
    responses = completions

    extracted = [
        m.group(1).strip() if (m := match_format.search(r)) else None
        for r in responses
    ]

    scores = []

    assert len(extracted) == len(answer) == len(domains), (
        f"Length mismatch: extracted={len(extracted)}, "
        f"answer={len(answer)}, domain={len(domains)}"
    )

    for guess, gold, domain in zip(extracted, answer, domains):

        if guess is None:
            scores.append(0.0)
            continue

        guess_l = guess.lower().strip()
        gold_l = str(gold).lower().strip()
        score = 0.0

        # ====================
        # üî¢ MATH
        # ====================
        if domain == "math":
            try:
                g = float(guess)
                a = float(gold)

                if g == a:
                    score = 3.0
                else:
                    ratio = g / a if a != 0 else 0.0
                    if 0.95 <= ratio <= 1.05:
                        score = 1.5
                    elif 0.9 <= ratio <= 1.1:
                        score = 0.5
                    else:
                        score = -1.0
            except:
                score = -0.5

        # ====================
        # üî¨ SCIENCE (MCQ)
        # ====================
        elif domain == "science":
            score = 3.0 if guess_l == gold_l else -1.0

        # ====================
        # üß† LOGIC (yes / no)
        # ====================
        elif domain == "logic":
            yes_set = {"yes", "true"}
            no_set = {"no", "false"}

            if guess_l in yes_set and gold_l in yes_set:
                score = 3.0
            elif guess_l in no_set and gold_l in no_set:
                score = 3.0
            else:
                score = -1.0

        # ====================
        # üíª CODE (MBPP ‚Äì structural reward)
        # ====================
        elif domain == "code":
            s = guess.strip()

            has_def = "def " in s
            has_paren = "(" in s and ")" in s
            has_colon = ":" in s
            has_return = "return" in s
            has_print = "print" in s

            # Hard fail cases
            if len(s) < 20 or "i cannot" in guess_l:
                score = -1.0

            else:
                structure_score = 0.0

                if has_def and has_paren and has_colon:
                    structure_score += 1.5
                if has_return or has_print:
                    structure_score += 1.0

                score = structure_score

        # ====================
        # ‚ùì Unknown domain
        # ====================
        else:
            score = 0.0

        scores.append(score)

    return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [None]:
match_numbers = re.compile(
    rf"{ANSWER_START}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{ANSWER_START}  0.34  {ANSWER_END}")

In [None]:
def punish_refusal(prompts, completions, **kwargs):
    scores = []

    REFUSAL_PHRASES = [
        "please provide the problem",
        "i need the problem",
        "cannot solve without",
        "please provide the reasoning",
        "don‚Äôt provide",
        "i cannot help",
        "cannot answer",
        "unable to solve",
        "please provide",
        "need more information",
    ]

    for completion in completions:
        # -----------------------------
        # Safe text extraction
        # -----------------------------
        if isinstance(completion, str):
            text = completion.lower()
        elif isinstance(completion, list) and len(completion) > 0:
            first = completion[0]
            if isinstance(first, dict) and "content" in first:
                text = first["content"].lower()
            else:
                text = str(completion).lower()
        else:
            text = str(completion).lower()

        text = text.strip()

        # -----------------------------
        # üö® HARD REFUSAL = ABSOLUTE DEATH
        # -----------------------------
        if any(p in text for p in REFUSAL_PHRASES):
            scores.append(-20.0) 
            continue

        # -----------------------------
        # üö´ Empty / ultra-short junk
        # -----------------------------
        if len(text) < 15:
            scores.append(-10.0)
            continue

        # -----------------------------
        # ‚úÖ ATTEMPT BONUS (CRITICAL)
        # -----------------------------
        attempted = (
            "<reasoning>" in text
            or "<answer>" in text
            or any(c.isdigit() for c in text)
        )

        if attempted:
            scores.append(+1.0)   # üõü trying > refusing
        else:
            scores.append(-2.0)   # vague fluff still bad

    return scores


In [None]:
def penalize_length_and_rambling(prompts, completions, **kwargs):
    scores = []

    MAX_LEN = 80          # tokens-ish proxy (chars OK too)
    LENGTH_PENALTY = 4.0  # strong on purpose

    RAMBLE_MARKERS = [
        "let's re-read",
        "this is not correct",
        "not possible",
        "Let's rephrase the problem",
    ]

    for completion in completions:
        # Extract text safely
        if isinstance(completion, str):
            text = completion
        elif isinstance(completion, list) and len(completion) > 0:
            text = completion[0].get("content", "")
        else:
            text = str(completion)

        text_lower = text.lower()
        score = 0.0

        # ===============================
        # 1Ô∏è‚É£ Length penalty (HUGE)
        # ===============================
        length = len(text)
        if length > MAX_LEN:
            excess = (length - MAX_LEN) / MAX_LEN
            score -= LENGTH_PENALTY * excess

        # ===============================
        # 2Ô∏è‚É£ Talking AFTER answer (ILLEGAL)
        # ===============================
        if "</answer>" in text:
            after = text.split("</answer>", 1)[-1].strip()
            if after:
                score -= 2.5  # hard slap

        # ===============================
        # 3Ô∏è‚É£ Rambling / restart detection
        # ===============================
        ramble_hits = sum(m in text_lower for m in RAMBLE_MARKERS)
        score -= 0.5 * ramble_hits

        scores.append(score)

    return scores

### Phase 1: Reasoning Quality Reward

In the second training phase, the model is encouraged to produce **high-quality reasoning traces** in addition to correct answers.

A dedicated reasoning-quality reward evaluates the content inside the `<reasoning>` block using domain-agnostic signals such as:
- Multi-sentence structure
- Logical flow markers (e.g., ‚Äúbecause‚Äù, ‚Äútherefore‚Äù)
- Grounding via numbers, examples, or named entities
- Length sanity (neither too short nor excessively verbose)

The reward penalizes:
- Missing or malformed reasoning blocks
- Keyword spam
- Purely rhetorical or fluffy explanations

This phase focuses on **learning how to reason**, while deferring strict brevity constraints to the final phase.


In [None]:
import re
from collections import Counter

# üîë Generic reasoning markers (domain-agnostic)
REASONING_KEYWORDS = [
    # logical flow
    "because", "therefore", "thus", "hence", "so", "as a result",
    "this implies", "it follows", "which means",

    # reasoning actions
    "assume", "consider", "analyze", "evaluate", "compare",
    "explain", "reason", "conclude", "determine",

    # structure
    "first", "second", "next", "then", "finally",

    # evidence / grounding
    "given", "based on", "from this", "according to"
]


def reasoning_quality_reward(prompts, completions, **kwargs):
    scores = []

    for response in completions:
        score = 0.0
        text = response.lower()

        # -----------------------------------
        # 1Ô∏è‚É£ Require reasoning block
        # -----------------------------------
        if "<reasoning>" not in text or "</reasoning>" not in text:
            scores.append(-0.4)
            continue

        m = re.search(r"<reasoning>(.*?)</reasoning>", text, re.S)
        if m is None:
            scores.append(-0.4)
            continue

        reasoning = m.group(1).strip()

        # -----------------------------------
        # 2Ô∏è‚É£ Sentence structure (domain-agnostic)
        # -----------------------------------
        sentences = [
            s.strip() for s in re.split(r"[.\n]", reasoning)
            if len(s.strip()) > 6
        ]

        if len(sentences) >= 2:
            score += 0.15
        if len(sentences) >= 4:
            score += 0.15
        if len(sentences) >= 7:
            score += 0.1

        # -----------------------------------
        # 3Ô∏è‚É£ Reasoning keyword usage (NOT spam)
        # -----------------------------------
        keyword_hits = sum(reasoning.count(k) for k in REASONING_KEYWORDS)

        if 1 <= keyword_hits <= 5:
            score += 0.25
        elif keyword_hits > 8:
            score -= 0.25  # keyword spam

        # -----------------------------------
        # 4Ô∏è‚É£ Keyword repetition penalty (ONLY keywords)
        # -----------------------------------
        keyword_counts = Counter()

        for kw in REASONING_KEYWORDS:
            c = reasoning.count(kw)
            if c > 0:
                keyword_counts[kw] += c

        if keyword_counts:
            max_rep = max(keyword_counts.values())

            if max_rep >= 5:
                score -= 0.5
            elif max_rep == 4:
                score -= 0.35
            elif max_rep == 3:
                score -= 0.2
            elif max_rep == 2:
                score -= 0.1

        # -----------------------------------
        # 5Ô∏è‚É£ Length sanity (GENERIC, relaxed)
        # -----------------------------------
        token_len = len(reasoning.split())

        if token_len < 25:
            score -= 0.25
        elif 50 <= token_len <= 250:
            score += 0.25
        elif 250 < token_len <= 450:
            score += 0.15
        elif token_len > 600:
            score -= 0.35  # rambling

        # -----------------------------------
        # 6Ô∏è‚É£ Grounding signals (numbers OR entities OR examples)
        # -----------------------------------
        has_numbers = bool(re.search(r"\d", reasoning))
        has_examples = "example" in reasoning or "for instance" in reasoning
        has_entities = bool(re.search(r"[A-Z][a-z]+", m.group(1)))

        grounding_hits = sum([has_numbers, has_examples, has_entities])

        if grounding_hits >= 1:
            score += 0.15
        if grounding_hits >= 2:
            score += 0.15

        # -----------------------------------
        # 7Ô∏è‚É£ Penalize pure fluff phrases
        # -----------------------------------
        fluff_phrases = [
            "it is obvious", "clearly", "everyone knows",
            "needless to say", "without loss of generality"
        ]

        if any(p in reasoning for p in fluff_phrases):
            score -= 0.3

        # -----------------------------------
        # 8Ô∏è‚É£ Final clamp (keep GRPO stable)
        # -----------------------------------
        score = max(-0.6, min(0.9, score))
        scores.append(score)

    return scores


## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


We define a helper function to generate an answer, given a prompt.

In [None]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
      eos_tokens=[1,106],
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output
    

Another helper function for evaluation.

In [None]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish.

In [None]:
# The evaluation might take up to couple of minutes to finish. Please be patient.

(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [None]:
# ===============================
# Checkpointing options
# ===============================
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS,
    max_to_keep=MAX_TO_KEEP,
)

# ===============================
# Metrics logger options (NEW API)
# ===============================
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo",
    project_name="tunix-grpo",
    run_name="gemma3-1b-grpo",
    flush_every_n_steps=20,
)


In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )
    

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1,106],
    ),
)
grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

### Setting Up the GRPO Trainer

Now we initialize our system for training. First, we create an `RLCluster` instance, which brings together the **policy model (`actor`)**, a **reference model (`reference`)**, and a **tokenizer**. Our `actor` is a trainable LoRA model, while the `reference` is a fixed base model that we use to guide the training.

We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.

Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:

**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.

**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in.

## üß© Phase-Based Curriculum (R1-Style)

Training is organized into **three phases**, inspired by Open-R1 and DeepSeek-R1.

Rather than enabling all rewards from the beginning, rewards are **gradually introduced**
to stabilize GRPO and prevent early collapse.

| Phase | Training Progress | Goal |
|------|------------------|------|
| Phase 0 | 0‚Äì25% | Learn to respond + follow format |
| Phase 1 | 25‚Äì75% | Learn to be correct |
| Phase 2 | 75‚Äì100% | Be concise and strict |

Each phase enables a different subset of reward functions.


In [None]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!

---

## Phase 0 ‚Äî Bootstrap: Learning to Respond and Follow Structure

**Objective:**  
Ensure the model always responds and learns the required output structure before optimizing correctness.

In the early stages of GRPO training, applying correctness or length-based penalties can destabilize learning. The model must first learn **how to respond**, not **how to be perfect**.

### Enabled Rewards
- **Refusal penalty**  
  Penalizes empty responses, refusals, or deflections.
- **Strict format reward**  
  Requires exact usage of:
```

<reasoning>...</reasoning> <answer>...</answer>

```
- **Soft format reward**  
Provides partial credit for near-correct formatting.

### Disabled Rewards
- Answer correctness
- Length / verbosity penalties

### Why this phase matters
- Prevents RL collapse from early negative rewards
- Teaches the model to separate reasoning from answers
- Establishes a stable behavioral baseline

This phase mirrors the early ‚Äúformat bootstrapping‚Äù stage used in Open-R1‚Äìstyle training.

In [None]:
PHASE_0_STEPS = int(0.25 * MAX_STEPS)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        punish_refusal,             # üö® prevent collapse
        match_format_exactly,       # üö™ structure
        match_format_approximately, # üß≠ soft guidance
    ],
    algo_config=grpo_config,
)

with mesh:
    grpo_trainer.train(
        train_dataset,
        max_steps=PHASE_0_STEPS,
    )


## Phase 1 ‚Äî Correctness: Learning to Solve the Task

**Objective:**  
Teach the model to produce correct answers *within* the learned reasoning format.

Once the model reliably follows the required structure, we introduce correctness-based rewards. At this stage, the model already knows *how* to answer ‚Äî now it learns *what* the correct answer is.

### Enabled Rewards
- **Refusal penalty**
- **Strict + soft format rewards**
- **Domain-aware correctness reward**
  - Exact match ‚Üí full reward
  - Near match ‚Üí partial reward
  - Numeric tolerance for math tasks

### Disabled Rewards
- Length / verbosity penalties

### Why length penalties are disabled
Applying conciseness pressure too early can cause:
- Truncated reasoning
- Incomplete answers
- Mode collapse

This phase focuses purely on **accuracy**, not polish.

In [None]:
PHASE_1_STEPS = int(0.50 * MAX_STEPS)

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        punish_refusal,
        match_format_exactly,
        match_format_approximately,
        check_answer,               # üéØ correctness ON
        reasoning_quality_reward
    ],
    algo_config=grpo_config,
)

with mesh:
    grpo_trainer.train(
        train_dataset,
        max_steps=PHASE_1_STEPS,
    )


## Phase 2 ‚Äî Polish: Being Correct, Concise, and Controlled

**Objective:**  
Refine the model‚Äôs outputs to be short, clean, and strictly formatted.

At this stage, the model already:
- Responds reliably
- Uses the correct format
- Produces mostly correct answers

We now apply termination pressure to discourage verbosity and rambling.

### Enabled Rewards
- **Refusal penalty**
- **Strict + soft format rewards**
- **Correctness reward**
- **Length and rambling penalty**
  - Penalizes overly long reasoning
  - Penalizes text after `</answer>`
  - Encourages concise explanations

### What this phase achieves
- Sharp, professional reasoning traces
- Clean separation between reasoning and final answer
- Better alignment with human and LLM-as-judge evaluation

This final phase aligns the model with real-world expectations for transparent but concise reasoning.

In [None]:
PHASE_2_STEPS = MAX_STEPS - PHASE_0_STEPS - PHASE_1_STEPS

grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        punish_refusal,
        match_format_exactly,
        match_format_approximately,
        check_answer,
        penalize_length_and_rambling,  # only now
    ],
    algo_config=grpo_config,
)

with mesh:
    grpo_trainer.train(
        train_dataset,
        max_steps=PHASE_2_STEPS,
    )

## Evaluate

Let's evaluate our finetuned model!

In [None]:
# Load checkpoint first.
import re

# Find the latest checkpoint by listing directories in CKPT_DIR/actor
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
  for item in os.listdir(actor_ckpt_dir):
    if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
      step = int(item)
      if step > latest_step:
        latest_step = step

if latest_step == -1:
  raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval')  # logging bug workaround

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [None]:
# The evaluation might take up to couple of minutes to finish. Please be patient.
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

With sufficient training, you should see that the percentages of correct model outputs have clearly gone up, which means our training worked.

In [None]:
# =====================================
# SAVE FULL TUNIX-LOADABLE CHECKPOINT
# =====================================
import os
import json
import orbax.checkpoint as ocp
from flax import nnx

EXPORT_DIR = "./tunix_full_model"
os.makedirs(EXPORT_DIR, exist_ok=True)

# Extract FULL params (base + LoRA merged)
full_params = nnx.state(lora_policy)

# Save using Orbax (Tunix-compatible)
full_checkpointer = ocp.StandardCheckpointer()
full_checkpointer.save(
    EXPORT_DIR,
    full_params,
    force=True,
)

# Optional but recommended: minimal model card
model_card = {
    "model_name": "gemma3-1b-grpo-reasoning",
    "base_model": "google/gemma-3-1b",
    "framework": "tunix / jax / flax-nnx",
    "training_method": "GRPO",
    "task": "reasoning"
}

with open(os.path.join(EXPORT_DIR, "model_card.json"), "w") as f:
    json.dump(model_card, f, indent=2)

print("‚úÖ Full Tunix model saved at:", EXPORT_DIR)

In [None]:
# =====================================
# UPLOAD FULL MODEL TO KAGGLE (MODEL)
# =====================================

import json
import subprocess

MODEL_SLUG = "gemma3-1b-grpo-reasoning-tunix"

OWNER = subprocess.check_output(["kaggle", "whoami"]).decode().strip()
MODEL_ID = f"{OWNER}/{MODEL_SLUG}"

print(f"üì¶ Uploading Kaggle model: {MODEL_ID}")

# Kaggle model metadata
metadata = {
    "title": "Gemma3 1B GRPO Reasoning (Tunix)",
    "id": MODEL_ID,
    "licenses": [{"name": "apache-2.0"}]
}

with open("model-metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

# Create model (no-op if already exists)
!kaggle models create -p . --metadata model-metadata.json || true

# Upload model version
!kaggle models versions create \
    -p tunix_full_model \
    -m "Full Gemma3 1B model trained with GRPO (Tunix compatible)"

print(MODEL_ID)