# Google Tunix Hack - Train a Model to Show its Work 

This notebook demonstrates the post-training of Gemma-3-1B-IT across diverse domains, including math, coding, science, and creative writing. Using Tunix‚ÄîGoogle‚Äôs new JAX-native library‚Äîwe implement a two-phase pipeline: Supervised Fine-Tuning (SFT) to establish a structured reasoning format, followed by Reinforcement Learning (RL) to enhance logical depth. Our approach emphasizes reproducibility and provides a clear end-to-end framework for generating verifiable reasoning traces before reaching a final answer.

In [None]:
import functools
from pprint import pprint
import re
import sys
import os
import csv
import json
import shutil

from flax import nnx
import grain
import humanize
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import kagglehub
import numpy as np
import optax
from orbax import checkpoint as ocp
from pathlib import Path
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.models.gemma3 import params as gemma_params
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.sft import peft_trainer
from tunix.sft import utils

## Logging Into Services

In [None]:
# from dotenv import load_dotenv
# load_dotenv()
# print("Using env vars to login")

# import nest_asyncio
# nest_asyncio.apply()
# print("nest_asyncio applied")

# # Only using wandb on TPU VM because it has strange bugs on Colab
# !pip install -q wandb
# import wandb
# # Check if WANDB_API_KEY is set before logging in
# if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
#     wandb.login(key=os.environ["WANDB_API_KEY"])
# else:
#     print("WANDB_API_KEY not found. Skipping wandb login.")

# if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
#   kagglehub.login()

# if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
#     hf_token = os.environ["HF_TOKEN"]
#     !hf auth login --token "$hf_token"
# else:
#     print("HF_TOKEN not found. Skipping Hugging Face login.")

# SFT

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [None]:
print(jax.devices())

The show_hbm_usage() function monitors High Bandwidth Memory (HBM) consumption across all local JAX devices (e.g., TPUs or GPUs). It retrieves memory statistics directly from the device runtime, formats the byte counts into human-readable units (MiB/GiB), and calculates the utilization percentage.

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

show_hbm_usage()

## Hyperparameters

This cell defines required variable and hyperparameters to conduct SFT training.

In [None]:
# For the post-training we are going to use gemma-3-1b-it Model
# So, let's define the model id and its corresponding tokenizer
model_id = "google/gemma-3-1b-it"
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"

# Since we require maximum target length of 1200 (according to our dataset)
# During training we are going to use a batch size of 8
BATCH_SIZE = 8 
MAX_TARGET_LENGTH = 1200 # Adjusted based on your TPU memory and model size.


# Sharding
# Model Setup
NUM_DEVICES = len(jax.devices())

# Adjust dimensions based on available GPUs
# We define the shape as (FSDP, TP)
if NUM_DEVICES == 8:
  MESH_COUNTS = (1, 8) 
elif NUM_DEVICES == 4:
  MESH_COUNTS = (1, 4)
elif NUM_DEVICES == 2:
  MESH_COUNTS = (1, 2)
elif NUM_DEVICES == 1:
  MESH_COUNTS = (1, 1)
else:
  raise ValueError(f"Unsuppored Number of TPUs: {NUM_DEVICES}")

MESH = [
    MESH_COUNTS,
    ("fsdp", "tp"),
]

# LoRA/QLoRA Configuration
RANK =64
ALPHA = 64

# ############## Train ############################
# As the training time is one of the limiting factor, we are not going to 
# evaluate the model, which will increase the training time
# so, the training fraction is set to 1.0. However, if evaluation required 
# the traning fraction can be adjected accordingly
TRAIN_FRACTION = 0.99
SFT_MAX_STEPS = 1000 # Maximum training steps
EVAL_EVERY_N_STEPS = 2000 # evaluates the model every 200 steps
NUM_EPOCHS = 1

# #################### Checkpoint saving #################
# Following path are defined to store model checkpoints during training.
SFT_FULL_CKPT_DIR = f"{os.getcwd()}/tmp/sft/full_ckpts/"
SFT_CKPT_DIR = f"{os.getcwd()}/tmp/sft/lora_ckpts/"
SFT_PROFILING_DIR = f"{os.getcwd()}/tmp/sft/profiling/"

The following function is responsible for creating required folders for checkpoins

In [None]:
import logging
import shutil

def create_dir(path):
  try:
    os.makedirs(path, exist_ok=True)
    logging.info(f"Created dir: {path}")
  except OSError as e:
    logging.error(f"Error creating directory '{path}': {e}")


create_dir(SFT_FULL_CKPT_DIR)
create_dir(SFT_CKPT_DIR)
create_dir(SFT_PROFILING_DIR)

## Load model from HF

In the following cell, we are downloading our model from Hugging Face

In [None]:
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

EOS_TOKENS = []
generation_config_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(generation_config_path):
  with open(generation_config_path, "r") as f:
    generation_configs = json.load(f)
  EOS_TOKENS = generation_configs.get("eos_token_id", [])
  print(f"Using EOS token IDs: {EOS_TOKENS}")

print("\n--- HBM Usage BEFORE Model Load ---")
# show_hbm_usage()

We first load the original weights into a temporary model instance, then extract and re-save the model's state into a new, properly formatted local checkpoint, whcih can then be successfully loaded by the final sharded NNX model

In [None]:
MODEL_CP_PATH = local_model_path

if "gemma-3-1b" in model_id:
    model_config = gemma_lib.ModelConfig.gemma3_1b_it()
else:
    raise ValueError(f"Unsupported model: {model_id}")

mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(
        MODEL_CP_PATH, (model_config), mesh
    )
    # nnx.display(base_model)

## Initialize Tokenizer

This snippet initializes the tokenizer for our model

In [None]:
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if tokenizer.eos_id() not in EOS_TOKENS:
    EOS_TOKENS.append(tokenizer.eos_id())
    print(f"Using EOS token IDs: {EOS_TOKENS}")

## Apply LoRA to the base model

The `get_lora_model` takes the base model and applies LoRA layers to it. It uses a `LoraProvider` to select specific layers (like attention and MLP layers) to be adapted. The resulting LoRA-infused model is then sharded and updated to ensure it is ready for distributed training. 

In [None]:
# from grpo notebook
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )
  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Create LoRA model 
sft_lora_model = get_lora_model(base_model, mesh=mesh,)
# nnx.display(lora_model) # uncomment this if you want to dispaly the model

