# üëã Introduction

---

### Google Tunix Hack
*Train a model to show its work*


<div align="center">
  <img src="https://github.com/san9min/tunix-hack/blob/main/assets/profile.png?raw=true" width="60%">
</div>


This project targets the Google Tunix Hack ‚Äì Train a model to show its work by proposing a post-training pipeline that enables small language models (SLMs) to produce explicit and consistent reasoning traces. 

<div align="center">
  <img src="https://raw.githubusercontent.com/san9min/tunix-hack/main/assets/post-training-pipeline.png" width="80%">
</div>


Our approach first applies LoRA-based supervised fine-tuning (SFT) on a teacher-distilled, high-quality general-reasoning dataset to stabilize output format and reasoning behavior. We then perform LoRA-based GRPO reinforcement learning with rule- and probability-based rewards on a general reasoning dataset to improve reasoning ability.

This work demonstrates a practical recipe for training small reasoning models under limited compute constraints, aligned with Tunix‚Äôs goal of making step-by-step reasoning accessible to the open-source community.



# üéØ Training Strategy
---
We select Gemma-2B-IT from the small Gemma family due to its instruction-tuned initialization.

## üë®‚Äçüè´ Synthetic Data SFT Distillation

We adopt LoRA-based SFT as an initial stage to stabilize the model‚Äôs reasoning structure. The SFT process is conducted in two stages. In the first stage, the model is trained on general-domain data containing relatively long chains of thought (CoT) to learn reasoning processes. In the second stage, we further fine-tune the model on samples with reasoning lengths of up to 1024 tokens, which is our target length, while slightly increasing the proportion of mathematical reasoning data.


## ‚öñÔ∏è RL with GRPO using Rule- & Probability-Based Rewards

We apply Group Relative Policy Optimization (GRPO) to adaptively improve reasoning quality using relative comparisons among sampled responses.

### Reward Function

<div align="center">
      <img src="https://github.com/san9min/tunix-hack/blob/main/assets/reward_function_pipeline.png?raw=true" width="50%">
</div>

We use a domain-routed hybrid reward with a shared format constraint across all domains.

Specifically:

* Math samples receive format + answer correctness rewards.

* Non-math samples receive format + probability-based (RLPR-style) rewards.


### Rule-based Reward $R_{\text{rule}}$

We define a rule-based reward composed of two components:
a format reward, which is applied to all domains, and an answer reward, which is only applied to math samples.

The rule-based reward for math domain is defined as:
$$ R_{rule} = R_{format} + R_{answer} $$
  
  - **Format Reward $R_{\text{fmt}}$** : This reward is applied to both math and non-math samples, enforcing a strict output contract: `<reasoning> ... </reasoning> <answer> ... </answer>`.A response receives a positive reward only if it strictly matches the required tag template, with no extra text outside the tags, correct tag order, non-empty reasoning and answer spans. Missing tags, malformed structure, strict-match failures, or empty spans receive negative penalties.

  - **Answer Reward $R_{\text{ans}}$** : ONLY for Math domain. Answer correctness is evaluated using a cascade of increasingly permissive checks:
      - Exact string match after normalization (whitespace, currency symbols, commas)
      - Numeric parsing for valid numeric forms including integers,floating-point / scientific notation, fractions,percentages
      - Ratio-band partial credit when numeric parsing succeeds but values differ, using: $[0.9, 1.1]$, then $[0.8, 1.2]$, with decreasing reward.If direct parsing fails but the ground-truth answer is numeric, we extract the last numeric token from the model output and provide limited partial credit.


### Probability-based Reward $R_{\text{prob}}$ (Non-Math)

For non-math samples, we compute a policy-confidence signal by measuring the rollout policy‚Äôs average probability of the ground-truth answer tokens under teacher forcing.

Let $y*$ denote the ground-truth answer (optionally wrapped as `<answer>`{y*}`</answer>` and appended with EOS), and let $\mathcal{Y}$ be the token positions of $y*$.

We compute per-token probabilities under the rollout policy $\pi$ as:

$$ R_{\text{full}}= \frac{1}{|\mathcal{Y}|}\sum_{t \in \mathcal{Y}} \exp\!\left(\log p_{\pi}\left(y_t^{*}\mid\text{prompt} + \text{reasoning},y_{<t}^{*}\right)\right).$$


**Debiasing**

We also compute a prompt-only baseline to define the final prbability reward as $$R_{prob}=Clip(R_{full}‚àíR_{base},[0,1])$$ 

**Std-based Prompt Filtering**

When sampling $G$ rollouts per prompt, we compute the per-prompt standard deviation of $R_{\text{prob}}$.
Prompts with standard deviation below a threshold $\beta$ (fixed or EMA-updated) are filtered by zeroing out their $R_{\text{prob}}$.

**Objective**

The final reward combines a **shared format reward** with a
**domain-specific learning signal** (answer correctness for math, probability for non-math):

For each sample $i$, the final reward is:

$$R_i=\begin{cases}\lambda_{\text{rule}} \, R_{\text{rule}, i},& \text{if } d_i = \text{math}, \\[6pt]\lambda_{\text{prob}} \, R_{\text{prob}, i}+\lambda_{\text{fmt}} \, R_{\text{fmt}, i},& \text{otherwise}.\end{cases}$$

where $d_i \in ( \text{math}, \text{non-math})$ denote the per-sample domain label.

*(Optionally, the probability term can be **gated on strict format success** for non-math samples;  
if the strict tag format fails, $R_{\text{prob}}$ is dropped and only the format reward is applied.)*


# üóÇÔ∏è Finetuning Dataset
---
## üìù SFT Dataset

We construct our SFT dataset by curating and filtering subsets from publicly available reasoning datasets, including `glaiveai/reasoning-v1-20m` and `bespokelabs/Bespoke-Stratos-17k`. Glaiveai supplies large-scale general-domain reasoning traces beyond math and code, while Bespoke complements it with high-quality, correctness-filtered reasoning data across math, code, science, and logic tasks.

We strictly preprocess the dataset by enforcing standardized reasoning/answer tag formats, correct structure and uniqueness, token budget limits, removal of malformed samples, and English-only content to ensure training stability and format reliability.

Under constrained computational resources, we further reduce the dataset size through domain- and CoT length-aware sampling. Specifically, we sample reasoning traces using a 1:4 ratio between long and short CoT examples, aiming to balance exposure to rich reasoning processes‚Äîsuch as self-correction and verification patterns typically found in long CoTs‚Äîwith shorter reasoning traces that are more learnable for small models.

For SFT Stage 2, we refine this dataset by retaining only samples with reasoning lengths below 1024 tokens and slightly increasing the proportion of mathematical reasoning tasks.

<div align="center">
      <img src="https://github.com/san9min/tunix-hack/blob/main/assets/sft_dataset_response_distribution.webp?raw=true" width="50%">
</div>

You can use our this at here:

