# GRPO Demo

This tutorial demonstrates training the [Gemma](https://deepmind.google/models/gemma/) 
2 2B-IT model on the [GSM8K math reasoning benchmark](https://huggingface.co/datasets/openai/gsm8k) 
using [Group Relative Policy Optimization (GRPO)](https://arxiv.org/pdf/2402.03300). 
GRPO can enhance your model's problem-solving skills on mathematical word problems,
coding problems, etc.

GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It
is a variant of [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) 
that reduces memory usage by eliminating the need for a separate value function
model. GRPO works by generating multiple responses for a given prompt, 
evaluating these responses using a reward model, and then calculating a relative
advantage based on the group's performance to update the policy.

In this tutorial we use a `v5e-8` TPU for Gemma2-2b-it. Let's get started!

Note that the setup below is for the Gemma2-2B-IT model only. If you want to use
another model (say, Qwen2.5), you may need to change the setup (for example, 
tokenizer, chat template, reward function, etc.).

## Install necessary libraries

In [1]:
!pip install -q wandb
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

# !pip install -q git+https://github.com/google/tunix
# !pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
# !pip install -U flax
!pip install flax==0.12.0

!pip install -q datasets wandb==0.22.0

[0m


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


Collecting google-tunix==0.1.3 (from google-tunix[prod]==0.1.3)


  Downloading google_tunix-0.1.3-py3-none-any.whl.metadata (7.8 kB)


Collecting numba (from google-tunix==0.1.3->google-tunix[prod]==0.1.3)
  Downloading numba-0.63.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.9 kB)


Collecting hf_transfer (from google-tunix==0.1.3->google-tunix[prod]==0.1.3)
  Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)


Collecting libtpu==0.0.20.* (from jax[tpu]!=0.7.2,>=0.6.0; extra == "prod"->google-tunix[prod]==0.1.3)


  Downloading libtpu-0.0.20-py3-none-manylinux_2_31_x86_64.whl.metadata (500 bytes)




Collecting llvmlite<0.47,>=0.46.0dev0 (from numba->google-tunix==0.1.3->google-tunix[prod]==0.1.3)


  Downloading llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.0 kB)
















Downloading google_tunix-0.1.3-py3-none-any.whl (253 kB)


Downloading libtpu-0.0.20-py3-none-manylinux_2_31_x86_64.whl (137.2 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/137.2 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [91m━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.8/137.2 MB[0m [31m202.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m86.0/137.2 MB[0m [31m215.4 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m131.6/137.2 MB[0m [31m218.9 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m137.1/137.2 MB[0m [31m219.6 MB/s[0m eta [36m0:00:01[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.2/137.2 MB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.6 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m83.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numba-0.63.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.8 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[?25h

Downloading llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (56.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/56.3 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━[0m [32m49.3/56.3 MB[0m [31m246.0 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m56.1/56.3 MB[0m [31m243.7 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m56.1/56.3 MB[0m [31m243.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m91.8 MB/s[0m eta [36m0:00:00[0m
[?25h

Installing collected packages: libtpu, llvmlite, hf_transfer, numba, google-tunix
  Attempting uninstall: libtpu
    Found existing installation: libtpu 0.0.17
    Uninstalling libtpu-0.0.17:


      Successfully uninstalled libtpu-0.0.17


  Attempting uninstall: google-tunix
    Found existing installation: google-tunix 0.1.1


    Uninstalling google-tunix-0.1.1:


      Successfully uninstalled google-tunix-0.1.1


Successfully installed google-tunix-0.1.3 hf_transfer-0.1.9 libtpu-0.0.20 llvmlite-0.46.0 numba-0.63.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m

Collecting flax==0.12.0


  Downloading flax-0.12.0-py3-none-any.whl.metadata (11 kB)




Downloading flax-0.12.0-py3-none-any.whl (466 kB)


Installing collected packages: flax


Successfully installed flax-0.12.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import wandb, os
from kaggle_secrets import UserSecretsClient
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")



## Imports

In [3]:
import functools
import gc
import os
from pprint import pprint
import re

import csv
import shutil

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma import model as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger



## Hyperparameters

Let's define the configuration we are going to use. Note that this is by no
means a "perfect" set of hyperparameters. To get good results, you might have
to train the model for longer.

In [4]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

# ====== LoRA ======
RANK = 64
ALPHA = 64.0

# ====== Sharding ======
MESH = [(1, 4), ("fsdp", "tp")]

# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 512
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 4

# === other GRPO configs ===
# The number of iterations per batch (𝜇 in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 2
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 2738
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 50

EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 100
MAX_TO_KEEP = 4

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

## Utility functions

In [5]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

## Data preprocessing

First, let's define some special tokens. We instruct the model to first reason
between the `<reasoning>` and `</reasoning>` tokens. After
reasoning, we expect it to provide the answer between the `<answer>` and
`</answer>` tokens.

In [6]:
reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"


SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \
provide your reasoning. Place it between {reasoning_start} and \
{reasoning_end}. Then, provide the final answer (i.e., just one numerical \
value) between {solution_start} and {solution_end}."""

TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model"""

We use OpenAI's [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k), which comprises grade school math word problems.

In [7]:
def extract_hash_answer(text: str) -> str | None:
  if "####" not in text:
    return None
  return text.split("####")[1].strip()


def _load_from_tfds(data_dir: str, split: str):
  import tensorflow_datasets.text.gsm8k
  return tfds.data_source(
      "gsm8k",
      split=split,
      data_dir=data_dir,
      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
      download=True,
  )


def download_kaggle_dataset(target_dir="./data/gsm8k"):
  os.makedirs(target_dir, exist_ok=True)
  src = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
  src = Path(src)
  dst = Path(target_dir)

  for csv_file in src.glob("*.csv"):  # match all CSV files
    shutil.copy2(csv_file, dst / csv_file.name)
    print(f"Copied {csv_file.name} → {dst/csv_file.name}")
  return target_dir


def get_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
  # Download data
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  if source == "tfds":
    import tensorflow_datasets.text.gsm8k
    data = tfds.data_source(
        "gsm8k",
        split=split,
        data_dir=data_dir,
        builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
        download=True,
    )

  elif source == "kaggle":
    kaggle_dir = download_kaggle_dataset(data_dir)
    file_name = "main_" + split + ".csv"
    csv_path = os.path.join(kaggle_dir, file_name)  # adjust filename if needed

    data = []
    with open(csv_path, newline="", encoding="utf-8") as csvfile:
      reader = csv.DictReader(csvfile)
      for row in reader:
        data.append({
            "question": row["question"],
            "answer": row["answer"],
        })

  else:
    raise ValueError(f"Unknown source: {source}")

  def _as_text(v):
    return v if isinstance(v, str) else v.decode("utf-8")

  dataset = (
      grain.MapDataset.source(data)
      .shuffle(seed=42)
      .map(
          lambda x: {
              # passed to model forward pass
              "prompts": TEMPLATE.format(
                  system_prompt=SYSTEM_PROMPT,
                  question=_as_text(x["question"]),
              ),
              # passed to reward functions
              "question": _as_text(x["question"]),
              # passed to reward functions
              "answer": extract_hash_answer(_as_text(x["answer"])),
          }
      )
  )
  return dataset

We split the dataset set into train and test sets as usual.

In [8]:
# source = input("Choose data source [tfds/kaggle]: ").strip().lower()
source = 'kaggle'

if source not in ("tfds", "kaggle"):
  print("Invalid choice. Defaulting to 'tfds'.")
  source = "tfds"

print(f"Using data source: {source}")

dataset = get_dataset(TRAIN_DATA_DIR, "train", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_BATCHES
]

if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

test_dataset = get_dataset(TEST_DATA_DIR, "test", source).batch(TRAIN_MICRO_BATCH_SIZE)[
    :NUM_TEST_BATCHES
]

dataset_lengths = (
    len(train_dataset),
    len(val_dataset) if val_dataset is not None else 0,
    len(test_dataset),
)
print(f"dataset contains {dataset_lengths} of batches")

Using data source: kaggle
Copied main_test.csv → data/train/main_test.csv
Copied main_train.csv → data/train/main_train.csv


Copied socratic_train.csv → data/train/socratic_train.csv
Copied socratic_test.csv → data/train/socratic_test.csv
Copied main_test.csv → data/test/main_test.csv
Copied main_train.csv → data/test/main_train.csv
Copied socratic_train.csv → data/test/socratic_train.csv
Copied socratic_test.csv → data/test/socratic_test.csv
dataset contains (2738, 0, 50) of batches


Let's see how one batch of the training dataset looks like!


In [9]:
for ele in train_dataset[:1]:
  pprint(ele)

{'answer': array(['3', '34'], dtype='<U2'),
 'prompts': array(['<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nMaria has 4 dimes, 4 quarters, and 7 nickels in her piggy bank. Her mom gives her 5 quarters. How much money, in dollars, does Maria have now?<end_of_turn>\n<start_of_turn>model',
       '<start_of_turn>user\nYou are given a problem. Think about the problem and provide your reasoning. Place it between <reasoning> and </reasoning>. Then, provide the final answer (i.e., just one numerical value) between <answer> and </answer>.\n\nA wildlife team is monitoring the number of birds in a park. There are 3 blackbirds in each of the park’s 7 trees. There are also 13 magpies roaming around the park. How many birds are in the park in total?<end_of_turn>\n<start_of_turn>model'],
      dtype='<U

## Load the policy model and the reference model

The policy model is the model which is actually trained and whose weights are
updated. The reference model is the model with which we compute KL divergence.
This is to ensure that the policy updates are not huge and that it does not
deviate too much from the reference model.

Typically, the reference model is the base model, and the policy model is the
same base model, but with LoRA parameters. Only the LoRA parameters are updated.

Note: We perform full precision (fp32) training. You can, however, leverage
Qwix for QAT.

To load the model, you need to be on [Kaggle](https://www.kaggle.com/) and need
to have agreed to the Gemma license
[here](https://www.kaggle.com/models/google/gemma/flax/).

In [10]:
# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

In [11]:
model_path = {
    "gemma2": "google/gemma-2/flax/", 
}
model_family = "gemma2"
model_version = "gemma2-2b-it"
print(f"{model_path[model_family]}{model_version}")

kaggle_ckpt_path = kagglehub.model_download(
    f"{model_path[model_family]}{model_version}"
)

google/gemma-2/flax/gemma2-2b-it


In [12]:
print(kaggle_ckpt_path)

/kaggle/input/gemma-2/flax/gemma2-2b-it/1


This code snippet serves as a workaround to re-save the pre-trained model checkpoint from Kaggle into a local format that is compatible with the [Flax NNX](https://flax.readthedocs.io/en/stable/why.html) library. Because the original checkpoint has parameter names and tensor structures that don't match the target NNX model architecture, it cannot be loaded directly.

We first load the original weights into a temporary model instance, then extract and re-save the model's state into a new, properly formatted local checkpoint, which can then be successfully loaded by the final sharded NNX model.

In [13]:
!rm /tmp/content/intermediate_ckpt/* -rf

!rm /tmp/content/ckpts/* -rf

if model_family == "gemma2":
  params = params_lib.load_and_format_params(
      os.path.join(kaggle_ckpt_path, "gemma2-2b-it")
  )
  gemma = gemma_lib.Transformer.from_params(params, version="2-2b-it")
  checkpointer = ocp.StandardCheckpointer()
  _, state = nnx.split(gemma)
  checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
  checkpointer.wait_until_finished()
  # Delete the intermediate model to save memory.
  del params
  del gemma
  del state
  gc.collect()

E0000 00:00:1766917016.306026      74 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238




### Model Loading and LoRA Application

These two functions work together to load a base model from a checkpoint and apply a LoRA (Low-Rank Adaptation) layer to it.

* `get_ref_model`: Loads the complete Gemma model from a specified checkpoint path. It uses **JAX sharding** to distribute the model parameters across multiple devices.
* `get_lora_model`: Takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it's ready for distributed training.

In [14]:
def get_gemma_ref_model(ckpt_path):
  mesh = jax.make_mesh(*MESH)
  model_config = gemma_lib.ModelConfig.gemma2_2b()
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: gemma_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
  )
  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

Now we load reference and policy Gemma models using the Flax NNX library and display their structures.

In [15]:
# Reference model
if model_family == "gemma2":
  ref_model, mesh, model_config = get_gemma_ref_model(
      ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
  )

In [16]:
# Policy model
lora_policy = get_lora_model(ref_model, mesh=mesh)
nnx.display(lora_policy)

In [17]:

# ==============================================================================
# [新增] 加载自定义的 LoRA Checkpoint 进行继续训练
# ==============================================================================
import os
from orbax import checkpoint as ocp
import jax
from flax import nnx

# 设置你的模型路径/kaggle/input/gemma2-2b-grpo/jax/jax-flax/4
# 注意：根据你之前的保存代码，参数保存在 "params" 子目录下
custom_ckpt_path = "/kaggle/input/gemma2-2b-grpo/jax/jax-flax/4/params"

print(f"正在从 {custom_ckpt_path} 加载 LoRA 参数...")

if os.path.exists(custom_ckpt_path):
    # 1. 定义恢复的目标结构 (仅提取模型中的 LoRA 参数结构)
    abs_params = jax.tree.map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        nnx.state(lora_policy, nnx.LoRAParam),
    )

    # 2. 读取 Checkpoint
    checkpointer = ocp.StandardCheckpointer()
    restored_lora_params = checkpointer.restore(custom_ckpt_path, target=abs_params)

    # 3. 更新模型权重 (将读取的参数应用到 lora_policy)
    nnx.update(
        lora_policy,
        jax.tree.map(
            lambda current, restored: restored,
            nnx.state(lora_policy, nnx.LoRAParam),
            restored_lora_params,
        ),
    )
    print("✅ 成功加载自定义 LoRA 参数！准备开始 Fine-tune。")
else:
    print(f"❌ 路径不存在: {custom_ckpt_path}，请检查路径是否正确。")
# ==============================================================================

正在从 /kaggle/input/gemma2-2b-grpo/jax/jax-flax/4/params 加载 LoRA 参数...




✅ 成功加载自定义 LoRA 参数！准备开始 Fine-tune。


In [18]:
if model_family == "gemma2":
  tokenizer = tokenizer_lib.Tokenizer(
      tokenizer_path=os.path.join(kaggle_ckpt_path, "tokenizer.model")
  )

## Define reward functions

We define four reward functions:

- reward if the format of the output exactly matches the instruction given in
`TEMPLATE`;
- reward if the format of the output approximately matches the instruction given
in `TEMPLATE`;
- reward if the answer is correct/partially correct;
- Sometimes, the text between `<answer>`, `</answer>` might not be one
  number. So, we extract the number, and reward the model if the answer is correct.

The reward functions are inspired from
[here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).

First off, let's define a RegEx for checking whether the format matches.

In [19]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

match_format.search(
    f"{reasoning_start}Let me"
    f" think!{reasoning_end}{solution_start}2{solution_end}",
)

<re.Match object; span=(0, 54), match='<reasoning>Let me think!</reasoning><answer>2</an>

Give the model a reward of 3 points if the format matches exactly.

In [20]:
def match_format_exactly(prompts, completions, **kwargs):
  return [
      0 if match_format.search(response) is None else 3.0
      for response in completions
  ]

We also reward the model if the format of the output matches partially.

In [21]:
def match_format_approximately(prompts, completions, **kwargs):
  scores = []

  for completion in completions:
    score = 0
    response = completion
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.5 if response.count(reasoning_start) == 1 else -0.5
    score += 0.5 if response.count(reasoning_end) == 1 else -0.5
    score += 0.5 if response.count(solution_start) == 1 else -0.5
    score += 0.5 if response.count(solution_end) == 1 else -0.5
    scores.append(score)
  return scores

Reward the model if the answer is correct. A reward is also given if the answer
does not match exactly, i.e., based on how close the answer is to the correct
value.

In [22]:
def check_answer(prompts, completions, answer, **kwargs):
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_format.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  assert len(extracted_responses) == len(
      answer
  ), f"{extracted_responses} and {answer} have mismatching length"
  for guess, true_answer in zip(extracted_responses, answer):
    score = 0
    if guess is None:
      scores.append(0)
      continue
    # Correct answer gets 3 points!
    if guess == true_answer:
      score += 3.0
    # Match if spaces are seen
    elif guess.strip() == true_answer.strip():
      score += 1.5
    else:
      # We also reward it if the answer is close via ratios!
      # Ie if the answer is within some range, reward it!
      try:
        ratio = float(guess) / float(true_answer)
        if ratio >= 0.9 and ratio <= 1.1:
          score += 0.5
        elif ratio >= 0.8 and ratio <= 1.2:
          score += 0.25
        else:
          score -= 1.0  # Penalize wrong answers
      except:
        score -= 0.5  # Penalize
    scores.append(score)
  return scores

Sometimes, the text between `<answer>` and `</answer>` might not be one
number; it can be a sentence. So, we extract the number and compare the answer.

In [23]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
match_numbers.findall(f"{solution_start}  0.34  {solution_end}")

['0.34']

In [24]:
def check_numbers(prompts, completions, answer, **kwargs):
  question = kwargs["question"]
  responses = completions

  extracted_responses = [
      guess.group(1) if (guess := match_numbers.search(r)) is not None else None
      for r in responses
  ]

  scores = []
  
  # Output to file instead of terminal
  log_dir = "./output"
  os.makedirs(log_dir, exist_ok=True)
  log_file = os.path.join(log_dir, "train_log.txt")
  
  with open(log_file, "a") as f:
      f.write("START ============================\n")
      f.write(f"Question: {question[0]}\n")
      f.write(f"Answer: {answer[0]}\n")
      f.write(f"Response: {responses[0]}\n")
      f.write(f"Extracted: {extracted_responses[0]}\n")
      f.write("END ==============================\n")

  for guess, true_answer in zip(extracted_responses, answer):
    if guess is None:
      scores.append(0)
      continue
    # Convert to numbers
    try:
      true_answer = float(true_answer.strip())
      guess = float(guess.strip())
      scores.append(1.5 if guess == true_answer else 0.0)
    except:
      scores.append(0)
      continue
  return scores

## Evaluate


Before we train the model, let's evaluate the model on the test set so we can
see the improvement post training.

We evaluate it in two ways:

**Quantitative**

* **Answer Accuracy**: percentage of samples for which the model predicts the
correct final numerical answer  
* **Answer (Partial) Accuracy**: percentage of samples for which the model
predicts a final numerical answer such that the \`model answer / answer\`
ratio lies between 0.9 and 1.1.  
* **Format Accuracy**: percentage of samples for which the model outputs the
correct format, i.e., reasoning between the reasoning special tokens, and the
final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.

**Qualitative**

We'll also print outputs for a few given questions so that we can compare the generated output later.


We define a helper function to generate an answer, given a prompt.

In [25]:
def generate(
    question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=question,
        ),
    ]
  else:
    input_batch = [
        TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=q,
        )
        for q in question
    ]

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=768,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=False,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

Another helper function for evaluation.

In [26]:
def evaluate(
    dataset,
    sampler,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, sampler, temperature, top_k, top_p, seed=p
      )
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = (
            guess.group(1)
            if (guess := match_numbers.search(response)) is not None
            else "-1000000"
        )
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check format
        if match_format.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [27]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

Now let's see how the original model does on the test set. You can see the percentages of the mode outputs that are fully correct, partially correct and just correct in format. The following step might take couple of minutes to finish.

In [28]:
# The evaluation might take up to couple of minutes to finish. Please be patient.

(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

  0%|          | 0/50 [00:00<?, ?it/s]

===> corr=3, total=10, corr / total * 100=30.0, partially_corr / total * 100=30.0, corr_format / total * 100=100.0


===> corr=10, total=20, corr / total * 100=50.0, partially_corr / total * 100=50.0, corr_format / total * 100=100.0


===> corr=17, total=30, corr / total * 100=56.666666666666664, partially_corr / total * 100=56.666666666666664, corr_format / total * 100=100.0


===> corr=21, total=40, corr / total * 100=52.5, partially_corr / total * 100=55.00000000000001, corr_format / total * 100=100.0


===> corr=24, total=50, corr / total * 100=48.0, partially_corr / total * 100=52.0, corr_format / total * 100=100.0


===> corr=31, total=60, corr / total * 100=51.66666666666667, partially_corr / total * 100=56.666666666666664, corr_format / total * 100=100.0


===> corr=38, total=70, corr / total * 100=54.285714285714285, partially_corr / total * 100=58.57142857142858, corr_format / total * 100=100.0


===> corr=44, total=80, corr / total * 100=55.00000000000001, partially_corr / total * 100=58.75, corr_format / total * 100=100.0


SKIPPED
===> corr=50, total=90, corr / total * 100=55.55555555555556, partially_corr / total * 100=58.88888888888889, corr_format / total * 100=100.0


===> corr=55, total=100, corr / total * 100=55.00000000000001, partially_corr / total * 100=59.0, corr_format / total * 100=100.0
corr=55, total=100, accuracy=55.00000000000001%, partial_accuracy=59.0%, format_accuracy=100.0%


## Train

Let's set up all the configs first - checkpointing, metric logging and training.
We then train the model.

In [29]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP,enable_async_checkpointing=False
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [30]:
# Logs
%load_ext tensorboard
%tensorboard --logdir /tmp/content/tmp/tensorboard/grpo --port=0

In [31]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [32]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

### Setting Up the GRPO Trainer

Now we initialize our system for training. First, we create an `RLCluster` instance, which brings together the **policy model (`actor`)**, a **reference model (`reference`)**, and a **tokenizer**. Our `actor` is a trainable LoRA model, while the `reference` is a fixed base model that we use to guide the training.

We then create a `GRPOLearner`, the specialized trainer that uses a list of **reward functions** to evaluate and optimize the model's output, completing the RL training setup.

Tunix trainers are integrated with [Weights & Biases](https://wandb.ai/) to help you visualize the training progress. You can choose how you want to use it:

**Option 1 (Type 1)**: If you're running a quick experiment or just testing things out, choose this. It creates a temporary, private dashboard right in your browser without requiring you to log in or create an account.

**Option 2 (Type 2)**: If you have an existing W&B account and want to save your project's history to your personal dashboard, choose this. You'll be prompted to enter your API key or log in.

In [33]:
import wandb
# 如果你不想登录 wandb，可以使用 mode="disabled"
# wandb.init(mode="disabled") 
# 或者如果你想记录日志：
wandb.init(project="gemma-grpo-finetune", name="run-1")


[34m[1mwandb[0m: Currently logged in as: [33mliuxiaohua721[0m ([33mnetcloud[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: creating run


[34m[1mwandb[0m: Tracking run with wandb version 0.22.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20251228_102049-l2rc03jw[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mrun-1[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/netcloud/gemma-grpo-finetune[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/netcloud/gemma-grpo-finetune/runs/l2rc03jw[0m


In [34]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

[34m[1mwandb[0m: Finishing previous runs because reinit is set to 'default'.


[34m[1mwandb[0m: updating run metadata


[34m[1mwandb[0m: 🚀 View run [33mrun-1[0m at: [34m[4mhttps://wandb.ai/netcloud/gemma-grpo-finetune/runs/l2rc03jw[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/netcloud/gemma-grpo-finetune[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20251228_102049-l2rc03jw/logs[0m


[34m[1mwandb[0m: creating run


[34m[1mwandb[0m: Tracking run with wandb version 0.22.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20251228_102051-ub43r5kt[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33m2025-12-28_10-20-51[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/netcloud/tunix?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/netcloud/tunix/runs/ub43r5kt?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m




[34m[1mwandb[0m: Finishing previous runs because reinit is set to 'default'.


[34m[1mwandb[0m: updating run metadata


[34m[1mwandb[0m: 🚀 View run [33m2025-12-28_10-20-51[0m at: [34m[4mhttps://wandb.ai/netcloud/tunix/runs/ub43r5kt?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/netcloud/tunix?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20251228_102051-ub43r5kt/logs[0m


[34m[1mwandb[0m: Tracking run with wandb version 0.22.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20251228_102052-6nbf9erl[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33m2025-12-28_10-20-52[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/netcloud/tunix?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/netcloud/tunix/runs/6nbf9erl?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m




In [35]:


# GRPO Trainer
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    grpo_config=grpo_config,
)

The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!

In [36]:
with mesh:
  grpo_trainer.train(dataset)

Actor Training:   0%|          | 0/2738 [00:00<?, ?step/s]















































































































































































































































































































































































[34m[1mwandb[0m: updating run metadata


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                                 actor/train/kl ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▁▂▂▄▅▄█▃▂▃▂▄▂
[34m[1mwandb[0m:                               actor/train/loss ▂▃▂▃▁▁▃▃▃▁▃▂▂▅▄▃▃▄▄▃▃▃▄▃▂▂▂▂▂▂▄▃█▄▂▃▃▅▂▃
[34m[1mwandb[0m:                         actor/train/perplexity ▁▂▂▁▃▃▂▄▃▂▃▁▂▂▅▂▂▄▂▂█▄▂▂▂▂▃▃▂▄▂▃▃▄▂▃▃▃▄▂
[34m[1mwandb[0m:                      actor/train/step_time_sec ▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:                      actor/train/steps_per_sec ▇▇▇█▇███▇▇▇█▇▁▇▇▇███▇█▇▇▇▇█████████▇████
[34m[1mwandb[0m:                    actor/train/tflops_per_step ▁
[34m[1mwandb[0m:      jax/core/compile/backend_compile_duration ▁
[34m[1mwandb[0m: jax/core/compile/jaxpr_to_mlir_module_duration ▁
[34m[1mwandb[0m:          jax/core/compile/jaxpr_trace_duration ▁
[34m[1mwandb[0m:               jax/orbax/write/sharded_array_gb ▁
[34m[1mwandb[0m:                                          

[34m[1mwandb[0m: 🚀 View run [33m2025-12-28_10-20-52[0m at: [34m[4mhttps://wandb.ai/netcloud/tunix/runs/6nbf9erl?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/netcloud/tunix?apiKey=58e2326a5c73c201d33a3f57d04c26d02a30794a[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20251228_102052-6nbf9erl/logs[0m


































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































## Evaluate

Let's evaluate our finetuned model!

In [37]:
wandb.init(project='tunix-eval')  # logging bug workaround

[34m[1mwandb[0m: Tracking run with wandb version 0.22.0


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20251228_113245-v31uzk66[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mjumping-eon-15[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/netcloud/tunix-eval[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/netcloud/tunix-eval/runs/v31uzk66[0m


In [38]:
# Load checkpoint first.

trained_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(NUM_BATCHES), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)



In [39]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

In [40]:
# The evaluation might take up to couple of minutes to finish. Please be patient.
(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(
    test_dataset,
    sampler,
    **GENERATION_CONFIGS["greedy"],
)
print(
    f"{corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,"
    f" {format_accuracy=}%"
)

  0%|          | 0/50 [00:00<?, ?it/s]

===> corr=3, total=10, corr / total * 100=30.0, partially_corr / total * 100=30.0, corr_format / total * 100=100.0


===> corr=10, total=20, corr / total * 100=50.0, partially_corr / total * 100=50.0, corr_format / total * 100=100.0


===> corr=18, total=30, corr / total * 100=60.0, partially_corr / total * 100=60.0, corr_format / total * 100=100.0


===> corr=21, total=40, corr / total * 100=52.5, partially_corr / total * 100=57.49999999999999, corr_format / total * 100=100.0


===> corr=26, total=50, corr / total * 100=52.0, partially_corr / total * 100=57.99999999999999, corr_format / total * 100=100.0


===> corr=32, total=60, corr / total * 100=53.333333333333336, partially_corr / total * 100=61.66666666666667, corr_format / total * 100=100.0


===> corr=39, total=70, corr / total * 100=55.714285714285715, partially_corr / total * 100=62.857142857142854, corr_format / total * 100=100.0


===> corr=44, total=80, corr / total * 100=55.00000000000001, partially_corr / total * 100=61.25000000000001, corr_format / total * 100=100.0


SKIPPED
===> corr=51, total=90, corr / total * 100=56.666666666666664, partially_corr / total * 100=62.22222222222222, corr_format / total * 100=100.0


===> corr=57, total=100, corr / total * 100=56.99999999999999, partially_corr / total * 100=63.0, corr_format / total * 100=100.0
corr=57, total=100, accuracy=56.99999999999999%, partial_accuracy=63.0%, format_accuracy=100.0%


In [41]:
# 定义一个测试问题
test_question = "why is the sky blue?"

print(f"Question: {test_question}\n")
print("-" * 20 + " Model Output " + "-" * 20)

# 调用 generate 函数生成回答
# 使用之前定义的 sampler
response = generate(
    test_question,
    sampler,
    temperature=0.7,  # 稍微有一些随机性，也可以设为 0 (greedy)
    top_k=50,
    top_p=0.95
)

print(response)

Question: why is the sky blue?

-------------------- Model Output --------------------


<reasoning>
Sunlight is made up of all the colors of the rainbow.  When sunlight enters the Earth's atmosphere, it collides with air molecules. The shorter wavelengths of light, like blue and violet, are scattered more effectively than the longer wavelengths, like red and orange. This scattering is called Rayleigh scattering.  So, we see a blue sky because our eyes are more sensitive to blue light than other colors. </reasoning>
<answer> 
<answer> 
 
<end_of_turn>


In [42]:
!rm -rf output/model/instance-flax/*


  pid, fd = os.forkpty()


In [43]:
import os
from huggingface_hub import HfApi, create_repo

# 1. 定义保存路径
FINAL_MODEL_DIR = "/kaggle/working/output/instance-flax"
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)

# 2. 保存完整的 LoRA 参数 (Orbax 格式)
# 注意：这里我们保存的是 LoRA 参数，加载时需要配合基座模型
checkpointer = ocp.StandardCheckpointer()
# 获取当前的 LoRA 状态
lora_state = nnx.state(lora_policy, nnx.LoRAParam)
# 保存到指定目录
checkpointer.save(os.path.join(FINAL_MODEL_DIR, "params"), lora_state)

# 3. 保存 Tokenizer (如果有变动，通常直接复制基座的即可)
#tokenizer.save_pretrained(FINAL_MODEL_DIR)
from transformers import AutoTokenizer
try:
    # kaggle_ckpt_path 是之前下载模型时定义的全局变量
    hf_tokenizer = AutoTokenizer.from_pretrained(kaggle_ckpt_path)
    hf_tokenizer.save_pretrained(FINAL_MODEL_DIR)
    print("Tokenizer 已成功保存到 HF 格式")
except Exception as e:
    print(f"使用 transformers 保存失败，尝试直接复制文件: {e}")
    # 备选方案：直接从原始目录复制 tokenizer.model 文件
    import shutil
    src_tokenizer = os.path.join(kaggle_ckpt_path, "tokenizer.model")
    if os.path.exists(src_tokenizer):
        shutil.copy(src_tokenizer, FINAL_MODEL_DIR)
        print("已直接复制 tokenizer.model")

# 4. 创建 README.md (包含比赛要求的 Model ID)
model_card_content = f"""
---
library_name: flax
tags:
- gemma
- tunix
- reinforcement-learning
- grpo
---

# Gemma 2 2B GRPO Finetuned

This model is finetuned on GSM8K using GRPO.

## Model Details
- **Base Model**: google/gemma-2-2b-it
- **Framework**: Tunix (Flax/JAX)
- **Method**: GRPO (Group Relative Policy Optimization)

## Usage
Load using Tunix library.
"""
with open(os.path.join(FINAL_MODEL_DIR, "README.md"), "w") as f:
    f.write(model_card_content)

print(f"模型已准备好，保存在: {FINAL_MODEL_DIR}")



Tokenizer 已成功保存到 HF 格式
模型已准备好，保存在: /kaggle/working/output/instance-flax


In [44]:
!pip install kaggle

  pid, fd = os.forkpty()


Collecting kaggle


  Downloading kaggle-1.8.3-py3-none-any.whl.metadata (16 kB)


Collecting kagglesdk<1.0,>=0.1.14 (from kaggle)
  Downloading kagglesdk-0.1.14-py3-none-any.whl.metadata (13 kB)


Collecting mypy>=1.15.0 (from kaggle)


  Downloading mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (2.2 kB)


Collecting python-slugify (from kaggle)
  Downloading python_slugify-8.0.4-py2.py3-none-any.whl.metadata (8.5 kB)


Collecting types-requests (from kaggle)
  Downloading types_requests-2.32.4.20250913-py3-none-any.whl.metadata (2.0 kB)


Collecting types-tqdm (from kaggle)
  Downloading types_tqdm-4.67.0.20250809-py3-none-any.whl.metadata (1.7 kB)




Collecting librt>=0.6.2 (from mypy>=1.15.0->kaggle)
  Downloading librt-0.7.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (1.3 kB)


Collecting text-unidecode>=1.3 (from python-slugify->kaggle)
  Downloading text_unidecode-1.3-py2.py3-none-any.whl.metadata (2.4 kB)


Downloading kaggle-1.8.3-py3-none-any.whl (102 kB)


Downloading kagglesdk-0.1.14-py3-none-any.whl (159 kB)


Downloading mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (13.6 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/13.6 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.6/13.6 MB[0m [31m127.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading python_slugify-8.0.4-py2.py3-none-any.whl (10 kB)


Downloading types_requests-2.32.4.20250913-py3-none-any.whl (20 kB)
Downloading types_tqdm-4.67.0.20250809-py3-none-any.whl (24 kB)


Downloading librt-0.7.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (189 kB)
Downloading text_unidecode-1.3-py2.py3-none-any.whl (78 kB)


Installing collected packages: text-unidecode, types-requests, python-slugify, librt, types-tqdm, mypy, kagglesdk, kaggle


Successfully installed kaggle-1.8.3 kagglesdk-0.1.14 librt-0.7.5 mypy-1.19.1 python-slugify-8.0.4 text-unidecode-1.3 types-requests-2.32.4.20250913 types-tqdm-4.67.0.20250809
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [45]:
# 1. 创建 model-metadata.json

import os
import json
from kaggle_secrets import UserSecretsClient
import json

# 获取凭证
user_secrets = UserSecretsClient()
os.environ['KAGGLE_USERNAME'] = user_secrets.get_secret("KAGGLE_USERNAME")
os.environ['KAGGLE_KEY'] = user_secrets.get_secret("KAGGLE_KEY")

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
username = user_secrets.get_secret("KAGGLE_USERNAME")
model_slug = "gemma2-2b-grpo"
instance_slug = "jax-flax"
instance_uri = f"{username}/{model_slug}/jax/jax-flax"


model_metadata = {
    "ownerSlug": username,
    "slug": model_slug,
    "title": "Gemma 2 2B GRPO Finetuned",
    "description": "This is a Gemma 2 2B model finetuned on the GSM8K dataset using Group Relative Policy Optimization (GRPO). It includes LoRA weights for Flax/JAX.",
    "isPrivate": True,
    "instances": [
        {
            "framework": "jax",       # 框架
            "slug": "jax-flax",       # 变体标识符，后续上传文件时会用到
            "description": "LoRA weights and tokenizer for Flax"
        }
    ]
}

with open(os.path.join(FINAL_MODEL_DIR, "model-metadata.json"), "w") as f:
    json.dump(model_metadata, f, indent=4)

In [46]:
import json

metadata = {
  "ownerSlug": "liuxiaohua72", # 替换为你的 Kaggle 用户名
  "modelSlug": model_slug,
  "framework": "jax",
  "instanceSlug": "jax-flax",
  "licenseName": "Apache 2.0",
  "title": "Gemma 2 2B GRPO",
  "subtitle": "Finetuned using JAX/Flax",
  "description": "LoRA weights for Gemma 2 2B-IT trained with GRPO on GSM8K.",
  "is_private": True,
  "licenses": [{"name": "apache-2.0"}],
  "keywords": ["jax", "flax", "gemma", "grpo"],
  "instances": [
    {
      "framework": "jax",
      "instance_slug": "jax-flax" # 这里的 slug 决定了变体名称
    }
  ]
}

with open(os.path.join(FINAL_MODEL_DIR, "model-instance-metadata.json"), "w") as f:
    json.dump(metadata, f, indent=2)

In [47]:
#!kaggle models instances create -p "$FINAL_MODEL_DIR"

In [48]:
!kaggle models instances versions create $instance_uri -p "$FINAL_MODEL_DIR" --dir-mode zip

Starting upload for file README.md


  0%|                                                 | 0.00/340 [00:00<?, ?B/s]

100%|████████████████████████████████████████████| 340/340 [00:00<00:00, 837B/s]
Upload successful: README.md (340B)
Starting upload for file special_tokens_map.json


  0%|                                                 | 0.00/555 [00:00<?, ?B/s]

100%|██████████████████████████████████████████| 555/555 [00:00<00:00, 1.44kB/s]
Upload successful: special_tokens_map.json (555B)
Starting upload for file tokenizer.model


  0%|                                               | 0.00/4.04M [00:00<?, ?B/s]

100%|██████████████████████████████████████| 4.04M/4.04M [00:00<00:00, 10.1MB/s]
Upload successful: tokenizer.model (4MB)
Starting upload for file tokenizer.json


  0%|                                               | 0.00/32.8M [00:00<?, ?B/s]

 35%|█████████████▋                         | 11.5M/32.8M [00:00<00:00, 108MB/s]

100%|██████████████████████████████████████| 32.8M/32.8M [00:00<00:00, 47.6MB/s]
Upload successful: tokenizer.json (33MB)
Starting upload for file tokenizer_config.json


  0%|                                               | 0.00/45.2k [00:00<?, ?B/s]

100%|███████████████████████████████████████| 45.2k/45.2k [00:00<00:00, 128kB/s]
Upload successful: tokenizer_config.json (45KB)


Starting upload for file params.zip


  0%|                                                | 0.00/108M [00:00<?, ?B/s]

 13%|█████▏                                  | 13.9M/108M [00:00<00:00, 136MB/s]

 30%|████████████                            | 32.6M/108M [00:00<00:00, 170MB/s]

 45%|██████████████████▏                     | 49.0M/108M [00:00<00:00, 167MB/s]

 60%|████████████████████████                | 65.0M/108M [00:00<00:00, 142MB/s]

 73%|█████████████████████████████▎          | 79.0M/108M [00:00<00:00, 142MB/s]

 87%|██████████████████████████████████▊     | 93.9M/108M [00:00<00:00, 147MB/s]

100%|█████████████████████████████████████████| 108M/108M [00:01<00:00, 101MB/s]
Upload successful: params.zip (108MB)


Your model instance version was created. Url=https://www.kaggle.com/models/liuxiaohua72/gemma2-2b-grpo/Jax/jax-flax/5


In [49]:
import os

print("Uploading from:", FINAL_MODEL_DIR)
for root, dirs, files in os.walk(FINAL_MODEL_DIR):
    for f in files:
        print(os.path.join(root, f))

Uploading from: /kaggle/working/output/instance-flax
/kaggle/working/output/instance-flax/README.md
/kaggle/working/output/instance-flax/special_tokens_map.json
/kaggle/working/output/instance-flax/tokenizer.model
/kaggle/working/output/instance-flax/model-instance-metadata.json
/kaggle/working/output/instance-flax/tokenizer.json
/kaggle/working/output/instance-flax/tokenizer_config.json
/kaggle/working/output/instance-flax/model-metadata.json
/kaggle/working/output/instance-flax/params/_CHECKPOINT_METADATA
/kaggle/working/output/instance-flax/params/_METADATA
/kaggle/working/output/instance-flax/params/manifest.ocdbt
/kaggle/working/output/instance-flax/params/_sharding
/kaggle/working/output/instance-flax/params/ocdbt.process_0/manifest.ocdbt
/kaggle/working/output/instance-flax/params/ocdbt.process_0/d/325b238bc57b0c6631ba251d8fd2854a
/kaggle/working/output/instance-flax/params/ocdbt.process_0/d/8ffa82664ac2e81aa1255ce153353527
/kaggle/working/output/instance-flax/params/ocdbt.proce

In [50]:
from huggingface_hub import HfApi
from huggingface_hub import login, upload_folder
from kaggle_secrets import UserSecretsClient
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
os.environ["HF_HUB_DISABLE_XET"] = "1"
'''
user_secrets = UserSecretsClient()
api = HfApi()
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)


upload_folder(
    folder_path=FINAL_MODEL_DIR,
    repo_id="liuxiaohua72/gemma2",
    repo_type="model"
)
'''


'\nuser_secrets = UserSecretsClient()\napi = HfApi()\nhf_token = user_secrets.get_secret("HF_TOKEN")\nlogin(token=hf_token)\n\n\nupload_folder(\n    folder_path=FINAL_MODEL_DIR,\n    repo_id="liuxiaohua72/gemma2",\n    repo_type="model"\n)\n'

With sufficient training, you should see that the percentages of correct model outputs have clearly gone up, which means our training worked.