## Load Datasets for SFT Training

The following code is adapted from the Tunix library. Because Tunix does not natively provide SFT dataset preparation‚Äîspecifically for tokenization, padding, truncation, train-test splitting, and shuffling‚Äîthis notebook introduces a modified function to handle these preprocessing steps for our custom dataset.

For our SFT training, we are going to use a dataset called VITHURSHAN/Selected_SFT_plus_Cascade-SFT-Stage-1 from Hugging Face
https://huggingface.co/datasets/VITHURSHAN/Selected_SFT_plus_Cascade-SFT-Stage-1

Since the dataset for both SFT contains multiple domains, system prompt for each domain must be generated. To streamline the training process, the datasets mentioned above contains the system prompt for each data point in a separate column. Therefore, system prompt for each domain is not required for this training. 

In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data loading and preprocessing."""

from collections.abc import Iterable
from typing import Any
from datasets import load_dataset
import datasets
from grain import python as grain
import numpy as np
import tensorflow_datasets as tfds
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.sft.peft_trainer import TrainingInput  


INPUT_TEMPLATE_IT = {
    "prefix": "<start_of_turn>user\n{SYSTEM_PROMPT}\n",
    "suffix": "\n<end_of_turn>\n<start_of_turn>model\n",
}

def create_datasets(
    dataset_name: str,
    global_batch_size: int,
    max_target_length: int,
    num_train_epochs: int | None,
    tokenizer: tokenizer_lib.Tokenizer,
    instruct_tuned: bool = False,
    TRAIN_FRACTION: float = 1.0,
) -> tuple[Iterable[TrainingInput], Iterable[TrainingInput]]:
  """Creates train and eval data iterator.

  Args:
    dataset_name: The name of the dataset to use.
    global_batch_size: The global batch size to use for both train and eval.
    max_target_length: The maximum length of the target sequence.
    num_train_epochs: The number of epochs to use for training. If None, the
      dataset will be repeated indefinitely.
    tokenizer: The tokenizer to use for tokenizing the dataset.
    instruct_tuned: Whether the dataset should be instruct tuned.
    input_template: The input template to use for the dataset.

  Returns:
    A tuple of train and eval data iterators.
  """

  input_template = INPUT_TEMPLATE_IT

  try:
      raw_data = load_dataset(dataset_name, split="train")

      if TRAIN_FRACTION == 1.0:
          train_loader = _build_data_loader(
          data_source=raw_data,
          batch_size=global_batch_size,
          num_epochs=num_train_epochs,
          max_seq_len=max_target_length,
          tokenizer=tokenizer,
          input_template=input_template,
          )
          return train_loader, None
      else:
        split_data = raw_data.train_test_split(train_size=TRAIN_FRACTION, seed=42, shuffle=True)
        train_ds, eval_ds = split_data["train"], split_data["test"]

  except ValueError as e:
      raise ValueError(f"Unsupported dataset: {dataset_name}")


  train_loader = _build_data_loader(
      data_source=train_ds,
      batch_size=global_batch_size,
      num_epochs=num_train_epochs,
      max_seq_len=max_target_length,
      tokenizer=tokenizer,
      input_template=input_template,
  )
  eval_loader = _build_data_loader(
      data_source=eval_ds,
      batch_size=global_batch_size,
      num_epochs=1,
      max_seq_len=max_target_length,
      tokenizer=tokenizer,
      input_template=input_template,
  )
  return train_loader, eval_loader


def _build_data_loader(
    *,
    data_source: grain.RandomAccessDataSource,
    batch_size: int,
    num_epochs: int | None,
    max_seq_len: int,
    tokenizer: tokenizer_lib.Tokenizer,
    input_template: dict[str, str],
) -> grain.DataLoader:
  """Builds a data loader for the given data source."""
  return grain.DataLoader(
      data_source=data_source,
      sampler=grain.IndexSampler(
          num_records=len(data_source),
          num_epochs=num_epochs,
          shard_options=grain.NoSharding(),
          shuffle=True,
          seed=42,
      ),
      worker_count=5,
      operations=[
          _Tokenize(tokenizer, input_template),
          _BuildTrainInput(max_seq_len, tokenizer.pad_id()),
          _FilterOverlength(max_seq_len),
          grain.Batch(batch_size=batch_size, drop_remainder=True),
      ],
  )

class _Tokenize(grain.MapTransform):
    """Tokenize the input."""

    def __init__(
        self, tokenizer: tokenizer_lib.Tokenizer, input_template: dict[str, str]
    ):
      self._tokenizer = tokenizer
      self._input_template = input_template

    def map(self, element: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]:
      """Tokenize the input."""
      if "question" in element.keys():  
        src_tokens = self._tokenizer.tokenize(
            element["question"], # .decode() removed
            prefix=self._input_template["prefix"].format(SYSTEM_PROMPT=element['system_prompt']),
            suffix=self._input_template["suffix"],
            add_eos=False,
        )
        dst_tokens = self._tokenizer.tokenize(
            element["response"], add_eos=True # decode() removed
        )
      return src_tokens, dst_tokens


class _BuildTrainInput(grain.MapTransform):
    """Build a TrainingInput from a tuple of source and destination tokens."""

    def __init__(self, max_seq_len: int, pad_value: int | bool):
      self._max_seq_len = max_seq_len
      self._pad_value = pad_value

    def map(self, tokens: tuple[np.ndarray, np.ndarray]) -> TrainingInput:
      src_tokens, dst_tokens = tokens

      # The input sequence fed to the model is simply the concatenation of the
      # source and the destination.
      tokens = np.concat([src_tokens, dst_tokens], axis=0)

      # To prevent the model from updating based on the source (input)
      # tokens, add a target mask to each input.
      q_mask = np.zeros_like(src_tokens, dtype=np.bool)
      a_mask = np.ones_like(dst_tokens, dtype=np.bool)
      mask = np.concat([q_mask, a_mask], axis=0)

      # If the input tokens sequence is smaller than the target sequence size,
      # then pad it with pad tokens.
      tokens = self._pad_up_to_max_len(tokens, self._pad_value)

      # Don't want to perform the backward pass on the pad tokens.
      mask = self._pad_up_to_max_len(mask, 0)

      return TrainingInput(input_tokens=tokens, input_mask=mask)

    def _pad_up_to_max_len(
        self, input_tensor: np.ndarray, pad_value: int
    ) -> np.ndarray:
      """Pad the given tensor up to sequence length of a batch."""
      seq_len = input_tensor.shape[0]
      to_pad = np.maximum(self._max_seq_len - seq_len, 0)
      return np.pad(
          input_tensor,
          [[0, to_pad]],
          mode="constant",
          constant_values=pad_value,
      )

class _FilterOverlength(grain.FilterTransform):
    """Filter out overlength examples."""

    def __init__(self, max_seq_len: int):
      self._max_seq_len = max_seq_len

    def filter(self, element: TrainingInput) -> bool:
      return element.input_tokens.shape[0] <= self._max_seq_len


# Now call the above function to create train and eval dataset
# Rebember, if you set TRAIN_FRACTION = 1.0, you will get None for sft_eval_ds
# However, it does not affect the training process. 

# VITHURSHAN/Full_SFT_plus_Cascade-SFT-Stage-1
sft_train_ds, sft_eval_ds = create_datasets("VITHURSHAN/SFT_Jan22",
                global_batch_size=BATCH_SIZE,
                max_target_length=MAX_TARGET_LENGTH,
                num_train_epochs=NUM_EPOCHS,
                tokenizer=tokenizer,
                instruct_tuned=True,
                TRAIN_FRACTION=TRAIN_FRACTION,
              )

`gen_model_input_fn` transforms raw training inputs into the structured format required by the transformer. It generates essential metadata, including a positional index and a causal attention mask, based on the input tokens' padding mask to ensure the model only attends to valid, preceding information during training.

Key Components:
pad_mask: Identifies non-padding tokens.

positions: Maps the sequential order of tokens, ignoring padding.

attention_mask: A causal mask that prevents the model from "looking ahead" at future tokens or attending to padding.

In [None]:
def gen_model_input_fn(x: peft_trainer.TrainingInput):
  pad_mask = x.input_tokens != tokenizer.pad_id()
  positions = utils.build_positions_from_mask(pad_mask)
  attention_mask = utils.make_causal_attn_mask(pad_mask)
  return {
      'input_tokens': x.input_tokens,
      'input_mask': x.input_mask,
      'positions': positions,
      'attention_mask': attention_mask,
  }

## Training

This code configures and executes the **Supervised Fine-Tuning (SFT)** pipeline for the model using **LoRA** (Low-Rank Adaptation).

### Key Components:

* **Logging & Checkpointing:** Integrates **Weights & Biases (WandB)** and TensorBoard for real-time experiment tracking, with automated model saving every 250 steps.
* **Optimization Suite:** Implements an **Optax** chain featuring a **warmup-cosine decay schedule** and **global gradient clipping** (set to 1.0) to stabilize training and prevent divergent gradients.
* **Trainer Initialization:** Configures the `PeftTrainer` with gradient accumulation and a custom input function to handle the JAX-native data flow.
* **Execution:** Runs the training loop within a distributed **JAX mesh** context, ensuring the session closes gracefully via `wandb.finish()` even if a crash occurs.

In [None]:
import wandb
import os
# ignore wandb errors
# os.environ["WANDB_SILENT"] = "true"

# lora_logging_options = metrics_logger.MetricsLoggerOptions(
#     log_dir=f"./tmp/sft/tensorboard/lora", flush_every_n_steps=20
# )

checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=250,
    max_to_keep=4,
)
# 1. Initialize WandB
# wandb.init(
#     project="gemma-tunix",  # Change this to your project name
#     name=f"run-LORA",
#     config={
#         "eval_every_n_steps": EVAL_EVERY_N_STEPS,
#         "max_steps": SFT_MAX_STEPS,
#         "learning_rate": 2e-4,
#         "method": "LORA",
#     }
# )
# wandb.init()

sft_training_config = peft_trainer.TrainingConfig(
    gradient_accumulation_steps=2,
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=SFT_MAX_STEPS,
    # metrics_logging_options=lora_logging_options, 
    checkpoint_root_directory=SFT_CKPT_DIR,
    checkpointing_options=checkpointing_options,
)

# 1. Recommended LR: 2e-4 is the "Gold Standard" for LoRA
# 1e-3 is often too high; 5e-5 is too slow.
PEAK_LR = 2e-4

# A. Define the Scheduler
sft_lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=PEAK_LR, 
    # Warmup: Usually 3-5% of total steps. 
    warmup_steps=0.15 * SFT_MAX_STEPS,
    decay_steps=SFT_MAX_STEPS,
    end_value=PEAK_LR * 0.2  # Decay to 10% of peak (2e-5)
)

# B. Define the Optimizer WITH Gradient Clipping
# Clipping prevents "exploding gradients" which ruin training runs instantly.
sft_optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # <--- Crucial "other thing" to add
    optax.adamw(learning_rate=sft_lr_schedule, weight_decay=0.01)
)

sft_trainer = peft_trainer.PeftTrainer(
    sft_lora_model, sft_optimizer, sft_training_config
).with_gen_model_input_fn(gen_model_input_fn)

with mesh:
    sft_trainer.train(sft_train_ds, sft_eval_ds)
# try:
#     with mesh:
#         sft_trainer.train(sft_train_ds, sft_eval_ds)
# finally:
#     # Ensure run closes even if training crashes
#     wandb.finish()

## Sampler

In [None]:
from tunix.generate import sampler as sampler_lib

sft_sampler = sampler_lib.Sampler(
    transformer=sft_lora_model,
    tokenizer=tokenizer if "gemma" in model_id else tokenizer.tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=1700,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)
# wandb.init()

question ="""you are a good scientist. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
Which of the following will most likely to be true?