* SFT Stage1 Dataset : [DOD Gemma2 SFT Dataset](https://www.kaggle.com/datasets/sangminlee09/small-reasoning-model-sft-dataset)
* SFT Stage2 Dataset : [DOD Gemma2 SFT subset](https://www.kaggle.com/datasets/sangminlee09/small-reasoning-model-sft-subset)


## üìö RL Dataset

The RL training dataset is constructed by integrating multiple reasoning-oriented sources to ensure broad general-domain coverage.

Tasks with strictly rule-verifiable answers are collected from the `knoveleng/open-rs` dataset. We retain only samples with unambiguous numeric ground-truth answers, restricting valid formats to integers, floating-point numbers, and rational fractions, including both plain-text (e.g., 7/24) and LaTeX-style representations (e.g., \frac{7}{24}).

Additional general-domain reasoning tasks are drawn from two sources. Creative, preference-driven tasks are sourced from `SAA-Lab/LitBench-Train`, where samples are ranked by user feedback metadata. Non-creative reasoning tasks spanning domains such as science and business are collected from the `openbmb/RLPR-Train-Dataset`; 

For datasets with difficulty annotations, samples are stratified using a 1:3 Easy-to-Hard ratio. Creative tasks, which lack explicit difficulty labels, are included as a fixed-size subset selected based on preference-related signals derived from user feedback.

<div align="center">
  <table>
    <tr>
      <td>
        <img src="https://raw.githubusercontent.com/san9min/tunix-hack/main/assets/rl_dataset_level_distribution.webp" width="45%">
      </td>
      <td>
        <img src="https://raw.githubusercontent.com/san9min/tunix-hack/main/assets/rl_dataset_domain_distribution.webp" width="45%">
      </td>
    </tr>
  </table>
</div>


You can use here:
[DOD Gemma2 RL Dataset](https://www.kaggle.com/datasets/sangminlee09/small-reasoning-model-rl-dataset)

# üß© Finetuning code
---

### Installation & Imports

In [None]:
import os
os.environ["HF_HUB_DISABLE_XET"] = "1"

In [None]:
!pip install -q kagglehub

!pip install -q ipywidgets

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain
!pip install "google-tunix[prod]==0.1.3"

!pip uninstall -q -y flax
# !pip install -U flax
!pip install flax==0.12.0

!pip install -q datasets wandb==0.22.0
!pip install -q 'numpy>2'
!pip install -q kaggle

In [None]:
from __future__ import annotations

# =========================
# Standard library
# =========================
import csv
import functools
import gc
import json
import math
import os
import re
import shutil
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from pprint import pprint
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Callable

# =========================
# Third-party libraries
# =========================
import grain.python as grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import numpy as np
import optax
import pandas as pd
import qwix
import tensorflow_datasets as tfds
import wandb
from flax import nnx
from orbax import checkpoint as ocp
from tqdm.auto import tqdm  

# =========================
# Tunix / Project-specific
# =========================
from tunix import PeftTrainer, TrainingConfig  
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma import model as model_lib
from tunix.models.gemma import params as params_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import peft_trainer, utils
from tunix.sft.metrics_logger import MetricsLogger 
from tunix.sft import metrics_logger

## ‚öôÔ∏è Hyper Parameter

In [None]:
# =====================================================
# ================ Path & Checkpoints =================
# =====================================================

# TensorBoard logs for SFT training
SFT_LOG_DIR = "/tmp/content/tmp/tensorboard/sft"
SFT_CKPT_DIR = "/kaggle/working/sft_ckpts/actor"

# Intermediate checkpoints (for recovery / resume)
SFT_INTERMEDIATE_CKPT_DIR = "/kaggle/working/ckpts/intermediate_ckpt"

# SFT stage 2 checkpoints
SFT_STAGE2_CKPT_DIR = "/kaggle/working/sft_ckpts/actor2"

# Final checkpoints after RL
FINAL_CKPT_DIR = "/kaggle/working/ckpts"


# =====================================================
# ======================= Data ========================
# =====================================================

# SFT stage 1 dataset (full general-domain reasoning data)
SFT_TRAIN_DATA_DIR = (
    "/kaggle/input/small-reasoning-model-sft-dataset/"
    "small_reasoning_model_sft_dataset.csv"
)

# SFT stage 2 dataset (shorter reasoning, math-heavy subset)
SFT_STAGE2_TRAIN_DATA_DIR = (
    "/kaggle/input/small-reasoning-model-sft-subset/"
    "small_reasoning_model_sft_dataset_stage2.csv"
)

# RL training dataset (verifiable + non-verifiable tasks)
RL_TRAIN_DATA_DIR = (
    "/kaggle/input/small-reasoning-model-rl-dataset/"
    "small_reasoning_model_rl_dataset.csv"
)

# Fraction of dataset to use (1.0 = full dataset)
SFT_TRAIN_FRACTION = 1.0
RL_TRAIN_FRACTION = 1.0

# ======================= LoRA ========================

# LoRA rank (controls adapter capacity)
RANK = 32

# LoRA scaling factor
ALPHA = 64.0

# ===================== Sharding ======================
# Device mesh configuration
# (fsdp: parameter sharding, tp: tensor parallelism)
MESH = [(1, 4), ("fsdp", "tp")]

# =================== SFT Stage 1 =====================

# Maximum training steps
#  - 0  : skip SFT entirely
#  - -1 : automatically computed from dataset size
SFT_MAX_STEPS = -1

# Per-device micro batch size
SFT_TRAIN_MICRO_BATCH_SIZE = 2

# Gradient accumulation steps
SFT_GRAD_ACCUM_STEPS = 64

# Number of epochs
SFT_NUM_EPOCHS = 1

# Evaluation frequency (in steps)
SFT_EVAL_EVERY_N_STEPS = 150

# Base learning rate
SFT_LEARNING_RATE = 5e-5

# Warmup steps for LR scheduler
SFT_WARMUP_STEPS = 20

# Weight decay for optimizer
SFT_WEIGHT_DECAY = 0.1

# Maximum sequence length
SFT_MAX_SEQ_LENGTH = 2048

# Checkpoint save frequency
SFT_SAVE_INTERVAL_STEPS = 100

# Maximum number of checkpoints to keep
SFT_MAX_TO_KEEP = 5

# Gradient clipping norm
SFT_MAX_GRAD_NORM = 0.1

# =================== SFT Stage 2 =====================

# Lower learning rate for refinement stage
SFT_STAGE2_LEARNING_RATE = 1e-5

# Number of epochs for stage 2
SFT_STAGE2_NUM_EPOCHS = 1

# Shorter max sequence length (target reasoning length)
SFT_STAGE2_MAX_SEQ_LENGTH = 1280

# Maximum training steps (-1 = auto)
SFT_STAGE2_MAX_STEPS = -1

# Checkpoint save frequency
SFT_STAGE2_SAVE_INTERVAL_STEPS = 4

# ==================== RL (GRPO) ======================

# Maximum prompt length fed to the policy
RL_MAX_PROMPT_LENGTH = 256

# Total generation length
RL_TOTAL_GENERATION_STEPS = 768

# Sampling temperature (keep relatively high for diversity)
RL_TEMPERATURE = 0.6

# Nucleus sampling
RL_TOP_P = 1.0

# Top-k sampling
RL_TOP_K = 50

# Number of responses generated per prompt (G in GRPO)
RL_NUM_GENERATIONS = 4

# Number of GRPO inner iterations (Œº in the paper)
RL_NUM_ITERATIONS = 1

# KL divergence penalty coefficient (Œ≤)
BETA = 0.08

# PPO-style clipping epsilon (Œµ)
EPSILON = 0.2

# Per-device micro batch size
RL_TRAIN_MICRO_BATCH_SIZE = 2

# Number of batches per epoch
RL_NUM_BATCHES = 1200

# Number of evaluation batches
RL_NUM_TEST_BATCHES = 100

# Evaluation frequency
RL_EVAL_EVERY_N_STEPS = 100

# Number of epochs
RL_NUM_EPOCHS = 1

# Total number of RL training steps
RL_MAX_STEPS = int(
    RL_NUM_BATCHES * RL_NUM_ITERATIONS * RL_TRAIN_FRACTION * RL_NUM_EPOCHS
)

# Optimizer learning rate
RL_LEARNING_RATE = 1e-6

# Adam optimizer beta parameters
RL_B1 = 0.9
RL_B2 = 0.99

# Weight decay
RL_WEIGHT_DECAY = 0.1

# Warmup steps (fraction of total steps)
RL_WARMUP_STEPS = 0.1 * RL_MAX_STEPS

# Gradient clipping norm
RL_MAX_GRAD_NORM = 0.1

# Checkpoint save frequency
RL_SAVE_INTERVAL_STEPS = 200

# Maximum number of checkpoints to keep
RL_MAX_TO_KEEP = 8

# ============================================================== #

# DO NOT CHANGE BELOW

# Use these standard output tags so that your model's output follow this format in plain text (no JSON/XML):
# <reasoning>model_reasoning_trace</reasoning>
# <answer>model_final_answer</answer>

REASONING_START = "<reasoning>"
REASONING_END = "</reasoning>"
SOLUTION_START = "<answer>"
SOLUTION_END = "</answer>"

# Use these parameters for greedy decoding; used in competition evaluation
INF_TEMPERATURE=None
INF_TOP_K=1
INF_TOP_P=None
SEED=42

### System Prompt & Template

We use the same system prompt and template for both the SFT and RL stages.


In [None]:
SYSTEM_PROMPT = f"""You are helpful Assistant. The user asks a question, and you solve it.
You first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within {REASONING_START} {REASONING_END} and
{SOLUTION_START} {SOLUTION_END} tags, respectively, i.e., {REASONING_START} reasoning process here {REASONING_END} {SOLUTION_START} answer here {SOLUTION_END}.
"""

TEMPLATE = """
{system_prompt}
User: {user}. Assistant:
"""

## üßπ Data Preprocessing

### SFT Data Preprocessing

We preprocess the SFT data by tokenizing each prompt‚Äìcompletion pair with a fixed maximum sequence length. The prompt tokens are masked out from the loss, and supervision is applied only to the completion tokens. When sequences exceed the maximum length, the prompt is truncated first to preserve as much of the completion (CoT Process) as possible. All samples are padded to a uniform length to enable efficient batched training.

In [None]:
def get_tokenizer():
    kaggle_ckpt_path = kagglehub.model_download(
        "google/gemma-2/flax/gemma2-2b-it"
      )
    tokenizer = tokenizer_lib.Tokenizer(
          tokenizer_path=os.path.join(kaggle_ckpt_path, "tokenizer.model"))
    return tokenizer, kaggle_ckpt_path



def gen_model_input_fn(x: peft_trainer.TrainingInput):
    pad_mask = x.input_tokens != tokenizer.pad_id()
    positions = utils.build_positions_from_mask(pad_mask)
    attention_mask = utils.make_causal_attn_mask(pad_mask)

    return {
        "input_tokens": x.input_tokens,
        "input_mask": x.input_mask,          # <-- LOSS MASK (completion-only)
        "positions": positions,
        "attention_mask": attention_mask,
    }

In [None]:
import grain.python as grain

TrainingInput = peft_trainer.TrainingInput

# =========================================
# 1) Tokenize: (prompt_tokens, completion_tokens)
# =========================================
class _TokenizeSFT(grain.MapTransform):
    def __init__(self, tokenizer, *, template: str, system_prompt: str,):
        self._tokenizer = tokenizer
        self._template = template
        self._system_prompt = system_prompt

    def map(self, element: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
        def _as_text(v):
            return v if isinstance(v, str) else v.decode("utf-8")

        user_text = _as_text(element["user"])
        assistant_text = _as_text(element["assistant"])

        prompt = self._template.format(
            system_prompt=self._system_prompt,
            user=user_text,
        )
        completion = assistant_text

        prompt_tokens = np.asarray(list(self._tokenizer.encode(prompt)), dtype=np.int32)
        comp_tokens   = np.asarray(list(self._tokenizer.encode(completion)), dtype=np.int32)

        return prompt_tokens, comp_tokens

# =========================================
# 2) BuildTrainInput: concat + completion-only mask + (TRUNCATE + pad)
# =========================================
class _BuildTrainInputSFT(grain.MapTransform):
    def __init__(self, max_seq_len: int, pad_id: int, eos_id: int):
        self._max_seq_len = int(max_seq_len)
        self._pad_id = int(pad_id)
        self._eos_id = int(eos_id)

    def map(self, tokens: Tuple[np.ndarray, np.ndarray]) -> TrainingInput:
        prompt_tokens, comp_tokens = tokens
    
        # -------------------------
        # 1) Ensure EOS at end of completion (same intent as tokenize_sft_for_tunix)
        # -------------------------
        if self._eos_id is not None:
            if comp_tokens.size == 0 or comp_tokens[-1] != self._eos_id:
                comp_tokens = np.concatenate(
                    [comp_tokens.astype(np.int32, copy=False), np.asarray([self._eos_id], np.int32)],
                    axis=0,
                )
    
        # -------------------------
        # 2) Concat + completion-only loss mask
        # -------------------------
        prompt_tokens = prompt_tokens.astype(np.int32, copy=False)
        comp_tokens = comp_tokens.astype(np.int32, copy=False)
    
        input_tokens = np.concatenate([prompt_tokens, comp_tokens], axis=0)
    
        q_mask = np.zeros(prompt_tokens.shape[0], dtype=np.bool_)
        a_mask = np.ones(comp_tokens.shape[0], dtype=np.bool_)
        input_mask = np.concatenate([q_mask, a_mask], axis=0)
    
        # -------------------------
        # 3) Truncate (keep completion as much as possible by trimming prompt first)
        # -------------------------
        max_len = self._max_seq_len
        if input_tokens.shape[0] > max_len:
            overflow = int(input_tokens.shape[0] - max_len)
    
            # (a) drop from the LEFT of the prompt first
            drop = min(overflow, int(prompt_tokens.shape[0]))
            if drop > 0:
                input_tokens = input_tokens[drop:]
                input_mask = input_mask[drop:]
                overflow -= drop
    
            # (b) if still overflow, drop from the RIGHT (may cut completion tail)
            if overflow > 0:
                input_tokens = input_tokens[:-overflow]
                input_mask = input_mask[:-overflow]
    
        # -------------------------
        # 4) Pad up to max_len (mask pad = 0)
        # -------------------------
        input_tokens = self._pad_up_to_max_len(input_tokens, self._pad_id)
        input_mask   = self._pad_up_to_max_len(input_mask, 0)
    
        return TrainingInput(input_tokens=input_tokens, input_mask=input_mask)


    def _pad_up_to_max_len(self, arr: np.ndarray, pad_value: int) -> np.ndarray:
        seq_len = arr.shape[0]
        to_pad = max(self._max_seq_len - seq_len, 0)
        return np.pad(arr, [[0, to_pad]], mode="constant", constant_values=pad_value)

# =========================================
# 3) DataLoader builder
# =========================================
def build_sft_dataloader(
    *,
    data_source: grain.RandomAccessDataSource,
    batch_size: int,
    num_epochs: int | None,
    max_seq_len: int,
    tokenizer,
    template: str,
    system_prompt: str,
    drop_remainder: bool = True,
) -> grain.DataLoader:
    return grain.DataLoader(
        data_source=data_source,
        sampler=grain.IndexSampler(
            num_records=len(data_source),
            num_epochs=num_epochs,
            shard_options=grain.NoSharding(),
        ),
        operations=[
            _TokenizeSFT(tokenizer, template=template, system_prompt=system_prompt),
            _BuildTrainInputSFT(max_seq_len, pad_id=int(tokenizer.pad_id()), eos_id=int(tokenizer.eos_id())),
            grain.Batch(batch_size=batch_size, drop_remainder=True),
        ],
    )
    
def create_sft_dataset(
    df_sft,
    *,
    tokenizer,
    template: str,
    system_prompt: str,
    train_micro_batch_size: int,
    eval_micro_batch_size: int,
    num_train_epochs: int | None,
    max_seq_len: int,
    train_fraction: float,
    seed: int = 42,
):
    if not (0.0 < train_fraction <= 1.0):
        raise ValueError("train_fraction must be in (0, 1].")

    records = df_sft.to_dict(orient="records")
    n = len(records)
    if n == 0:
        raise ValueError("df_sft is empty.")

    rng = np.random.default_rng(seed)
    perm = rng.permutation(n)

    # -------------------------
    # split
    # -------------------------
    if train_fraction == 1.0:
        train_idx = perm
        val_idx = None
    else:
        n_train = int(np.floor(n * train_fraction))
        n_train = max(1, n_train)
        train_idx = perm[:n_train]
        val_idx = perm[n_train:]

    train_records = [records[i] for i in train_idx]
    val_records = [] if (val_idx is None) else [records[i] for i in val_idx]


    # -------------------------
    # build loaders (real)
    # -------------------------
    train_ds = grain.MapDataset.source(train_records)
    sft_train_loader = build_sft_dataloader(
        data_source=train_ds,
        batch_size=train_micro_batch_size,
        num_epochs=num_train_epochs,
        max_seq_len=max_seq_len,
        tokenizer=tokenizer,
        template=template,
        system_prompt=system_prompt,
        drop_remainder=True,  # Í∏∞Ï°¥Í≥º ÎèôÏùº
    )

    if val_idx is None or len(val_records) == 0:
        sft_val_loader = None
    else:
        val_ds = grain.MapDataset.source(val_records)
        sft_val_loader = build_sft_dataloader(
            data_source=val_ds,
            batch_size=eval_micro_batch_size,
            num_epochs=1,
            max_seq_len=max_seq_len,
            tokenizer=tokenizer,
            template=template,
            system_prompt=system_prompt,
            drop_remainder=True,
        )

    # -------------------------
    # summary print
    # -------------------------
    raw_train = len(train_records)
    raw_val = 0 if val_idx is None else len(val_records)


    print(
            f"[split] total={n:,}  train={raw_train:,}  "
            f"val={raw_val:,}  train_fraction={train_fraction}"
        )



    return sft_train_loader, sft_val_loader, raw_train

In [None]:
#Tokenizer

# Initialize tokenizer and load base model checkpoint
# - tokenizer: used for SFT/RL preprocessing and generation
# - kaggle_ckpt_path: local path to the downloaded base model
tokenizer, kaggle_ckpt_path = get_tokenizer()

In [None]:
# Load SFT Datasets 

# Load SFT stage 1 dataset
# - General-domain reasoning data
# - Used to stabilize overall reasoning structure with longer CoT
df_sft = pd.read_csv(SFT_TRAIN_DATA_DIR)

sft_train_dataset, sft_val_dataset, effective_train_len = create_sft_dataset(
    df_sft,
    tokenizer=tokenizer,
    template=TEMPLATE,
    system_prompt=SYSTEM_PROMPT,
    train_micro_batch_size=SFT_TRAIN_MICRO_BATCH_SIZE,
    eval_micro_batch_size=SFT_TRAIN_MICRO_BATCH_SIZE,
    num_train_epochs=SFT_NUM_EPOCHS,
    max_seq_len=SFT_MAX_SEQ_LENGTH,
    train_fraction=SFT_TRAIN_FRACTION,
    seed=SEED,   
 )

# Load SFT stage 2 dataset
# - Shorter reasoning sequences with higher math proportion (Subset of stage 1)
# - Used for refinement toward target reasoning length
df_sft_stage2 = pd.read_csv(SFT_STAGE2_TRAIN_DATA_DIR)

sft_stage2_train_dataset, sft_stage2_val_dataset, stage2_effective_train_len = create_sft_dataset(
    df_sft_stage2,
    tokenizer=tokenizer,
    template=TEMPLATE,
    system_prompt=SYSTEM_PROMPT,
    train_micro_batch_size=SFT_TRAIN_MICRO_BATCH_SIZE,
    eval_micro_batch_size=SFT_TRAIN_MICRO_BATCH_SIZE,
    num_train_epochs=SFT_STAGE2_NUM_EPOCHS,
    max_seq_len=SFT_STAGE2_MAX_SEQ_LENGTH,
    train_fraction=SFT_TRAIN_FRACTION,
    seed=SEED,   
 )


In [None]:
# Auto-compute SFT stage 1 steps and checkpoint interval
if SFT_MAX_STEPS == -1:
    SFT_MAX_STEPS = (
        effective_train_len * SFT_NUM_EPOCHS //
        (SFT_GRAD_ACCUM_STEPS * SFT_TRAIN_MICRO_BATCH_SIZE)
    )

# Auto-compute SFT stage 2 steps and checkpoint interval
if SFT_STAGE2_MAX_STEPS == -1:
    SFT_STAGE2_MAX_STEPS = (
        stage2_effective_train_len * SFT_STAGE2_NUM_EPOCHS //
        (SFT_GRAD_ACCUM_STEPS * SFT_TRAIN_MICRO_BATCH_SIZE)
    )

### RL Data Preprocessing

We preprocess the RL dataset with a focus on domain-aware system prompt selection to support downstream reward function design. For each sample, the system prompt is chosen based on its domain: mathematical problems enforce a structured reasoning-and-answer format with the final output restricted to a numeric value, whereas non-mathematical tasks use a general-purpose prompt.

In [None]:
# Math
MATH_SYSTEM_PROMPT =  f"""You are helpful assistant. You are given a math problem. Think about the problem and \
provide your reasoning. Place reasoning trace between {REASONING_START} and \
{REASONING_END}. Then, provide the only final answer (i.e., 'just' numerical value\
) between {SOLUTION_START} and {SOLUTION_END}."""

SYSTEM_PROMPT_BY_DOMAIN = {
    "math": MATH_SYSTEM_PROMPT,
    "default": SYSTEM_PROMPT,
}


In [None]:
def _as_text(v) -> str:
    """Safely convert a value to string (handles None / NaN)."""
    if v is None:
        return ""
    try:
        if pd.isna(v):
            return ""
    except Exception:
        pass
    return v if isinstance(v, str) else str(v)


def _get(x: dict, k: str, default: str = "") -> str:
    """Dictionary getter with safe string conversion."""
    return _as_text(x.get(k, default))


def create_rl_dataset(
    df_rl: pd.DataFrame,
    *,
    shuffle_seed: int = 42,
) -> grain.MapDataset:
    """
    Build an RL training dataset for Tunix/Grain.

    Expected columns:
      - question (required)
      - answer (required)
      - answer_canon (optional): canonical/normalized answer; used if present
      - reward_model_type (optional): used for reward routing/selection
      - domain (optional): metadata for logging/routing
      - answer_type (optional): metadata for logging/routing
    """
    required_cols = {"question", "answer"}
    missing = required_cols - set(df_rl.columns)
    if missing:
        raise ValueError(f"Missing columns: {missing}. Got={list(df_rl.columns)}")

    records = df_rl.to_dict(orient="records")

    def _map_record(x: dict) -> dict:
        question = _as_text(x["question"])

        # Prefer canonical answer if provided; otherwise fall back to raw answer
        answer_canon = _get(x, "answer_canon")
        answer = answer_canon if answer_canon else _as_text(x["answer"])
        domain = _get(x, "domain").lower()

        system_prompt = ( SYSTEM_PROMPT_BY_DOMAIN["math"] if domain == "math" else SYSTEM_PROMPT_BY_DOMAIN["default"])

        return {
            # Model input prompt
            "prompts": TEMPLATE.format(
            system_prompt=system_prompt,
            user=question,
            ),

            # Raw fields for reward function
            "question": question,
            "answer": answer,

            # Reward routing / metadata (optional)
            "reward_model_type": _get(x, "reward_model_type"),
            "domain": _get(x, "domain"),
            "answer_type": _get(x, "answer_type"),
        }

    ds = (
        grain.MapDataset.source(records)
        .shuffle(seed=shuffle_seed)
        .map(_map_record)
    )
    return ds

In [None]:
# ============================================================
# Load + Pre-filter
# Drop samples whose question exceeds the maximum token budget (Anti-OOM)
# ============================================================

df_rl = pd.read_csv(RL_TRAIN_DATA_DIR)


def count_tokens(text: str) -> int:
    """Count tokens using the global `tokenizer`."""
    text = _as_text(text)
    return len(tokenizer.encode(text))


before_n = len(df_rl)
q_tok_lens = df_rl["question"].apply(count_tokens)
df_rl = df_rl[q_tok_lens <= RL_MAX_PROMPT_LENGTH].reset_index(drop=True)
after_n = len(df_rl)

print(f"[filter] question tokens <= {RL_MAX_PROMPT_LENGTH}: {before_n} -> {after_n}")

# ============================================================
# Build dataset
# ============================================================

rl_dataset = create_rl_dataset(df_rl).batch(RL_TRAIN_MICRO_BATCH_SIZE)[:RL_NUM_BATCHES]


if RL_TRAIN_FRACTION == 1.0:
    rl_train_dataset = rl_dataset.repeat(RL_NUM_EPOCHS)
    rl_val_dataset = None
else:
    n_total_rl = len(rl_dataset)
    n_train_rl = int(n_total_rl * RL_TRAIN_FRACTION)

    rl_train_dataset = rl_dataset[:n_train_rl].repeat(RL_NUM_EPOCHS)
    rl_val_dataset = rl_dataset[n_train_rl:].repeat(RL_NUM_EPOCHS)

rl_dataset_lengths = (
    len(rl_train_dataset),
    len(rl_val_dataset) if rl_val_dataset is not None else 0,
)
print(f"dataset contains {rl_dataset_lengths} of batches")


In [None]:
# for ele in rl_train_dataset[:1]:
#   pprint(ele)

## üèÜ Reward Function



In [None]:
# =========================
# Weights
# =========================
@dataclass
class RuleWeights:
    # Format-related rewards and penalties
    format_exact: float = 3.0
    format_missing_penalty: float = -1.0
    format_order_penalty: float = -0.5  # Less harsh penalty to preserve exploration in early training
    format_empty_reasoning_penalty: float = -0.5
    format_empty_answer_penalty: float = -0.5

    # Answer-related rewards (string and numeric matching)
    answer_exact: float = 3.0
    answer_strip_match: float = 1.5
    answer_ratio_0p9_1p1: float = 0.5
    answer_ratio_0p8_1p2: float = 0.25
    answer_wrong_penalty: float = -1.0

    # Parse-failure penalty is recommended to be 0 to avoid over-penalizing sentence-style or non-structured tasks
    answer_parse_fail_penalty: float = 0.0

    # Reward for cases where the sentence contains a matching number (check_numbers-style)
    numbers_exact: float = 1.5


# =========================
# Parsing helpers
# =========================
_FRAC_LATEX = re.compile(
    r"""\\(?:frac|dfrac|tfrac)\s*
        \{\s*([+-]?\d+)\s*\}\s*
        \{\s*([+-]?\d+)\s*\}
    """,
    re.VERBOSE,
)
_FRAC_SLASH = re.compile(r"^\s*([+-]?\d+)\s*/\s*([+-]?\d+)\s*$")
_PERCENT = re.compile(r"^\s*([+-]?(?:\d+(?:\.\d*)?|\.\d+))\s*%\s*$")
_FLOATLIKE = re.compile(r"^\s*([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)\s*$")
_NUM_ANYWHERE = re.compile(r"([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)")


def _normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = str(s)
    s = s.replace("$", "").replace(",", "")
    s = re.sub(r"\s+", " ", s).strip()
    return s


def _parse_number(s: str):
    """
    Returns:
      - Fraction for exact rational forms (latex frac / a/b / percent / integer)
      - float for float-like forms
      - None if cannot parse
    """
    s = _normalize_text(s)

    # LaTeX-style fraction appearing anywhere in the string
    m = _FRAC_LATEX.search(s)
    if m:
        num = int(m.group(1))
        den = int(m.group(2))
        if den == 0:
            return None
        return Fraction(num, den)

    # Pure slash-style fraction (a/b)
    m = _FRAC_SLASH.match(s)
    if m:
        num = int(m.group(1))
        den = int(m.group(2))
        if den == 0:
            return None
        return Fraction(num, den)

    # Percentage form (exact)
    m = _PERCENT.match(s)
    if m:
        val = Fraction(m.group(1))
        return val / 100

    # Integer form (exact)
    if re.fullmatch(r"[+-]?\d+", s):
        return Fraction(int(s), 1)

    # Floating-point or scientific notation
    m = _FLOATLIKE.match(s)
    if m:
        try:
            return float(m.group(1))
        except Exception:
            return None

    return None

def _extract_last_number_token(s: str) -> Optional[str]:
    """
    Extract the last numeric expression from a string.
    Priority (from strongest to weakest):
      1) LaTeX fraction  \\frac{a}{b}
      2) Slash fraction  a/b
      3) Percent         12%
      4) Float / integer 0.34, 5, 1e-3
    """
    if s is None:
        return None

    s = _normalize_text(s)

    # 1) LaTeX fraction (use the last occurrence)
    latex_fracs = list(_FRAC_LATEX.finditer(s))
    if latex_fracs:
        m = latex_fracs[-1]
        return m.group(0)

    # 2) Slash-style fraction a/b
    slash_fracs = list(_FRAC_SLASH.finditer(s))
    if slash_fracs:
        m = slash_fracs[-1]
        return m.group(0)

    # 3) Percentage expression
    percents = list(_PERCENT.finditer(s))
    if percents:
        return percents[-1].group(0)

    # 4) Floating-point or integer number
    nums = _NUM_ANYWHERE.findall(s)
    if nums:
        return nums[-1]

    return None


def _ratio_score(w: RuleWeights, ratio: float) -> float:
    if 0.9 <= ratio <= 1.1:
        return w.answer_ratio_0p9_1p1
    if 0.8 <= ratio <= 1.2:
        return w.answer_ratio_0p8_1p2
    return w.answer_wrong_penalty


def _to_float(x) -> float:
    return float(x) if not isinstance(x, Fraction) else float(x)


# =========================
# Reward
# =========================
class RuleBasedRewardFn:
    """
    Expected strict response format:
      <reasoning> ... </reasoning>
      <answer> ... </answer>

    This implementation focuses on:
      - format_exact
      - format_exact+answer (single pathway, numeric-aware)
    """

    def __init__(
        self,
        reasoning_start: str = "<reasoning>",
        reasoning_end: str = "</reasoning>",
        solution_start: str = "<answer>",
        solution_end: str = "</answer>",
        weights: Optional[RuleWeights] = None,
        debug: bool = False,
    ):
        self.reasoning_start = reasoning_start
        self.reasoning_end = reasoning_end
        self.solution_start = solution_start
        self.solution_end = solution_end
        self.weights = weights or RuleWeights()
        self.debug = debug

        # Strict full-match: no extra text allowed outside the tags
        self.match_format_strict = re.compile(
            rf"^[\s]*"
            rf"{re.escape(self.reasoning_start)}(?P<reasoning>.+?){re.escape(self.reasoning_end)}"
            rf"[\s]*"
            rf"{re.escape(self.solution_start)}(?P<answer>.+?){re.escape(self.solution_end)}"
            rf"[\s]*$",
            flags=re.DOTALL,
        )

    # ------------------------
    # Format reward
    # ------------------------
    def format_exact(self, prompts, completions, **kwargs) -> List[float]:
        w = self.weights
        scores: List[float] = []

        rs, re_ = self.reasoning_start, self.reasoning_end
        as_, ae = self.solution_start, self.solution_end

        for response in completions:
            m = self.match_format_strict.search(response)
            if m is not None:
                reasoning = _normalize_text(m.group("reasoning"))
                ans = _normalize_text(m.group("answer"))

                if reasoning.strip() == "":
                    scores.append(w.format_empty_reasoning_penalty)
                    continue
                if ans.strip() == "":
                    scores.append(w.format_empty_answer_penalty)
                    continue

                scores.append(w.format_exact)
                continue

            # Non-strict format ‚Üí apply negative rewards
            has_all = (rs in response) and (re_ in response) and (as_ in response) and (ae in response)
            if not has_all:
                scores.append(w.format_missing_penalty)
                continue

            # All tags exist but strict match fails (wrong order, duplication, nesting, or trailing text)
            rpos = response.find(rs)
            apos = response.find(as_)
            if rpos != -1 and apos != -1 and rpos > apos:
                scores.append(w.format_order_penalty)
            else:
                scores.append(w.format_order_penalty)

        return scores

    # ------------------------
    # Answer scoring (numeric-aware)
    # ------------------------
    def _answer_score(self, guess: str, true: str) -> float:
        """
          1) exact / strip match
          2) if both parse as numbers: exact or ratio reward
          3) else extract first number from guess and compare to numeric true (exact -> numbers_exact else 0)
        """
        w = self.weights
        guess = _normalize_text(guess)
        true = _normalize_text(true)

        # (1) String-level exact and stripped matching
        if guess == true:
            return w.answer_exact
        if guess.strip() == true.strip():
            return w.answer_strip_match

        # (2) Direct numeric parsing path
        gnum = _parse_number(guess)
        tnum = _parse_number(true)
        if gnum is not None and tnum is not None:
            # Exact fraction comparison case
            if isinstance(gnum, Fraction) and isinstance(tnum, Fraction):
                if gnum == tnum:
                    return w.answer_exact
                tv = float(tnum)
                ratio = float(gnum) / tv if tv != 0.0 else math.inf
                return _ratio_score(w, ratio)

            # Floating-point comparison case
            try:
                gv = _to_float(gnum)
                tv = _to_float(tnum)
                if math.isclose(gv, tv, rel_tol=1e-6, abs_tol=1e-9):
                    return w.answer_exact
                ratio = gv / tv if tv != 0.0 else math.inf
                return _ratio_score(w, ratio)
            except Exception:
                return w.answer_parse_fail_penalty
        # (3) Extract the last numeric token from the guess and compare with the true value
        tnum2 = _parse_number(true)
        if tnum2 is None:
            return w.answer_parse_fail_penalty

        extracted = _extract_last_number_token(guess)
        if extracted is None:
            return w.answer_parse_fail_penalty

        gnum2 = _parse_number(extracted)
        if gnum2 is None:
            return w.answer_parse_fail_penalty

        try:
            gv = _to_float(gnum2)
            tv = _to_float(tnum2)
            # check_numbers-style behavior: only exact matches receive partial reward
            if math.isclose(gv, tv, rel_tol=1e-6, abs_tol=1e-12):
                return w.numbers_exact

            ratio = gv / tv if tv != 0.0 else math.inf
            raw = _ratio_score(w, ratio)
            return max(0.0, 0.5 * raw)

        except Exception:
            return w.answer_parse_fail_penalty

    def __call__(
        self,
        prompts,
        completions,
        answer=None,
        rule_mode: str = "format_exact",
        **kwargs,
    ) -> List[float]:
        rule_mode = rule_mode.strip()
        w = self.weights

        if rule_mode == "format_exact":
            return self.format_exact(prompts, completions, **kwargs)

        if rule_mode == "format_exact+answer":
            if answer is None:
                raise ValueError("rule_mode='format_exact+answer' requires `answer`.")

            fmt_scores = self.format_exact(prompts, completions, **kwargs)
            out: List[float] = []

            for response, true_raw, fscore in zip(completions, answer, fmt_scores):
                # Gate: apply answer reward only when strict format is satisfied (to match stable, high-performing behavior)
                if fscore != w.format_exact:
                    out.append(fscore)
                    continue

                m = self.match_format_strict.search(response)
                if m is None:
                    out.append(fscore)
                    continue

                guess_raw = m.group("answer")
                ascore = self._answer_score(guess_raw, true_raw)
                out.append(fscore + ascore)

            return out

        raise ValueError(f"Unknown rule_mode: {rule_mode}")

In [None]:
_REASON_RE = re.compile(r"<reasoning>(.*?)</reasoning>", re.DOTALL)


def _extract_reasoning(text: str) -> str:
    m = _REASON_RE.search(text)
    return m.group(1).strip() if m else ""


def _encode(tokenizer, s: str) -> List[int]:
    return list(tokenizer.encode(s))


def _left_pad_to(arr: np.ndarray, target_len: int, pad_id: int) -> np.ndarray:
    if arr.shape[0] >= target_len:
        return arr[-target_len:]
    pad = np.full((target_len - arr.shape[0],), pad_id, dtype=np.int32)
    return np.concatenate([pad, arr.astype(np.int32)], axis=0)


def _right_pad_stack(token_lists: List[List[int]], pad_id: int) -> Tuple[np.ndarray, np.ndarray]:
    """Returns (tokens[B, L], mask[B, L]) where mask is 1 for non-pad."""
    maxlen = max((len(x) for x in token_lists), default=0)
    toks = np.full((len(token_lists), maxlen), pad_id, dtype=np.int32)
    msk = np.zeros((len(token_lists), maxlen), dtype=np.int32)
    for i, ids in enumerate(token_lists):
        if ids:
            toks[i, : len(ids)] = np.asarray(ids, dtype=np.int32)
            msk[i, : len(ids)] = 1
    return toks, msk


def _mean_prob_from_logps(logps: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
    """
    logps: [B, L] per-token log-probabilities (log p)
    mask:  [B, L] 1 for valid tokens else 0
    returns: [B] mean token *probability* (exp(log p)) over valid tokens
        # Note: this is NOT mean log-prob; it averages exp(logp), which is typically small.
    """
    mask_f = mask.astype(jnp.float32)
    denom = jnp.clip(mask_f.sum(axis=-1), a_min=1.0)
    probs = jnp.exp(logps)
    return (probs * mask_f).sum(axis=-1) / denom


def _apply_rlpr_std_prompt_filter(
    rewards: np.ndarray,
    *,
    num_generations: int,
    beta: float,
    beta_min: float,
    beta_max: float,
    eps: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    rewards: [B] where B == N*G (G=num_generations)
    Returns:
      filtered_rewards: [B] (groups with std < beta are zeroed)
      group_std:       [N] if filtering is applied; otherwise empty array
    """
    G = int(num_generations)
    if G <= 1 or rewards.shape[0] % G != 0:
        return rewards, np.zeros((0,), dtype=np.float32)

    R = rewards.reshape(-1, G)  # [N, G]
    sd = R.std(axis=1, ddof=1).astype(np.float32)
    sd = np.maximum(sd, eps)

    beta = float(np.clip(beta, beta_min, beta_max))
    keep = sd >= beta  # [N]
    R2 = R * keep[:, None].astype(R.dtype)  # zero-out low-std prompts
    return R2.reshape(-1), sd

In [None]:
@dataclass
class StdFilterConfig:
    enabled: bool = True
    num_generations: int = 4
    ema_decay: float = 0.99
    beta_init: float = 0.0
    beta_min: float = 0.0
    beta_max: float = 1.0
    eps: float = 1e-6
    filtering_mode: Literal["fixed", "dynamic"] = "fixed"


class ProbabilityBasedRewardFn:
    """
    RLPR-style Probability Reward using ACTOR (rollout policy) probs.

    GRPO expects:
      fn(prompts, completions, answer, **kwargs) -> List[float]

    Requires training_input include:
      - "answer" (ground truth y*)
    """

    def __init__(
        self,
        *,
        rl_cluster,
        tokenizer,
        max_prompt_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        use_debias: bool = True,
        clip_min: float = 0.0,
        clip_max: float = 1.0,
        wrap_y_with_answer_tags: bool = True,
        std_filter: Optional[StdFilterConfig] = None,
        add_eos_to_y: bool = True,
        name: str = "probability_based_reward",
    
    ):
        self.rl_cluster = rl_cluster
        self.tokenizer = tokenizer

        self.__name__ = name

        self.pad_id = int(self.rl_cluster.rollout.pad_id())
        self.eos_id = int(self.rl_cluster.rollout.eos_id())


        rc = self.rl_cluster.cluster_config.rollout_config
        if max_prompt_length is None:
            if isinstance(rc, dict):
                # default: TRAIN config (dict key is Mode enum; safest is to pick TRAIN if exists)
                train_key = None
                for k in rc.keys():
                    if str(k).lower() == "train" or getattr(k, "value", "").lower() == "train":
                        train_key = k
                        break
                key = train_key if train_key is not None else next(iter(rc.keys()))
                self.max_prompt_length = int(rc[key].max_prompt_length)
            else:
                self.max_prompt_length = int(rc.max_prompt_length)
        else:
            self.max_prompt_length = int(max_prompt_length)

        self.micro_batch_size = micro_batch_size
        self.use_debias = bool(use_debias)
        self.clip_min = float(clip_min)
        self.clip_max = float(clip_max)
        self.wrap_y_with_answer_tags = bool(wrap_y_with_answer_tags)
        self.std_filter = std_filter if std_filter is not None else StdFilterConfig(enabled=False)
        self.add_eos_to_y = bool(add_eos_to_y)
        self._beta = float(self.std_filter.beta_init)
        self._beta_initialized = False

    def _actor_per_token_logps(
        self,
        prompt_tok: jax.Array,
        y_tok: jax.Array,
        y_mask: jax.Array,
    ) -> jax.Array:
         # "old" logps = logps under the rollout policy used to sample trajectories
        # (i.e., the current policy before the upcoming update step)
        return self.rl_cluster.get_old_per_token_logps(
            prompt_tokens=prompt_tok,
            completion_tokens=y_tok,
            micro_batch_size=self.micro_batch_size or int(y_tok.shape[0]),
            completion_mask=y_mask,
        )

    def _get_current_beta(self, group_sd: Optional[np.ndarray]) -> float:
        """
        Returns the beta to use this step.
        - fixed: always beta_init
        - dynamic: EMA over step_mean_sd (if group_sd provided and non-empty)
        """
        current_filtering_mode = self.std_filter.filtering_mode
    
        if current_filtering_mode == "fixed":
            beta = float(self.std_filter.beta_init)
            beta = float(np.clip(beta, self.std_filter.beta_min, self.std_filter.beta_max))
            self._beta = beta
            self._beta_initialized = True
            return beta
    
        if current_filtering_mode == "dynamic":
            if group_sd is None or group_sd.size == 0:
                # no update possible; keep previous beta (or init if first time)
                beta = self._beta if self._beta_initialized else float(self.std_filter.beta_init)
                beta = float(np.clip(beta, self.std_filter.beta_min, self.std_filter.beta_max))
                return beta
    
            step_mean_sd = float(np.mean(group_sd))
            if not self._beta_initialized:
                self._beta = step_mean_sd
                self._beta_initialized = True
            else:
                d = float(self.std_filter.ema_decay)
                self._beta = d * self._beta + (1.0 - d) * step_mean_sd
    
            self._beta = float(np.clip(self._beta, self.std_filter.beta_min, self.std_filter.beta_max))
            return float(self._beta)
    
        raise ValueError(f"Unknown filtering_mode: {current_filtering_mode!r}")
    
    def _std_filter_rewards(self, pr_np: np.ndarray) -> np.ndarray:
        """
        Apply RLPR std-based prompt filter.
        - Computes group std over G generations
        - Chooses beta (fixed or dynamic)
        - Zeroes out groups with std < beta
        """
        if not self.std_filter.enabled:
            return pr_np
    
        G = int(self.std_filter.num_generations)
        if G <= 1 or pr_np.shape[0] % G != 0:
            return pr_np
    
        # 1) compute group std (beta=0 just to get sd; keep-all)
        _, group_sd = _apply_rlpr_std_prompt_filter(
                pr_np,
                num_generations=G,
                beta=0.0,
                beta_min=self.std_filter.beta_min,
                beta_max=self.std_filter.beta_max,
                eps=self.std_filter.eps,
        )
    
        # 2) choose beta (fixed or EMA-updated)
        beta = self._get_current_beta(group_sd)
    
        # 3) apply filter with chosen beta
        pr_np_filtered, _ = _apply_rlpr_std_prompt_filter(
                pr_np,
                num_generations=G,
                beta=beta,
                beta_min=self.std_filter.beta_min,
                beta_max=self.std_filter.beta_max,
                eps=self.std_filter.eps,
        )
        return pr_np_filtered        
    
    def __call__(
        self,
        *,
        prompts: List[str],
        completions: List[str],
        answer: Sequence[str],
        **kwargs: Any,
    ) -> List[float]:
        answer_list = [str(a) for a in list(answer)]
        if len(prompts) != len(completions) or len(prompts) != len(answer_list):
            raise ValueError(
                f"Length mismatch: {len(prompts)=}, {len(completions)=}, {len(answer_list)=}"
            )

        prompt_tok_base: List[np.ndarray] = []
        prompt_tok_full: List[np.ndarray] = []
        y_tok_list: List[List[int]] = []

        for p, c, y in zip(prompts, completions, answer_list):
            # prompt only
            p_ids = _encode(self.tokenizer, p)
            prompt_tok_base.append(_left_pad_to(np.asarray(p_ids, np.int32), self.max_prompt_length, self.pad_id))

            # prompt + extracted reasoning (conditioning format for reward computation)
            z = _extract_reasoning(c)
            if z:
                z_text = f"<reasoning>{z}</reasoning>"
                full_text = p + "\n" + z_text
            else:
                full_text = p
            full_ids = _encode(self.tokenizer, full_text)
            prompt_tok_full.append(_left_pad_to(np.asarray(full_ids, np.int32), self.max_prompt_length, self.pad_id))

            # y* tokens (optionally wrap)
            y_text = f"<answer>{y}</answer>" if self.wrap_y_with_answer_tags else y
            y_ids = _encode(self.tokenizer, y_text)
            if self.add_eos_to_y:
                y_ids = y_ids + [self.eos_id]
            y_tok_list.append(y_ids)

        prompt_tok_base_j = jnp.asarray(np.stack(prompt_tok_base, axis=0), dtype=jnp.int32)
        prompt_tok_full_j = jnp.asarray(np.stack(prompt_tok_full, axis=0), dtype=jnp.int32)

        y_tok_np, y_mask_np = _right_pad_stack(y_tok_list, self.pad_id)
        y_tok_j = jnp.asarray(y_tok_np, dtype=jnp.int32)
        y_mask_j = jnp.asarray(y_mask_np, dtype=jnp.int32)

       
        actor_logps_full = self._actor_per_token_logps(prompt_tok_full_j, y_tok_j, y_mask_j)
        r_full = _mean_prob_from_logps(actor_logps_full, y_mask_j)

        if self.use_debias:
            actor_logps_base = self._actor_per_token_logps(prompt_tok_base_j, y_tok_j, y_mask_j)
            r_base = _mean_prob_from_logps(actor_logps_base, y_mask_j)  # [B]
            pr_raw = r_full - r_base
        else:
            pr_raw = r_full

        pr_clipped = jnp.clip(pr_raw, self.clip_min, self.clip_max)
        pr_np = np.asarray(jax.device_get(pr_clipped), dtype=np.float32)
        pr_np = self._std_filter_rewards(pr_np)

        return pr_np.tolist()

In [None]:
RewardFn = Callable[..., List[float]]


def _as_list_domain(x, n: int, default: str = "") -> list[str]:
    if x is None:
        return [default] * n

    if isinstance(x, np.ndarray):
        x = x.reshape(-1).tolist()

    if isinstance(x, (list, tuple)):
        if len(x) != n:
            raise ValueError(f"domain length mismatch: {len(x)} vs batch {n}")
        return [str(v) for v in x]

    raise ValueError(
        f"domain must be a list/array of length {n} (got scalar {type(x).__name__})"
    )


def _slice(v: Any, idxs: list[int], B: int) -> Any:
    """If a kwarg is a length-B vector, slice it using idxs before forwarding"""
    if v is None or isinstance(v, str):
        return v
    if isinstance(v, np.ndarray):
        flat = v.reshape(-1)
        return flat[idxs] if flat.shape[0] == B else v
    if isinstance(v, (list, tuple)):
        return [v[i] for i in idxs] if len(v) == B else v
    return v


def _infer_format_pass_value(rule_based_fn: Any, fallback: float = 3.0) -> float:
    w = getattr(rule_based_fn, "weights", None)
    if w is not None and hasattr(w, "format_exact"):
        try:
            return float(getattr(w, "format_exact"))
        except Exception:
            pass
    return float(fallback)


class HybridRewardFn:
    def __init__(
        self,
        *,
        rule_based_fn: Optional[RewardFn],
        pr_fn: RewardFn,
        w_rule: float = 1.0,
        w_prob: float = 1.0,
        rule_mode: str = "format_exact+answer",
        name: str = "hybrid_reward",
        apply_format_to_prob: bool = True,
        w_format_prob: float = 1.0,
        gate_prob_on_format: bool = False,
        format_pass_value: Optional[float] = None,
    ):
        if pr_fn is None:
            raise ValueError("pr_fn must be provided.")
        self.rule_based_fn = rule_based_fn
        self.pr_fn = pr_fn
        self.w_rule = float(w_rule)
        self.w_prob = float(w_prob)
        self.rule_mode = str(rule_mode)
        self.__name__ = name

        self.apply_format_to_prob = bool(apply_format_to_prob)
        self.w_format_prob = float(w_format_prob)
        self.gate_prob_on_format = bool(gate_prob_on_format)

        if format_pass_value is not None:
            self.format_pass_value = float(format_pass_value)
        else:
            self.format_pass_value = (
                _infer_format_pass_value(rule_based_fn, fallback=3.0)
                if rule_based_fn is not None
                else 3.0
            )

    def __call__(
        self,
        *,
        prompts: List[str],
        completions: List[str],
        answer: Sequence[str],
        reward_model_type: Any = None,
        **kwargs: Any,
    ) -> List[float]:

        B = len(prompts)
        if len(completions) != B or len(answer) != B:
            raise ValueError(f"Length mismatch: {len(prompts)=}, {len(completions)=}, {len(answer)=}")

        domains = _as_list_domain(kwargs.get("domain", None), B, default="")
        domains = [d.strip().lower() for d in domains]

        rule_idxs = [i for i, d in enumerate(domains) if d == "math"]
        prob_idxs = [i for i in range(B) if i not in set(rule_idxs)]

        out = np.zeros((B,), dtype=np.float32)

        # -----------------------------
        # 1) Rule group reward (math)
        # -----------------------------
        if rule_idxs:
            if self.rule_based_fn is None:
                raise ValueError("domain includes 'math' but rule_based_fn is None.")

            sub_kwargs = {k: _slice(v, rule_idxs, B) for k, v in kwargs.items()}
            rule_r = self.rule_based_fn(
                prompts=[prompts[i] for i in rule_idxs],
                completions=[completions[i] for i in rule_idxs],
                answer=[answer[i] for i in rule_idxs],
                rule_mode="format_exact+answer",
                **sub_kwargs,
            )
            if len(rule_r) != len(rule_idxs):
                raise RuntimeError(f"rule_based_fn returned {len(rule_r)} != {len(rule_idxs)}")

            rule_np = np.asarray([float(x) for x in rule_r], dtype=np.float32)
            for j, i in enumerate(rule_idxs):
                out[i] = self.w_rule * float(rule_np[j])

        # -----------------------------
        # 2) Prob(PR) group reward (non-math)
        #    + optionally add format_exact
        # -----------------------------
        if prob_idxs:
            if self.apply_format_to_prob and self.rule_based_fn is None:
                raise ValueError("apply_format_to_prob=True but rule_based_fn is None.")

            sub_kwargs = {k: _slice(v, prob_idxs, B) for k, v in kwargs.items()}
            
            fmt_np = None
            if self.apply_format_to_prob:
                fmt_r = self.rule_based_fn(
                    prompts=[prompts[i] for i in prob_idxs],
                    completions=[completions[i] for i in prob_idxs],
                    answer=[answer[i] for i in prob_idxs],  # format_exact ignores answer, but we pass it to keep batch lengths consistent
                    rule_mode="format_exact",
                    **sub_kwargs,
                )
                if len(fmt_r) != len(prob_idxs):
                    raise RuntimeError(f"format_exact returned {len(fmt_r)} != {len(prob_idxs)}")
                fmt_np = np.asarray([float(x) for x in fmt_r], dtype=np.float32)

            # Gating mode: if strict format fails, skip PR and apply only the format reward
            if self.gate_prob_on_format and fmt_np is not None:
                pass_val = self.format_pass_value

                keep_local = [j for j in range(len(prob_idxs)) if float(fmt_np[j]) == pass_val]
                drop_local = [j for j in range(len(prob_idxs)) if float(fmt_np[j]) != pass_val]

                for j in drop_local:
                    i = prob_idxs[j]
                    out[i] = self.w_format_prob * float(fmt_np[j])
                if keep_local:
                    keep_idxs = [prob_idxs[j] for j in keep_local]
                    keep_kwargs = {k: _slice(v, keep_idxs, B) for k, v in kwargs.items()}

                    pr_r = self.pr_fn(
                        prompts=[prompts[i] for i in keep_idxs],
                        completions=[completions[i] for i in keep_idxs],
                        answer=[answer[i] for i in keep_idxs],
                        **keep_kwargs,
                    )
                    if len(pr_r) != len(keep_idxs):
                        raise RuntimeError(f"pr_fn returned {len(pr_r)} != {len(keep_idxs)}")

                    pr_np = np.asarray([float(x) for x in pr_r], dtype=np.float32)
                    
                    for t, j_local in enumerate(keep_local):
                        i = prob_idxs[j_local]
                        fmt_bonus = self.w_format_prob * float(fmt_np[j_local])
                        out[i] = self.w_prob * float(pr_np[t]) + fmt_bonus

                return out.tolist()
            # Additive mode: compute PR for all samples and add the format bonus    
            pr_r = self.pr_fn(
                prompts=[prompts[i] for i in prob_idxs],
                completions=[completions[i] for i in prob_idxs],
                answer=[answer[i] for i in prob_idxs],
                **sub_kwargs,
            )
            if len(pr_r) != len(prob_idxs):
                raise RuntimeError(f"pr_fn returned {len(pr_r)} != {len(prob_idxs)}")

            pr_np = np.asarray([float(x) for x in pr_r], dtype=np.float32)


            for j, i in enumerate(prob_idxs):
                fmt_bonus = self.w_format_prob * float(fmt_np[j]) if fmt_np is not None else 0.0
                out[i] = self.w_prob * float(pr_np[j]) + fmt_bonus

        return out.tolist()


## üì¶ Load Gemma2-2B-it Model

In [None]:
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

In [None]:
params = params_lib.load_and_format_params(
      os.path.join(kaggle_ckpt_path, "gemma2-2b-it")
  )
gemma = model_lib.Transformer.from_params(params, version="2-2b-it")
checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)
checkpointer.save(os.path.join(SFT_INTERMEDIATE_CKPT_DIR, "state"), state)
checkpointer.wait_until_finished()
# Delete the intermediate model to save memory.
del params
del gemma
del state
gc.collect()

In [None]:
def get_gemma_base_model(ckpt_path):
  mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[1]))
  model_config = model_lib.ModelConfig.gemma2_2b()
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: model_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
  )
  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(os.path.join(ckpt_path, "state"), target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Load Base model
base_model, mesh, model_config = get_gemma_base_model(SFT_INTERMEDIATE_CKPT_DIR)

# Apply LoRA
lora_policy = get_lora_model(base_model, mesh=mesh)

In [None]:
print("\n--- HBM Usage After Model Load ---")
show_hbm_usage()

## üöÄ Train

In [None]:
from kaggle_secrets import UserSecretsClient
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("WANDB_API_KEY")

### 1Ô∏è‚É£ SFT Stage (1)

In [None]:
sft_checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SFT_SAVE_INTERVAL_STEPS, max_to_keep=SFT_MAX_TO_KEEP
)

sft_logging_options = metrics_logger.MetricsLoggerOptions(
        log_dir=SFT_LOG_DIR, flush_every_n_steps=20
    )

In [None]:
sft_optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=SFT_LEARNING_RATE,
        warmup_steps=SFT_WARMUP_STEPS,
        decay_steps=max(SFT_MAX_STEPS, 1),
        end_value=0.0,
    ),
    weight_decay=SFT_WEIGHT_DECAY,
)

if SFT_MAX_GRAD_NORM is not None:
  sft_optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=SFT_MAX_GRAD_NORM),
      sft_optimizer,
  )

 
sft_training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=SFT_EVAL_EVERY_N_STEPS,
    max_steps=SFT_MAX_STEPS,
    gradient_accumulation_steps=SFT_GRAD_ACCUM_STEPS,
    checkpoint_root_directory=SFT_CKPT_DIR,
    checkpointing_options=sft_checkpointing_options,
    metrics_logging_options=sft_logging_options,
)

In [None]:
sft_trainer = peft_trainer.PeftTrainer(
        model=lora_policy,
        optimizer=sft_optimizer,
        training_config=sft_training_config,
    ).with_gen_model_input_fn(gen_model_input_fn)

In [None]:
with mesh:
    sft_trainer.train(sft_train_dataset, sft_val_dataset)

### 1Ô∏è‚É£ SFT Stage (2)

In [None]:
# Helper function to load ckpt
def get_latest_step(ckpt_root: str) -> int:
    steps = []
    for name in os.listdir(ckpt_root):
        path = os.path.join(ckpt_root, name)
        if name.isdigit() and os.path.isdir(path):
            steps.append(int(name))
    if not steps:
        raise ValueError("No checkpoint steps found")
    return max(steps)

In [None]:
wandb.init(project="tunix-SFT-2")

latest_step = get_latest_step(SFT_CKPT_DIR)
print(f"Load {latest_step} step Ckpt")
sft_stage1_ckpt_dir =  os.path.join(
    SFT_CKPT_DIR, str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(sft_stage1_ckpt_dir, target=abs_params)

with mesh:
    nnx.update(
        lora_policy,
        jax.tree.map(
            lambda a, b: b,
            nnx.state(lora_policy, nnx.LoRAParam),
            trained_lora_params,
        ),
    )

In [None]:
sft_stage2_optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=SFT_STAGE2_LEARNING_RATE,
        warmup_steps=int(0.1 * SFT_STAGE2_MAX_STEPS),
        decay_steps=max(SFT_STAGE2_MAX_STEPS, 1),
        end_value=0.0,
    ),
    weight_decay=SFT_WEIGHT_DECAY,
)

if SFT_MAX_GRAD_NORM is not None:
  sft_stage2_optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=SFT_MAX_GRAD_NORM),
      sft_stage2_optimizer,
  )

 
sft_stage2_training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=SFT_EVAL_EVERY_N_STEPS,
    max_steps=SFT_STAGE2_MAX_STEPS,
    gradient_accumulation_steps=SFT_GRAD_ACCUM_STEPS,
    checkpoint_root_directory=SFT_STAGE2_CKPT_DIR,
    checkpointing_options=sft_checkpointing_options,
    metrics_logging_options=sft_logging_options,
)

In [None]:
sft_stage2_trainer = peft_trainer.PeftTrainer(
        model=lora_policy,
        optimizer=sft_stage2_optimizer,
        training_config=sft_stage2_training_config,
    ).with_gen_model_input_fn(gen_model_input_fn)

In [None]:
with mesh:
    sft_stage2_trainer.train(sft_stage2_train_dataset, sft_stage2_val_dataset)

### 2Ô∏è‚É£ RL Stage

In [None]:
wandb.init(project="tunix-RL")

latest_step = get_latest_step(SFT_STAGE2_CKPT_DIR)

sft_two_trained_ckpt_dir = f"{SFT_STAGE2_CKPT_DIR}/{latest_step}/model_params"

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)

checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(sft_two_trained_ckpt_dir, target=abs_params)
with mesh:
    nnx.update(
        lora_policy,
        jax.tree.map(
            lambda a, b: b,
            nnx.state(lora_policy, nnx.LoRAParam),
            trained_lora_params,
        ),
    )

In [None]:
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=RL_SAVE_INTERVAL_STEPS, max_to_keep=RL_MAX_TO_KEEP
)

# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tmp/tensorboard/grpo", flush_every_n_steps=20
)

In [None]:
optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=RL_LEARNING_RATE,
        warmup_steps=RL_WARMUP_STEPS,
        decay_steps=RL_MAX_STEPS,
        end_value=0.0,
    ),
    b1=RL_B1,
    b2=RL_B2,
    weight_decay=RL_WEIGHT_DECAY,
)
if RL_MAX_GRAD_NORM is not None:
  optimizer = optax.chain(
      optax.clip_by_global_norm(max_norm=RL_MAX_GRAD_NORM),
      optimizer,
  )

In [None]:
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=RL_EVAL_EVERY_N_STEPS,
        max_steps=RL_MAX_STEPS,
        mini_batch_size=RL_TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=RL_TRAIN_MICRO_BATCH_SIZE,
        # metrics logging
        metrics_logging_options=metrics_logging_options,
        # checkpoint saving
        checkpoint_root_directory=FINAL_CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=RL_TOTAL_GENERATION_STEPS,
        max_prompt_length=RL_MAX_PROMPT_LENGTH,
        kv_cache_size=RL_MAX_PROMPT_LENGTH + RL_TOTAL_GENERATION_STEPS + 256,
        temperature=RL_TEMPERATURE,
        top_p=RL_TOP_P,
        top_k=RL_TOP_K,
        eos_tokens=[1,106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=RL_NUM_GENERATIONS,
    num_iterations=RL_NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

In [None]:
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=base_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

In [None]:
pr_fn = ProbabilityBasedRewardFn(
    rl_cluster=rl_cluster,
    tokenizer=tokenizer,
    use_debias=True,
    std_filter=StdFilterConfig(
    enabled=True,
    num_generations=4,
    filtering_mode="dynamic",
    beta_init=0.0,
),)

rule_fn = RuleBasedRewardFn()


hybrid = HybridRewardFn(
    rule_based_fn=rule_fn,
    pr_fn=pr_fn,
    w_rule=1.0,
    w_prob=2.0,
)

In [None]:
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[
hybrid
    ],
    grpo_config=grpo_config,
)

In [None]:
with mesh:
  grpo_trainer.train(rl_train_dataset)

## ‚úÖ Train from CKPT

In [None]:
USE_CKPT = True
ckpt_from = "username/modelname"

if USE_CKPT:
    !kaggle kernels output {ckpt_from} -p /kaggle/working/

In [None]:
def get_gemma_base_model_from_ckpt(ckpt_path):
    mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[1]))
    model_config = model_lib.ModelConfig.gemma2_2b()
    abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: model_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
    )
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(os.path.join(ckpt_path, "intermediate_ckpt","state"), target=abs_state)

    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    
    return gemma, mesh, model_config

In [None]:
# del base_model
# del lora_policy
# gc.collect()

print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()

local_ckpt_dir = "/kaggle/working/ckpts"

base_model, mesh, model_config = get_gemma_base_model_from_ckpt(local_ckpt_dir)
lora_policy = get_lora_model(base_model, mesh=mesh)

def get_latest_step(ckpt_root: str) -> int:
    steps = []
    for name in os.listdir(ckpt_root):
        path = os.path.join(ckpt_root, name)
        if name.isdigit() and os.path.isdir(path):
            steps.append(int(name))
    if not steps:
        raise ValueError("No checkpoint steps found")
    return max(steps)

actor_dir = os.path.join(local_ckpt_dir, "actor")
latest_step = get_latest_step(actor_dir) 
local_ckpt_dir = os.path.join(actor_dir, str(latest_step),"model_params") #FINAL_CKPT_DIR, "actor",str(100), "model_params"


abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(local_ckpt_dir, target=abs_params)
with mesh:
    nnx.update(
        lora_policy,
        jax.tree.map(
            lambda a, b: b,
            nnx.state(lora_policy, nnx.LoRAParam),
            trained_lora_params,
        ),
    )

print("\n--- HBM Usage AFTER Model Load ---")

show_hbm_usage()

In [None]:
with mesh:
  grpo_trainer.train(rl_train_dataset)

## Unrestricted mode

In [None]:
unrestricted_kaggle_model = "sangminlee09/dod-gemma2-2b-it-reasoning-800"

# üí¨ Other things we want the judges to know
---

## üí° Lesson Learned
**CoT Length**

Long chain-of-thought (CoT) data naturally encode advanced reasoning behaviors‚Äîself-reflection, verification, and dynamic strategy adaptation‚Äîthat are typically absent from short CoT.

However, for small models (at least in the LoRA-based SFT setting with a ~20k dataset), long CoTs are difficult to learn effectively. We observe that the model tends to overthink even on easy problems (e.g., GSM8K), producing unnecessarily long reasoning traces and still arriving at incorrect answers. On such easy tasks, short CoT training often performs better, yielding more stable and accurate solutions. At the same time, training only on short CoT limits the model‚Äôs reasoning capacity: it fails to acquire the higher-order abilities that long CoT enables, such as extended planning and iterative verification.

Therefore, mixing long and short CoT in an appropriate ratio appears crucial during SFT as a pre-training stage for subsequent RL. Short CoT helps anchor correctness and efficiency on simple tasks, while long CoT exposes the model to richer reasoning patterns that can be further amplified and refined during RL.

**Data Domain Ratio**

Verifiable data provides stable anchors for correctness, while non-verifiable data enables broad general reasoning. We aim to maximize the proportion of general-domain data while maintaining (or at least not degrading) basic mathematical ability, even if improvements are only marginal. In this stage, enforcing a stable reasoning format is also critical. We therefore tune the dataset mixture by monitoring GSM8K accuracy as a proxy for core math competence and format accuracy to ensure consistent adherence to the required output structure in the SFT stage.

**Make Reasoning Consice and Compact**

Concise and compact reasoning does not emerge from aggressive truncation alone.
Instead, it arises from a staged process in which the model is first exposed to rich reasoning behaviors, then further learn more at learnable reasoning lengths, and finally optimized under strict generation constraints, together with verified rewards, implicitly favor efficient reasoning by encouraging correctness within limited generation budgets.


üí¨ **Comment**
- The integration with Weights & Biases (W&B) made experiment tracking and iteration straightforward and convenient.
- Tighter integration with standard LLM benchmarks would be highly beneficial, making the framework a much stronger post-training solution.
- Accessing TPUs required long wait times, which significantly slowed down experimentation.

---

# Competition evaluation

### Load the final checkpoint for evaluation

In [None]:
CKPT_DIR = FINAL_CKPT_DIR

import re

# Find the latest checkpoint by listing directories in CKPT_DIR/actor
actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

latest_step = -1
if os.path.exists(actor_ckpt_dir):
  for item in os.listdir(actor_ckpt_dir):
    if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
      step = int(item)
      if step > latest_step:
        latest_step = step

if latest_step == -1:
  raise FileNotFoundError(f"No checkpoints found in {actor_ckpt_dir}")

print(f"Latest checkpoint step: {latest_step}")

wandb.init(project='tunix-eval')  # logging bug workaround

loaded_ckpt_path = os.path.join(
    CKPT_DIR, "actor", str(latest_step), "model_params"
)

abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
trained_lora_params = checkpointer.restore(loaded_ckpt_path, target=abs_params)

nnx.update(
    lora_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(lora_policy, nnx.LoRAParam),
        trained_lora_params,
    ),
)

### Create the sampler for finetuned model

In [None]:
sampler = sampler_lib.Sampler(
    transformer=lora_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_PROMPT_LENGTH + MAX_GENERATION_STEPS + 256,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

### Evaluate

In [None]:
class TunixHackathonJudge:
    questions = ['question1', 'question2', ...]
    judge = "ai"
    
    def __init__(self, temperature, top_k, top_p, max_generation_steps, seed):
        ...

    def evaluate(self, sampler, prompt):
        ...

Result = TunixHackathonJudge(INF_TEMPERATURE, INF_TOP_K, INF_TOP_P, MAX_GENERATION_STEPS, SEED).evaluate(sampler, PROMPT_TEMPLATE)

# Unrestricted mode (multi-session mode)

In [None]:
unrestricted_kaggle_model = "sangminlee09/dod-gemma2-2b-it-reasoning-800"