Options:
A: focused light beams like lasers can endanger pilots
B: napkins can endanger pilots
C: sweaters can endanger pilots
D: teddy bears can endanger pilots\n"""
# question = "write a python function to find the maximum from a given array, do not use any in-build methods. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags"

# question = """you are a good summarizer. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final summary between <answer> and </answer> tags. 
# The transition to solid-state batteries (SSBs) represents a paradigm shift in automotive engineering, promising energy densities nearly double that of current liquid-electrolyte lithium-ion cells. Proponents argue that SSBs eliminate the risk of thermal runaway by replacing flammable liquid electrolytes with non-combustible ceramic or polymer separators. However, significant manufacturing hurdles remain. High-volume production is currently cost-prohibitive due to the sensitivity of solid electrolytes to moisture and the difficulty of maintaining consistent 'solid-to-solid' interface contact during the battery's expansion and contraction cycles. While companies like Toyota and QuantumScape claim commercial viability is imminent, skeptics maintain that the supply chain for specialized raw materials is at least a decade away from maturity.
# """
TEMPLATE = """<start_of_turn>user
{question}<end_of_turn>
<start_of_turn>model"""

# question = """you are a good math solver. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
# Vithurshan is installing solar panels on a laboratory roof. Each solar panel produces 250 watts of power. He installs 12 rows of panels, with 8 panels in each row. However, due to building shade, 4 panels in total only operate at 50% capacity, and 2 panels are completely broken.

# How many total watts of power does the solar array produce?
# """

# question = """You are a good math solver. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
# I have two mangoes and three apples. How many fruits do I have?"""

# question = """you are a good story writer. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
# Write a story about a man who found a treasure in the jungle. The story should be 100 words long."""

# question = """you are a good summarizer. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final summary between <answer> and </answer> tags. 
# Summarize the following text: after the WW2, USA and USSR were the two superpowers in the world. They were in a race to see who could build the most powerful weapon. The USA won the race by building the atomic bomb. The USSR lost the race by not having the technology to build the atomic bomb."""


input_batch = [
TEMPLATE.format(question=question)
]

# input_batch = [
#     question
# ]
out_data = sft_sampler(
    input_strings=input_batch,
    max_generation_steps=800,  # The number of steps performed when generating a response.
    eos_tokens=EOS_TOKENS,
    temperature=0.2,
)

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"----------------------")
  print(f"Prompt:\n{input_string}")
  print(f"Output:\n{out_string}")
# wandb.finish()


## Clear Caches

In [None]:
# import gc
# import jax

# try:
#     # Delete references
#     del sft_lora_model, sft_optimizer, sft_lr_schedule, sft_trainer
#     del sft_training_config, sft_train_ds, sft_eval_ds
# except NameError:
#     pass

# # 1. Clear JIT/compilation caches
# jax.clear_caches() 

# # 2. Force Python garbage collection
# gc.collect() 

# # 3. Optional: Clear the live buffer cache (important for JAX)
# # This is often needed to truly see memory return in nvidia-smi
# for buf in jax.live_arrays():
#     buf.delete()

With the SFT phase complete, we now transition to Reinforcement Learning. We will utilize Group Relative Policy Optimization (GRPO), an algorithm specifically designed to enhance the reasoning capabilities of LLMs. As a variant of Proximal Policy Optimization (PPO), GRPO significantly reduces memory overhead by eliminating the need for a separate value function (critic) model. Instead, it generates a group of outputs for each prompt, evaluates them via a reward function, and updates the policy based on the relative advantage within that group.

In [None]:
show_hbm_usage()

# GRPO

This Part of the Notebook mainly utilizes the `grpo_gemma` example

## Hyperparameters

Let's define the configuration we are going to use for grpo training.

In [None]:
# ====== GRPO ======
# === Generation during GRPO training ===
MAX_PROMPT_LENGTH = 1100
TOTAL_GENERATION_STEPS = 950
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
NUM_GENERATIONS = 4

# === other GRPO configs ===
# The number of iterations per batch (ùúá in GRPO algo 1).
NUM_ITERATIONS = 1
# The coefficient for the KL divergence penalty (ùõΩ) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
BETA = 0.08
# Epsilon value for clipping (ùúÄ in GRPO loss in paper). Similar to PPO, for
# stable updates.
EPSILON = 0.2

# ====== Training ======
TRAIN_MICRO_BATCH_SIZE = 1
TRAIN_FRACTION = 1.0
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
NUM_BATCHES = 8000
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 64

EVAL_EVERY_N_STEPS = 64  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1  # can potentially train for more epochs

# Number of training steps.
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# Checkpoint saving
INTERMEDIATE_CKPT_DIR = f"{os.getcwd()}/tmp/grpo/intermediate_ckpt/"
CKPT_DIR = f"{os.getcwd()}/tmp/grpo/ckpts/"
SAVE_INTERVAL_STEPS = 25
MAX_TO_KEEP = 3

## Data Processing

The following function is inspired from the data processing function mention in the `grpo_gemma` notebook and slightly modified to handle the custom dataset. As mentioned above, the dataset itself contains the system prompt including formatting instructions with <reasoning>, </reasoning>, <answer>, </answer> tags. There is no need for a user to define this during the GRPO training. 

In [None]:
from datasets import load_dataset
import grain

# Define the that contains system_prompt and question
# both system_prompt and question will be coming from the dataset
TEMPLATE = """<start_of_turn>user
{system_prompt}

{question}<end_of_turn>
<start_of_turn>model
"""

def get_dataset(data_dir, split="train", TRAIN_MICRO_BATCH_SIZE = 1, NUM_BATCHES=3738):

    rl_data = load_dataset(data_dir)
    dataset = (
        grain.MapDataset.source(rl_data[split])
        .shuffle(seed=42)
        .map(
            lambda x: {
                "domain": x["domain"],
                "prompts": TEMPLATE.format
                (
                    system_prompt=x['system_prompt'],
                    question=x["question"],
                ),
                "question": x["question"],
                "answer": x["answer"],
            }
        )
    )
    return dataset.batch(TRAIN_MICRO_BATCH_SIZE)[:NUM_BATCHES]

# Let's call the function `get_dataset` to create a dataset
dataset = get_dataset("VITHURSHAN/RL_With_Cascade",
                      split="train",
                      TRAIN_MICRO_BATCH_SIZE=TRAIN_MICRO_BATCH_SIZE*2,
                      NUM_BATCHES=NUM_BATCHES)

In [None]:
# if the TRAIN_FRACTION == 1.0, there is no val_set
# it is similar to what we created during SFT
# considering the limited time, we are not going to have any val_dataset
if TRAIN_FRACTION == 1.0:
  train_dataset = dataset.repeat(NUM_EPOCHS)
  val_dataset = None
else:
  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
  train_dataset = train_dataset.repeat(NUM_EPOCHS)

  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)

# Finally, print the number of batches in the dataset
dataset_lengths = (
    len(train_dataset),
    len(val_dataset) if val_dataset is not None else 0,
)
print(f"dataset contains {dataset_lengths} of batches")

## Load Policy and Reference Model

The policy model is the primary model undergoing training and weight updates. In contrast, the reference model remains frozen and is used to calculate KL divergence. This constraint prevents the policy from deviating too drastically from its original distribution, ensuring training stability.

In this configuration, we use the base model as our reference and the `sft_lora_model` as our policy. Because the policy utilizes LoRA (Low-Rank Adaptation), only the lightweight adapter weights are updated, significantly reducing the computational footprint.

In [None]:
# MODEL_CP_PATH = local_model_path
# wandb.init()
# if "gemma-3-1b" in model_id:
#     model_config = gemma_lib.ModelConfig.gemma3_1b_it()
# else:
#     raise ValueError(f"Unsupported model: {model_id}")

# mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
# with mesh:
#     base_model = params_safetensors_lib.create_model_from_safe_tensors(
#         MODEL_CP_PATH, (model_config), mesh
#     )
#     # nnx.display(base_model)
# wandb.finish()

In [None]:
# Policy model
# wandb.finish()
# wandb.init()
lora_policy = get_lora_model(base_model, mesh=mesh)
# sft_lora = get_lora_model(base_model, mesh=mesh)
# nnx.display(lora_policy)
# wandb.finish()

In [None]:
show_hbm_usage()

In [None]:
# load checkpoint from previous run
from flax import nnx

# wandb.init()
ckp_path = f"{SFT_CKPT_DIR}/{SFT_MAX_STEPS}/model_params"

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)

checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(ckp_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params
    ),
)
# wandb.finish()

In [None]:
from tunix.generate import sampler as sampler_lib

ckp_sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer if "gemma" in model_id else tokenizer.tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=1700,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)
# wandb.init()

question ="""you are a good scientist. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
Which of the following will most likely to be true?

Options:
A: focused light beams like lasers can endanger pilots
B: napkins can endanger pilots
C: sweaters can endanger pilots
D: teddy bears can endanger pilots\n"""
question = "write a python function to find the maximum from a given array, do not use any in-build methods. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags"

# question = """you are a good summarizer. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final summary between <answer> and </answer> tags. 
# The transition to solid-state batteries (SSBs) represents a paradigm shift in automotive engineering, promising energy densities nearly double that of current liquid-electrolyte lithium-ion cells. Proponents argue that SSBs eliminate the risk of thermal runaway by replacing flammable liquid electrolytes with non-combustible ceramic or polymer separators. However, significant manufacturing hurdles remain. High-volume production is currently cost-prohibitive due to the sensitivity of solid electrolytes to moisture and the difficulty of maintaining consistent 'solid-to-solid' interface contact during the battery's expansion and contraction cycles. While companies like Toyota and QuantumScape claim commercial viability is imminent, skeptics maintain that the supply chain for specialized raw materials is at least a decade away from maturity.
# """
TEMPLATE = """<start_of_turn>user
{question}<end_of_turn>
<start_of_turn>model"""

question = """you are a good math solver. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
Vithurshan is installing solar panels on a laboratory roof. Each solar panel produces 250 watts of power. He installs 12 rows of panels, with 8 panels in each row. However, due to building shade, 4 panels in total only operate at 50% capacity, and 2 panels are completely broken.

How many total watts of power does the solar array produce?
"""

# question = """You are a good math solver. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
# I have two mangoes and three apples. How many fruits do I have?"""

# question = """you are a good story writer. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final answer between <answer> and </answer> tags. 
# Write a story about a man who found a treasure in the jungle. The story should be 100 words long."""

# question = """you are a good summarizer. Think carefully and write your reasoning between <reasoning> and </reasoning> tags and final summary between <answer> and </answer> tags. 
# Summarize the following text: after the WW2, USA and USSR were the two superpowers in the world. They were in a race to see who could build the most powerful weapon. The USA won the race by building the atomic bomb. The USSR lost the race by not having the technology to build the atomic bomb."""


input_batch = [
TEMPLATE.format(question=question)
]

# input_batch = [
#     question
# ]
out_data = ckp_sampler(
    input_strings=input_batch,
    max_generation_steps=800,  # The number of steps performed when generating a response.
    eos_tokens=EOS_TOKENS,
    temperature=0.2,
)

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"----------------------")
  print(f"Prompt:\n{input_string}")
  print(f"Output:\n{out_string}")
# wandb.finish()


## Define Reward Functions

To effectively improve the reasoning capabilities of the model during the reinforcement learning (RL) stage, this approach employs four complementary reward functions. Each reward targets a distinct aspect of model behavior, collectively guiding the model toward producing well-structured, logically sound, and high-quality responses. The combination of these rewards balances strict structural constraints with semantic quality and reference alignment.


1. **Reward for Correct Formatting** - inspired from grpo_gemma notebook

The first reward function enforces strict adherence to the required output format, which is essential for downstream evaluation and interpretability. The model is expected to generate outputs that strictly follow the predefined structure: <reasoning>reasoning</reasoning><answer>answer</answer>

2. **Reward for Approximate Formatting** - inspired from grpo_gemma notebook

While strict formatting is important, overly harsh penalties can hinder exploration during RL, particularly for smaller models. To mitigate this, a soft or approximate formatting reward is introduced.

This reward provides partial credit when the model output is close to the desired format but contains minor deviations, such as:

- Missing or malformed angle brackets

- Incorrect capitalization of tags

- Minor ordering issues (e.g., answer preceding reasoning)

- Additional whitespace or newline inconsistencies

3. **Reinforcement Learning with Reference Probability Reward (RLPR)** evaluates reasoning quality without external verifiers by using the model‚Äôs own confidence as the reward signal. For a given prompt, the model computes the token-level probability of a high-quality reference reasoning and answer. A higher probability indicates that the model internally judges the reasoning process as more likely to lead to the correct answer. During training, reinforcement learning maximizes this probability-based reward, encouraging reasoning trajectories that naturally align with correct outcomes across both structured and open-ended domains.

Rubrics as Rewards: Reinforcement Learning Beyond Verifiable Domains (https://arxiv.org/pdf/2506.18254)

4. **LLM-as-a-Judge Reward**

The final reward function employs a large, external LLM as an automated evaluator to assess the quality of the model‚Äôs outputs. Given the input prompt and the model-generated response, the judge model evaluates multiple qualitative dimensions, including:

- Correctness of the final answer

- Logical coherence and completeness of the reasoning trace

- Relevance to the input prompt

- Clarity and usefulness of the explanation

The judge model outputs a scalar score or categorical rating, which is then converted into a reward signal for RL optimization. This reward is particularly valuable for open-ended tasks such as creative writing, summarization, ideation, and story generation, where exact reference matching is infeasible.

Rubrics as Rewards: Reinforcement Learning Beyond Verifiable Domains (https://arxiv.org/abs/2507.17746)


**NOTE**: While the cell below includes the code for the RLPR (Reinforcement Learning with Reference Probability) Reward, it will not be active during this run. Since RLPR requires real-time tokenization and forward passes to generate logits, omitting it allows us to optimize training speed and remain within the competition's 9-hour TPU time limit.

In [None]:
import re
from typing import List, Optional
import jax.numpy as jnp
from flax import nnx
from tunix.sft import utils
import jax
import google.generativeai as genai
import json
# import os
from datetime import datetime
import random

reasoning_start = "<reasoning>"
reasoning_end = "</reasoning>"
solution_start = "<answer>"
solution_end = "</answer>"

# ========================Match format=================================
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

def match_format_exactly(response) -> float:
    return -1 if match_format.search(response) is None else 1.0

# ========================Partial match=================================
def match_format_approximately(response) -> float:

    score = 0
    # Count how many keywords are seen - we penalize if too many!
    # If we see 1, then plus some points!
    score += 0.2 if response.count(reasoning_start) == 1 else -0.2
    score += 0.2 if response.find(reasoning_start) == 0 else -0.2
    score += 0.2 if response.count(reasoning_end) == 1 else -0.2
    score += 0.2 if response.count(solution_start) == 1 else -0.2
    score += 0.2 if response.count(solution_end) == 1 else -0.2
    
    return score

# ======================== RLPR =================================
def get_token_probabilities( 
    model: nnx.Module,
    tokenizer,
    sequence: str,
    reference_answer: str,
) -> List[float]:

    """
    args:
        model: model
        tokenizer: tokenizer
        sequence: str (Reasoning + Answer)
        reference_answer: str (Answer only)
    returns:
        probs: List[float]
    """
    model.eval()
    # 1. Tokenize precisely to find the split point
    # We slice the string to get reasoning to ensure we match the sequence prefix exactly
    if sequence.endswith(reference_answer):
        reasoning = sequence[:-len(reference_answer)]
    else:
        reasoning = sequence.replace(reference_answer, "")
    tokenized_sequence = tokenizer.tokenize(sequence, add_eos=False)
    # tokenized_answer = tokenizer.tokenize(reference_answer, add_eos=False)
    # The index in tokenized_sequence where the answer begins
    # Both share the same BOS token, so the length of the reasoning IDs
    # is the correct starting index for the answer.

    if not reasoning:
        start_idx = 1
    else:
        tokenized_reasoning = tokenizer.tokenize(reasoning, add_eos=False)
        start_idx = tokenized_reasoning.shape[0]
    # 2. Setup inputs
    pad_mask = tokenized_sequence != tokenizer.pad_id()
    pad_mask = jnp.expand_dims(pad_mask, axis=0)
    positions = utils.build_positions_from_mask(pad_mask)
    attention_mask = utils.make_causal_attn_mask(pad_mask)

    # 3. Forward Pass
    logits = model(
        jnp.expand_dims(tokenized_sequence, axis=0),
        positions,
        None,
        attention_mask,
    )

    # 4. Extract Log Probs
    logits = jnp.squeeze(logits[0], axis=0) 
    # Logits at i predict token at i+1. Shift both to align.
    log_probs = jax.nn.log_softmax(logits[:-1], axis=-1)
    target_ids = tokenized_sequence[1:]

    # Get log probs for the actual tokens that appeared
    log_probs_all = jnp.take_along_axis(
        log_probs, 
        target_ids[..., None], 
        axis=-1
    ).squeeze(-1)

    # 5. Extract Answer Slice
    # Since log_probs_all starts from the 2nd token (index 1 of sequence),
    # the answer starting at start_idx in the sequence is at start_idx - 1 here.
    ans_log_probs = log_probs_all[start_idx - 1:]

    # Convert to probabilities
    probs = jnp.exp(ans_log_probs)

    return probs

def calculate_probability_reward(
    model: nnx.Module,
    tokenizer,
    reasoning: Optional[str],
    reference_answer: str,
) -> float:
    
    """
    args:
        model: model
        tokenizer: tokenizer
        reasoning: Optional[str] = None
        reference_answer: Optional[str] = None
    returns:
        r: float
    """

    if reasoning:
        o_prime = f"{reasoning} {reference_answer}"
    else:
        o_prime = reference_answer

    probability_of_y_star = get_token_probabilities(
        model,
        tokenizer,
        o_prime,
        reference_answer,
    )

    if probability_of_y_star.shape[0] == 0:
        return 0.0

    average_probability = jnp.mean(probability_of_y_star)

    return average_probability

def calculate_debiased_reward(
    model: nnx.Module,
    tokenizer,
    reasoning: str,
    reference_answer: str,
) -> float:

    """
    args:
        model: model
        tokenizer: tokenizer
        reasoning: str
        reference_answer: str
    returns:
        r: float
    """
    r = calculate_probability_reward(
        model,
        tokenizer,
        reasoning,
        reference_answer,
    )

    r_prime = calculate_probability_reward(
        model,
        tokenizer,
        None,
        reference_answer,
    )

    r_hat = r - r_prime
    r_scaled = r_hat * 5
    r_final = jnp.clip(r_scaled, a_min=0.0, a_max=1.0)

    return float(r_final)

# rlpr main
def extract_reasoning_and_answer(response: str):

    reasoning = re.search(rf'{reasoning_start}(.*?){reasoning_end}', response, re.DOTALL)
    answer = re.search(rf'{solution_start}(.*?){solution_end}', response, re.DOTALL)

    if reasoning and answer:
      return reasoning.group(1), answer.group(1)
    elif reasoning:
      return reasoning.group(1), None
    elif answer:
      return None, answer.group(1)
    else:
      return None, None

def rlpr(reasoning, reference_answer):
    return calculate_debiased_reward(
      lora_policy,
      tokenizer,
      reasoning,
      reference_answer,
    )

# ========================LLM as Judge=================================
# Prompts with Rubrics for reward function
# Rubrics for Coding Math and Science
LLM_AS_JUDGE_PROMPT_MATH_SCIENCE_CODING = """
You are a rigorous technical grader. Your task is to evaluate a model's response based on a Prompt and a Ground Truth answer.

### Evaluation Criteria:
1. **Correctness (Critical):** Does the text inside <answer> matches the Ground Truth? 
2. **Logic (Essential):** Does the <reasoning> trace logically lead to the answer without hallucinations or "looping" logic?
3. **Format (Required):** Are the <reasoning> and <answer> tags used correctly?

### Scoring Rubric:
- **2.0:** Perfect logic, correct final answer, and proper formatting.
- **1.5:** Correct final answer and proper format, but the reasoning is slightly disorganized or contains minor fluff.
- **1.0:** Correct final answer, but the reasoning is logically flawed, non-existent, or hallucinated.
- **0.5:** Wrong final answer, but the reasoning shows significant effort and follows the correct methodology.
- **0.0:** Wrong final answer and nonsensical/empty reasoning.

### Inputs:
- **Prompt:** {prompt}
- **Model Response:** {response}
- **Ground Truth:** {ground_truth}

### Output Instruction:
Examine the response carefully. Provide **ONLY** the numerical score as a float (e.g., 1.5 or 2.0). Do not include any text, explanations, or labels.
"""

# Rubrics for Story Generation
LLM_AS_JUDGE_PROMPT_STORY = """
You are a professional literary critic. Evaluate the story based on the prompt.

### Evaluation Criteria:
1. **Narrative Arc:** Does the <reasoning> show a clear plan (intro, conflict, climax) that the story follows?
2. **Creativity:** Is the prose engaging, or is it repetitive and clich√©?
3. **Consistency:** Do characters and settings stay consistent between the reasoning and the final story?

### Scoring Rubric:
- **2.0:** Excellent storytelling, vivid imagery, and a clear logical plan in the reasoning tags.
- **1.5:** Good story, but the reasoning is thin or the prose uses too many repetitive "filler" words.
- **1.0:** The story is coherent but boring/generic, or the reasoning doesn't match the output.
- **0.5:** Fragmented story or major contradictions between the plan and the final text.
- **0.0:** Non-sensical, offensive, or fails to follow the prompt entirely.

### Inputs:
- **Prompt:** {prompt}
- **Response:** {response}

### Output Instruction:
Provide ONLY the numerical score as a float (0.0 - 2.0). No text.
"""

# Rubrics for Creative Ideation
LLM_AS_JUDGE_PROMPT_IDEATION = """
You are an innovation consultant. Evaluate the ideas generated.

### Evaluation Criteria:
1. **Originality:** Are these "outside-the-box" ideas, or just the most obvious solutions?
2. **Feasibility:** Are the ideas actionable and relevant to the prompt?
3. **Reasoning Quality:** Does the <reasoning> explore different angles before settling on the ideas in <answer>?

### Scoring Rubric:
- **2.0:** Diverse, high-quality, and unique ideas with thorough brainstorming in reasoning.
- **1.5:** Good ideas, but they feel somewhat similar to each other.
- **1.0:** Obvious or "boring" ideas that don't show much creative effort.
- **0.0:** Repetitive ideas (listing the same thing twice) or irrelevant suggestions.

### Inputs:
- **Prompt:** {prompt}
- **Response:** {response}

### Output Instruction:
Provide ONLY the numerical score as a float (0.0 - 2.0). No text.
"""

# Rubrics for summarization
LLM_AS_JUDGE_PROMPT_SUMMARIZATION = """
You are an expert editor. Evaluate the summary of the provided text.

### Evaluation Criteria:
1. **Factuality:** Does the summary contain ONLY information present in the source?
2. **Density:** Is the summary concise while keeping all "key points"?
3. **Reasoning:** Does the <reasoning> correctly identify the main entities and themes before summarizing?

### Scoring Rubric:
- **2.0:** Perfect summary: concise, factual, and covers all core points.
- **1.5:** Factual summary, but misses one minor detail or is slightly too wordy.
- **1.0:** Contains a minor factual hallucination or misses a major key point.
- **0.0:** Major hallucinations, or the summary is longer than the original text.

### Inputs:
- **Source Text:** {prompt}
- **Summary Response:** {response}

### Output Instruction:
Provide ONLY the numerical score as a float (0.0 - 2.0). No text.
""" 

# Rubrics for Creative Writing
LLM_AS_JUDGE_PROMPT_CREATIVE = """
You are an expert editor and literary critic. 

### Evaluation Criteria:
1. **The Plan (<reasoning>):** Does the model outline a creative strategy (e.g., tone, structure, or plot points) before writing?
2. **Execution (<answer>):** Is the writing engaging, vivid, and free of repetitive "AI-style" filler?
3. **Prompt Adherence:** Did the model follow all constraints of the creative prompt?

### Scoring Rubric:
- **2.0:** Excellent. The reasoning shows a deep plan, and the writing is professional and creative.
- **1.5:** Good. The writing is solid, but the reasoning is brief or the prose is slightly generic.
- **1.0:** Functional. The text is coherent but lacks creativity or the reasoning is disconnected from the output.
- **0.5:** Poor. Significant repetition, boring prose, or failed to use tags correctly.
- **0.0:** Fail. Nonsensical, off-topic, or empty tags.

### Inputs:
- **Prompt:** {prompt}
- **Response:** {response}

### Output Instruction:
Provide ONLY the numerical score as a float (0.0 - 2.0). No text.
"""

# Rubrics for General Reasoning Questions
LLM_AS_JUDGE_PROMPT_GENERAL = """
You are an expert evaluator of general-purpose AI assistants.

### Evaluation Criteria:
1. **Intent Analysis (<reasoning>):** Does the model correctly identify the user's core intent and break down the problem logically before answering?
2. **Helpfulness & Clarity (<answer>):** Is the final response comprehensive, accurate, and structured in a way that is easy to read?
3. **Instruction Following:** Did the model satisfy all explicit constraints (e.g., "list format," "concise," "explain like I'm 5")?

### Scoring Rubric:
- **2.0:** Excellent. The reasoning shows deep understanding of the user's goal. The answer is precise, high-quality, and follows all instructions perfectly.
- **1.5:** Good. The answer is helpful and correct, but the reasoning is slightly generic or the writing style is verbose.
- **1.0:** Functional. The response addresses the prompt, but misses minor constraints or the reasoning is disconnected from the final output.
- **0.5:** Poor. Misses the core intent, contains hallucinations, or fails to use the required tags correctly.
- **0.0:** Fail. Nonsensical, off-topic, harmful, or empty tags.

### Inputs:
- **Prompt:** {prompt}
- **Response:** {response}

### Output Instruction:
Provide ONLY the numerical score as a float (0.0 - 2.0). No text.
"""

# Set up the Judge Model 
# We are going to use Gemini-3 flash as our judge model

GEMINI_API_KEY = "AIzaSyBH0jLu-8wiYwNUvSfqpJJqQcwr48x7ZP8"

# GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

if GEMINI_API_KEY:
    try:
        genai.configure(api_key=GEMINI_API_KEY)
        llm_judge = "gemma-3-27b-it"
        # llm_judge = "gemini-3-flash-preview"
        gemini_model = genai.GenerativeModel(llm_judge)
    except Exception as e:
        print(f"Warning: Failed to initialize Gemini model: {e}")
        gemini_model = None
else:
    print("GEMINI_API_KEY not found; skipping Gemini model initialization.")
    gemini_model = None

# Since our dataset has a column that stores the domain information,
# we can choose the correct prompt based on the domain

def llm_as_judge(domain, prompt, response, ground_truth):
    if gemini_model is None:
        return 0.0

    LLM_AS_JUDGE_PROMPT = None

    # basic science, coding, gsm8k
    if domain == "basic_science" or \
        domain == "coding_data" or \
            domain == "gsm8k":

      LLM_AS_JUDGE_PROMPT = LLM_AS_JUDGE_PROMPT_MATH_SCIENCE_CODING.format(
        prompt=prompt,
        response=response,
        ground_truth=ground_truth,
      )
    # creative ideation
    elif domain == "creative_ideation":
      LLM_AS_JUDGE_PROMPT = LLM_AS_JUDGE_PROMPT_IDEATION.format(
        prompt=prompt,
        response=response,
      )

    # story generation
    elif domain == "story_generation":
      LLM_AS_JUDGE_PROMPT = LLM_AS_JUDGE_PROMPT_STORY.format(
        prompt=prompt,
        response=response,
      )

    # creative writing
    elif domain == "creative_writing":
      LLM_AS_JUDGE_PROMPT = LLM_AS_JUDGE_PROMPT_CREATIVE.format(
        prompt=prompt,
        response=response,
      )
    # summarization
    elif domain == "summarization":
      LLM_AS_JUDGE_PROMPT = LLM_AS_JUDGE_PROMPT_SUMMARIZATION.format(
        prompt=prompt,
        response=response,
      )

    # general
    elif domain == "general":
        LLM_AS_JUDGE_PROMPT = LLM_AS_JUDGE_PROMPT_GENERAL.format(
            prompt=prompt,
            response=response,
        )
    else:
      return 0.0

    # call LLM
    try:
        response = gemini_model.generate_content(LLM_AS_JUDGE_PROMPT)
        response_text = response.text
        
        # Finds the first integer or decimal in the response
        match = re.search(r"[-+]?\d*\.\d+|\d+", response_text)
        
        if match:
            reward = float(match.group())
        else:
            reward = 0.0
        
    except Exception as e:
        # Catching broader API errors (quota, safety filters, etc.)
        reward = 0.0

    return reward

# GRPO
def GRPO_Rewards(prompts, completions, answer, **kwargs):

    domain = kwargs["domain"][0]
    question = kwargs["question"][0]
    prompt = prompts[0]
    responses = completions
    scores = []
    log_buffer = []
    
    for response in responses:
        # Initialize all components to 0.0
        reasoning, extracted_answer = None, None
        reward_for_rlpr = 0.0
        reward_for_llm_as_judge = 0.0
        reward_for_format_aptly = 0.0
        reward_for_format_exactly = 0.0
        reward_for_format_exactly = match_format_exactly(response)
        if reward_for_format_exactly == 1.0:
            # pass
            reasoning, extracted_answer = extract_reasoning_and_answer(response)
            # reward_for_rlpr = rlpr(reasoning, extracted_answer) # it is not used due to time limit
            reward_for_llm_as_judge = llm_as_judge(domain, prompt, response, answer[0])

        else:
            reward_for_rlpr = 0.0
            reward_for_llm_as_judge = 0.0
            reward_for_format_aptly = match_format_approximately(response)

        total_reward = reward_for_format_exactly + \
                      reward_for_format_aptly + \
                      reward_for_rlpr + \
                      reward_for_llm_as_judge

        scores.append(total_reward)

        log_buffer.append({
                "timestamp": datetime.now().isoformat(),
                "domain": domain,
                "question": question,
                "prompt": prompt,
                "response": response,
                "answer": answer[0],
                "rewards": 
                {
                    "reward_for_format_exactly": reward_for_format_exactly,
                    "reward_for_format_aptly": reward_for_format_aptly,
                    "reward_for_llm_as_judge": reward_for_llm_as_judge,
                    "reward_for_rlpr": reward_for_rlpr,
                    "total_reward": total_reward,
                }
            }
        )
    # write in a json
    with open("grpo_log.jsonl", "a", encoding="utf-8") as f:
        for entry in log_buffer:
            f.write(json.dumps(entry) + "\n")

    return scores 

## Train

In [None]:
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)

# Metrics logger
# metrics_logging_options = metrics_logger.MetricsLoggerOptions(
#     log_dir="/tmp/grpo/tensorboard/grpo", flush_every_n_steps=20
# )

In [None]:
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        # metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=EOS_TOKENS,
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

## Setting Up the GRPO Trainer

In [None]:
# RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=sft_lora_model, 
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# GRPO Trainer
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
        GRPO_Rewards
    ],
    algo_config=grpo_config,
)

## Training

In [None]:
# import wandb
# os.environ["WANDB_PROJECT"] = "tunix-grpo"
# os.environ["WANDB_SILENT"] = "true"
# wandb.init()
with mesh:
  grpo_trainer.train(train_dataset, val_dataset)
# wandb.finish()