# Multi-Domain GRPO Training for Reasoning

## Overview

This notebook demonstrates how to fine-tune **Gemma 3 1B** using **Tunix** and **GRPO (Group Relative Policy Optimization)** to teach the model structured reasoning across multiple domains.

### Why This Approach?

**GRPO** (inspired by [DeepSeek-R1](https://arxiv.org/abs/2401.02954)) is a reinforcement learning technique that trains models to follow specific output formats while maintaining response quality. Unlike supervised fine-tuning, GRPO learns from comparative feedback—generating multiple responses and rewarding the best ones.

**Gemma 3 1B** offers an excellent balance of capability and efficiency:
- Fits comfortably within Kaggle TPU memory constraints
- Fast iteration cycles for experimentation
- Strong baseline reasoning capabilities to build upon

### What We Built

A multi-domain reasoning model that:
1. **Shows its work** in `<reasoning>` tags before answering
2. **Generalizes** across 7 domains: Math, Coding, Science, Logic, Creative Writing, Summarization, and Creative Ideation
3. **Uses DISCO sampling** to balance domain representation during training

### Key Results

| Domain Type | Format Compliance |
|-------------|-------------------|
| Verifiable (Math, Coding, Science, Logic) | **>95%** |
| Unverifiable (Creative, Summarization) | **85-95%** |

### Lessons Learned

- **Prompt consistency matters**: The training prompt format must match inference exactly
- **Token limits**: Max output tokens during training should match or exceed inference requirements
- **Domain balance**: DISCO temperature sampling prevents overfitting to majority domains

---

**Domains covered:** Math • Coding • Science • Logic • Creative Writing • Summarization • Creative Ideation

---

## Install Dependencies

In [None]:
# Clean up
!pip uninstall -q -y gensim bigframes tensorflow-decision-forests tf-keras flax jax jaxlib qwix tunix

# Install Google Cloud SDKs
!pip install -U -q google-cloud-storage google-cloud-automl google-cloud-bigquery protobuf

# Install NumPy 2.0
!pip install -q "numpy>=2.0" "ml_dtypes>=0.4.0"

# Install EXACT working versions
!pip install -q \
    "jax[tpu]==0.8.1" \
    "flax==0.12.1" \
    "qwix==0.1.4" \
    "optax==0.2.6" \
    "orbax-checkpoint==0.11.31" \
    "chex==0.1.91" \
    "google-tunix[prod]==0.1.3" \
    tensorflow \
    kagglehub \
    grain \
    humanize

---

## Imports

In [None]:
# Standard library
import functools
import gc
import json
import os
import re
import shutil
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple

# Third-party
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import numpy as np
import optax
import pandas as pd
import qwix
from flax import nnx
from orbax import checkpoint as ocp
from tqdm import tqdm

# Tunix
from tunix.models.gemma3 import model, params
from tunix.generate import sampler as sampler_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.sft import metrics_logger
from tunix.generate import sampler as sampler_lib

---

## Hyperparameters

### GRPO Settings

**Group Size (`NUM_GENERATIONS = 4`):** Each prompt generates 4 responses, which are compared to compute relative rewards. More generations provide better signal, but 4 is the sweet spot for 9 hours session.

**Iterations (`NUM_ITERATIONS = 4`):** We update the policy 4 times per batch, extracting more learning from each data sample. It is important when training data is limited by session time.

**KL Penalty (`BETA = 0.04`):** Controls how much the policy can diverge from the base model. We use a lower value to allow learning new behaviors (XML formatting) while preventing collapse.

**Clip Range (`EPSILON = 0.2`):** Standard PPO clipping—prevents destructively large updates while allowing meaningful learning.

---

### Learning Rate & Optimization

**Learning Rate (`3e-6`):** Conservative for RL. Unlike SFT, GRPO gradients can be noisy, so smaller steps improve stability.

**Warmup (10% of steps):** Gradual ramp-up while early reward signals stabilize.

**Gradient Clipping (`0.1`):** Tight clipping catches RL gradient spikes that could destabilize training.

---

### LoRA Configuration

**Rank 16, Alpha 32:** Middle-ground expressiveness. Rank 16 is sufficient for learning format patterns; alpha at 2× rank is standard scaling.

**Target Modules:** All attention projections (q, kv, attn_vec) plus MLP layers (gate, up, down); comprehensive coverage with ~0.5% of total parameters.

---

### Generation Settings

**Max Prompt Length (1024):** Fits long inputs (articles, code) while leaving headroom for generation.

**Generation Steps (512):** Enough for reasoning + answer. The format pattern generalizes to longer inference (700-1000 tokens).

**Temperature (0.7):** Encourages diverse reasoning paths during training.

In [None]:
# ==================== PATHS ====================
DATASET_PATH = "/kaggle/input/harmonic-oscillation/full_dataset_pool.jsonl"
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/grpo_checkpoints/"

# ==================== TRAINING DATA CONFIG ====================
TRAIN_SIZE = 10000
VAL_SIZE = 500
BATCH_SIZE = 2
DISCO_TEMP = 0.5
SEED = 42

# ==================== GRPO CONFIG ====================
NUM_EPOCHS = 1
NUM_ITERATIONS = 4
NUM_GENERATIONS = 4

# RL hyperparameters
BETA = 0.04
EPSILON = 0.2

# Training config
TRAIN_MICRO_BATCH_SIZE = 2
MAX_STEPS = int((TRAIN_SIZE / BATCH_SIZE) * NUM_ITERATIONS * NUM_EPOCHS)
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

# Checkpointing
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 3
EVAL_EVERY_N_STEPS = 500

# ==================== MODEL CONFIG ====================
MODEL_CP_PATH = params.GEMMA3_1B_IT
MESH = ((1, 4), ("fsdp", "tp"))

# LoRA config
LORA_RANK = 16
LORA_ALPHA = 32.0

# ==================== GENERATION CONFIG ====================
MAX_PROMPT_LENGTH = 1024          # Both need same limit
TOTAL_GENERATION_STEPS = 512      # MUST MATCH OR INFERENCE >= TRAINING
TEMPERATURE = 0.7
INFERENCE_TEMPERATURE = 0.0
TOP_P = 0.95
TOP_K = 50

# ==================== PROMPTING ====================
SYSTEM_PROMPT = """Provide your reasoning in <reasoning> tags, then your final answer in <answer> tags.
Format:
<reasoning>Your step-by-step thinking</reasoning>
<answer>Your final answer</answer>"""

TEMPLATE = """<start_of_turn>user
{system_prompt}
Task: {question}<end_of_turn>
<start_of_turn>model"""

---

## Training Data Prepartion

### Data Sources

We curated a multi-domain dataset from 8 public sources, unified into a single JSONL file:

| Domain | Source Dataset | Source | Samples |
|--------|---------------|--------|---------|
| Math | [GSM8K](https://www.kaggle.com/datasets/thedevastator/grade-school-math-8k-q-a) | Kaggle | 7,473 |
| Coding | [MBPP](https://www.kaggle.com/datasets/mbppjsonl) | Kaggle | 974 |
| Science | [SciQ](https://www.kaggle.com/datasets/thedevastator/sciq-a-dataset-for-science-question-answering) | Kaggle | 13,610 |
| Summarization | [CNN/DailyMail](https://www.kaggle.com/datasets/newspaper-text-summarization-cnn-dailymail) | Kaggle | 10,000 |
| Logic | [StrategyQA](https://huggingface.co/datasets/wics/strategy-qa) | HuggingFace | 2,260 |
| Creative Writing | [WritingPrompts](https://huggingface.co/datasets/euclaise/writingprompts) | HuggingFace | 5,968 |
| Creative Ideation | [Dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) + [No Robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots) | HuggingFace | 2,831 |

**Total pool: 43,116 samples**

### DISCO Temperature Sampling

Raw datasets have imbalanced domain distributions (Science: 31.6% vs Coding: 2.3%). We use **DISCO sampling** with temperature T=0.5 to create a more balanced training set:

| Domain | Natural % | DISCO T=0.5 % | Training Samples |
|--------|-----------|---------------|------------------|
| Science | 31.6% | 22.8% | 2,277 |
| Summarization | 23.2% | 19.5% | 1,951 |
| Math | 17.3% | 16.9% | 1,687 |
| Creative Writing | 13.8% | 15.1% | 1,507 |
| Creative Ideation | 6.6% | 10.4% | 1,038 |
| Logic | 5.2% | 9.3% | 927 |
| Coding | 2.3% | 6.1% | 609 |

DISCO "flattens" the distribution, boosting underrepresented domains (Coding: 2.3% → 6.1%) while preserving some natural proportions.

### Data Schema

Each sample follows a unified schema:
```json
{
  "domain": "math",
  "prompt": "If a train travels 60 mph for 2 hours, how far does it go?",
  "answer": "120 miles",
  "metadata": {"source": "GSM8K"}
}
```

For details on data curation, see the [Dataset pipeline](https://www.kaggle.com/code/fissalalsharef/dataset-curation-pipeline)

---

## Core Components

The following classes handle data loading, format validation, and domain-specific answer checking:

- **DatasetLoader**: Loads JSONL data with DISCO sampling
- **XMLValidator**: Ensures `<reasoning>` and `<answer>` tags are properly formatted
- **DomainValidator**: Checks answer correctness per domain (Math, Coding, Science, Logic)

In [None]:
class DatasetLoader:
    """Loads and preprocesses multi-domain JSONL dataset with DISCO sampling."""
    
    def __init__(self, jsonl_path: str):
        """
        Initialize loader.
        
        Args:
            jsonl_path: Path to JSONL file with dataset pool
        """
        self.jsonl_path = jsonl_path
        self.df = None
    
    def load(self) -> pd.DataFrame:
        """Load JSONL file into DataFrame."""
        print(f"Loading dataset from {self.jsonl_path}...")
        
        self.df = pd.read_json(self.jsonl_path, lines=True)
        
        return self.df
    
    def compute_disco_proportions(self, temperature: float = 1.0) -> Dict[str, float]:
        """
        Compute DISCO-adjusted domain proportions.
        
        Args:
            temperature: DISCO temperature
                - T=0.5: Moderate balancin
        """
        if self.df is None:
            self.load()
        
        # Get natural proportions
        domain_counts = self.df['domain'].value_counts()
        total = len(self.df)
        natural_props = {domain: count / total for domain, count in domain_counts.items()}
        
        # Apply DISCO temperature
        if temperature == 1.0:
            # Natural distribution
            adjusted = natural_props
        else:
            # Temperature-adjusted distribution
            adjusted_unnormalized = {
                domain: prop ** temperature
                for domain, prop in natural_props.items()
            }
            
            # Normalize to sum to 1.0
            total_adjusted = sum(adjusted_unnormalized.values())
            adjusted = {
                domain: val / total_adjusted
                for domain, val in adjusted_unnormalized.items()
            }
        
        return adjusted
    
    def sample_dataset(
        self,
        total_size: int,
        temperature: float,
        seed: int
    ) -> pd.DataFrame:
        """
        Sample dataset using DISCO proportions.
        
        Args:
            total_size: Total number of samples to draw
            temperature: DISCO temperature 0.5
            seed: Random seed for reproducibility
        
        Returns:
            DataFrame with sampled data
        """
        if self.df is None:
            self.load()
        
        # Compute DISCO proportions
        proportions = self.compute_disco_proportions(temperature)
        
        # Sample from each domain
        sampled_dfs = []
        
        print(f"\nSampling {total_size} samples:")
        for domain, proportion in proportions.items():
            target_count = int(total_size * proportion)
            domain_df = self.df[self.df['domain'] == domain]
            available = len(domain_df)
            
            # Check if we have enough samples
            if target_count > available:
                sampled = domain_df
            else:
                sampled = domain_df.sample(n=target_count, random_state=seed)
            
            sampled_dfs.append(sampled)
            print(f"  {domain:<20}: Sampled {len(sampled):4d} samples")
        
        # Combine and shuffle
        combined = pd.concat(sampled_dfs, ignore_index=True)
        combined = combined.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        
        print(f"\nTotal: {len(combined)} samples")
        return combined
    
    def create_datasets(
        self,
        train_size: int,
        temperature: float,
        batch_size: int,
        seed: int,
        val_size: Optional[int] = None
    ) -> Tuple[grain.MapDataset, Optional[grain.MapDataset]]:
        """
        Create train and validation Grain datasets.
        
        Args:
            train_size: Number of training samples
            val_size: Number of validation samples (optional)
            temperature: DISCO temperature
            batch_size: Batch size
            seed: Random seed
        
        Returns:
            (train_dataset, val_dataset) as Grain datasets
        """
        # Sample training data
        print("="*60)
        print("TRAINING DATA")
        print("="*60)
        train_df = self.sample_dataset(
            total_size=train_size,
            temperature=temperature,
            seed=seed
        )
        
        train_dataset = self._to_grain_dataset(train_df, batch_size, shuffle=True, seed=seed)
        
        # Validation dataset (optional)
        val_dataset = None
        if val_size:
            print("\n" + "="*60)
            print("VALIDATION DATA")
            print("="*60)
            val_df = self.sample_dataset(
                total_size=val_size,
                temperature=temperature,
                seed=seed + 1  # Different seed
            )
            val_dataset = self._to_grain_dataset(val_df, batch_size, seed, shuffle=False)
        
        return train_dataset, val_dataset
    
    def _to_grain_dataset(
        self,
        df: pd.DataFrame,
        batch_size: int,
        seed: int,
        shuffle: bool = True,
    ) -> grain.MapDataset:
        """Convert DataFrame to batched Grain dataset."""
        # Convert to list of dicts
        data = df.to_dict('records')
        
        # Create Grain dataset
        dataset = grain.MapDataset.source(data)
        
        if shuffle:
            dataset = dataset.shuffle(seed=seed)
        
        # Map to GRPO format (includes truncation!)
        dataset = dataset.map(self._format_for_grpo)
        
        # Batch
        dataset = dataset.batch(batch_size)
        
        return dataset
    
    def _format_for_grpo(self, item: Dict) -> Dict:
        """
        Format a single item for GRPO training.
        
        Truncates very long prompts to fit within token limits.
        """
        
        # Truncate long prompts
        prompt_text = item['prompt']
        MAX_CHARS = 2500  # ~625 tokens (safe for 1024 limit)
        
        if len(prompt_text) > MAX_CHARS:
            prompt_text = prompt_text[:MAX_CHARS] + "..."
        
        # Format prompt
        formatted_prompt = TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=prompt_text
        )
        
        # Normalize metadata
        metadata = item.get('metadata', {})
        if not isinstance(metadata, dict):
            metadata = {}
        
        return {
            "prompts": formatted_prompt,
            "domain": item['domain'],
            "question": item['prompt'],
            "answer": item['answer'],
            "metadata_str": json.dumps(metadata),
        }

def load_dataset_for_training(
    jsonl_path: str,
    train_size: int,
    batch_size: int,
    disco_temperature: float,
    seed: int = 42,
    val_size: Optional[int] = None
) -> Tuple[grain.MapDataset, Optional[grain.MapDataset]]:
    """
    Convenience function to load dataset in one call.
    
    Args:
        jsonl_path: Path to JSONL dataset
        train_size: Number of training samples
        val_size: Number of validation samples (None to skip)
        batch_size: Batch size
        disco_temperature: DISCO temperature
            - 0.5 = Moderate balancing
        seed: Random seed
    
    Returns:
        (train_dataset, val_dataset)
    
    Example:
        train_ds, val_ds = load_dataset_for_training(
            '/kaggle/input/dataset/pool.jsonl',
            train_size=15000,
            disco_temperature=0.5
        )
    """
    loader = DatasetLoader(jsonl_path)
    return loader.create_datasets(
        train_size=train_size,
        val_size=val_size,
        temperature=disco_temperature,
        batch_size=batch_size,
        seed=seed
    )


In [None]:
# Load Training & Validation dataset.
print("="*60)
print("LOADING DATASET")
print("="*60)

train_dataset, val_dataset = load_dataset_for_training(
    jsonl_path=DATASET_PATH,
    train_size=TRAIN_SIZE,
    val_size=VAL_SIZE,
    batch_size=BATCH_SIZE,
    disco_temperature=DISCO_TEMP,
    seed=SEED
)

print(f"\n✓ Train dataset: {type(train_dataset)}")
print(f"✓ Val dataset: {type(val_dataset)}")

In [None]:
class XMLValidator:
    """Validates XML structure with strict ordering and uniqueness checks."""
    
    REASONING_START = "<reasoning>"
    REASONING_END = "</reasoning>"
    ANSWER_START = "<answer>"
    ANSWER_END = "</answer>"
    
    @classmethod
    def is_valid(cls, response: str) -> bool:
        """
        Check if response has valid XML structure.
        
        Requirements:
        1. Exactly ONE <reasoning> tag pair
        2. Exactly ONE <answer> tag pair
        3. <reasoning> appears BEFORE <answer>
        4. No overlapping/nested tags
        5. Content not empty
        
        Returns:
            bool: True if valid, False otherwise
        """
        # Check tag counts
        if response.count(cls.REASONING_START) != 1 or response.count(cls.REASONING_END) != 1:
            return False
        if response.count(cls.ANSWER_START) != 1 or response.count(cls.ANSWER_END) != 1:
            return False
        
        # Check order
        reasoning_start_pos = response.find(cls.REASONING_START)
        reasoning_end_pos = response.find(cls.REASONING_END)
        answer_start_pos = response.find(cls.ANSWER_START)
        answer_end_pos = response.find(cls.ANSWER_END)
        
        # Reasoning must come before answer
        if reasoning_start_pos >= answer_start_pos:
            return False
        
        # Tags must be properly paired (start before end)
        if reasoning_start_pos >= reasoning_end_pos:
            return False
        if answer_start_pos >= answer_end_pos:
            return False
        
        # No overlapping (reasoning must fully end before answer starts)
        if reasoning_end_pos >= answer_start_pos:
            return False
        
        # Extract content and check not empty
        reasoning = cls.extract_reasoning(response)
        answer = cls.extract_answer(response)
        
        if not reasoning or not answer:
            return False
        if len(reasoning.strip()) == 0 or len(answer.strip()) == 0:
            return False
        
        return True
    
    @classmethod
    def extract_reasoning(cls, response: str) -> Optional[str]:
        """Extract reasoning content between tags."""
        pattern = rf"{re.escape(cls.REASONING_START)}(.*?){re.escape(cls.REASONING_END)}"
        match = re.search(pattern, response, re.DOTALL)
        return match.group(1).strip() if match else None
    
    @classmethod
    def extract_answer(cls, response: str) -> Optional[str]:
        """Extract answer content between tags."""
        pattern = rf"{re.escape(cls.ANSWER_START)}(.*?){re.escape(cls.ANSWER_END)}"
        match = re.search(pattern, response, re.DOTALL)
        return match.group(1).strip() if match else None


In [None]:
class DomainValidator(ABC):
    """Abstract base class for domain-specific answer validation."""
    
    @abstractmethod
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """Check if predicted answer matches ground truth."""
        pass


class MathValidator(DomainValidator):
    """Validator for math domain - extracts number after ####."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """
        Math answers are in format: "reasoning\n#### 60"
        Extract number after #### and compare.
        """
        # Extract ground truth number
        if "####" in ground_truth:
            gt_value = ground_truth.split("####")[1].strip()
        else:
            gt_value = ground_truth.strip()
        
        # Normalize both to float for comparison
        try:
            gt_num = float(self._normalize_number(gt_value))
            pred_num = float(self._normalize_number(predicted))
            return abs(gt_num - pred_num) < 1e-6  # Float comparison tolerance
        except (ValueError, AttributeError):
            return False
    
    @staticmethod
    def _normalize_number(text: str) -> str:
        """Extract and normalize numeric value."""
        # Remove common formatting: commas, dollar signs, percent
        cleaned = text.replace(",", "").replace("$", "").replace("%", "")
        # Extract first number (handles cases like "answer is 42")
        numbers = re.findall(r'-?\d+\.?\d*', cleaned)
        return numbers[0] if numbers else text


class CodingValidator(DomainValidator):
    """Validator for coding domain - executes test cases."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """
        Execute test cases from metadata.
        Returns True only if ALL test cases pass.
        """
        if not metadata or 'test_cases' not in metadata:
            return False
        
        test_cases = metadata['test_cases']
        
        try:
            # Create namespace with the predicted code
            namespace = {}
            exec(predicted, namespace)
            
            # Run each test case
            for test in test_cases:
                try:
                    exec(test, namespace)
                except AssertionError:
                    return False  # Test failed
                except Exception:
                    return False  # Execution error
            
            return True  # All tests passed
        except Exception:
            return False  # Code doesn't execute


class ScienceValidator(DomainValidator):
    """Validator for science domain - case-insensitive exact match."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """Case-insensitive comparison after normalization."""
        pred_normalized = predicted.strip().lower()
        gt_normalized = ground_truth.strip().lower()
        return pred_normalized == gt_normalized


class LogicValidator(DomainValidator):
    """Validator for logic domain - Yes/No normalization."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """
        Normalize Yes/No answers.
        Handles: yes, Yes, YES, no, No, NO
        """
        pred_normalized = predicted.strip().lower()
        gt_normalized = ground_truth.strip().lower()
        
        # Check for yes/no presence
        pred_is_yes = "yes" in pred_normalized
        pred_is_no = "no" in pred_normalized
        gt_is_yes = "yes" in gt_normalized
        gt_is_no = "no" in gt_normalized
        
        # Match if both have same yes/no
        return (pred_is_yes and gt_is_yes) or (pred_is_no and gt_is_no)



---

## Reward Function Design

GRPO learns from reward signals. Our reward function has two key properties:

1. **Format compliance is mandatory** — invalid XML = 0 reward
2. **Domain-specific scoring** — verifiable domains check correctness, creative domains use heuristics

### Reward Structure

| Component | Points | Condition |
|-----------|--------|-----------|
| **Format (Hard Gate)** | 0.2 | Valid `<reasoning>` and `<answer>` tags |
| **Correctness** (verifiable) | 0.8 | Answer matches ground truth |
| **Quality** (unverifiable) | up to 0.8 | Length + Diversity + Relevance |

**Total: 0.0 to 1.0**

### Format Validation Rules

The XML validator enforces strict requirements:
- Exactly ONE `<reasoning>` tag pair
- Exactly ONE `<answer>` tag pair  
- `<reasoning>` must appear BEFORE `<answer>`
- No overlapping or nested tags
- Both sections must have non-empty content

**Why a hard gate?** If output isn't parseable, we can't evaluate it. Reward = 0 teaches the model: "format first, content second."

### Verifiable Domain Validators

| Domain | Validation Method |
|--------|-------------------|
| **Math** | Extract number after `####`, compare with tolerance |
| **Coding** | Execute code against test cases |
| **Science** | Case-insensitive exact match |
| **Logic** | Normalize Yes/No answers |

### Unverifiable Domain Heuristics

For creative domains (writing, ideation, summarization), we use quality heuristics:

| Heuristic | Weight | What It Measures |
|-----------|--------|------------------|
| **Length Score** | 0.30 | Reasoning: 20-500 words, Answer: 10-300 words |
| **Lexical Diversity** | 0.25 | Unique words / Total words (vocabulary richness) |
| **Prompt Relevance** | 0.25 | Keyword overlap between prompt and reasoning |

### Example Reward Calculations

**Math problem (correct answer):**
Format valid: +0.2 Answer correct: +0.8 Total: 1.0 ✓

**Creative writing (good response):**
Format valid: +0.2 Length appropriate: +0.25 High diversity: +0.20 Good relevance: +0.20 Total: 0.85 ✓

**Any domain (invalid format):**
Format invalid: 0.0 Total: 0.0 ✗


In [None]:
class HeuristicRewards:
    """Heuristic-based quality rewards for creative domains."""
    
    @staticmethod
    def length_score(text: str, target: int, min_len: int, max_len: int) -> float:
        """
        Score based on length appropriateness.
        Returns 1.0 if within [min_len, max_len], decays outside.
        """
        length = len(text.split())
        
        if min_len <= length <= max_len:
            return 1.0
        else:
            # Linear decay based on distance from range
            if length < min_len:
                distance = min_len - length
                max_distance = min_len
            else:  # length > max_len
                distance = length - max_len
                max_distance = target
            
            return max(0.0, 1.0 - distance / max_distance)
    
    @staticmethod
    def lexical_diversity(text: str) -> float:
        """
        Calculate lexical diversity: unique words / total words.
        """
        words = text.lower().split()
        if len(words) == 0:
            return 0.0
        
        unique_words = len(set(words))
        return unique_words / len(words)
    
    @staticmethod
    def prompt_relevance(prompt: str, reasoning: str) -> float:
        """
        Calculate relevance by keyword overlap between prompt and reasoning.
        """
        # Extract meaningful words (>3 chars, not common stop words)
        stop_words = {'the', 'and', 'for', 'are', 'but', 'not', 'you', 'all', 'can', 
                     'her', 'was', 'one', 'our', 'out', 'day', 'get', 'has', 'him',
                     'his', 'how', 'man', 'new', 'now', 'old', 'see', 'two', 'way',
                     'who', 'boy', 'did', 'its', 'let', 'put', 'say', 'she', 'too', 'use'}
        
        prompt_words = set(w.lower() for w in prompt.split() if len(w) > 3 and w.lower() not in stop_words)
        reasoning_words = set(w.lower() for w in reasoning.split() if len(w) > 3 and w.lower() not in stop_words)
        
        if len(prompt_words) == 0:
            return 0.5  # Default score if no meaningful words in prompt
        
        overlap = len(prompt_words & reasoning_words)
        return overlap / len(prompt_words)


# ============================================================================
# MAIN REWARD CALCULATOR
# ============================================================================
class RewardCalculator:
    """
    Main reward calculator that routes to appropriate validators.
    
    Reward breakdown:
    - Format: 0.2 (all domains)
    - Verifiable: 0.6 correctness + 0.2 bonus
    - Creative: 0.3 length + 0.25 diversity + 0.25 relevance
    """
    
    VERIFIABLE_DOMAINS = {"math", "coding", "science", "logic"}
    CREATIVE_DOMAINS = {"creative_writing", "creative_ideation", "summarization"}
    
    def __init__(self):
        """Initialize validators."""
        self.xml_validator = XMLValidator()
        self.validators = {
            "math": MathValidator(),
            "coding": CodingValidator(),
            "science": ScienceValidator(),
            "logic": LogicValidator(),
        }
        self.heuristics = HeuristicRewards()
    
    def compute_reward(
        self,
        domain: str,
        prompt: str,
        response: str,
        ground_truth: Optional[str] = None,
        metadata: Optional[Dict] = None
    ) -> float:
        """
        Compute total reward for a response.
        
        Args:
            domain: Task domain (math, coding, science, etc.)
            prompt: Original prompt/question
            response: Model's generated response
            ground_truth: Expected answer (for verifiable domains)
            metadata: Additional data (e.g., test_cases for coding)
        
        Returns:
            float: Total reward score [0.0, 1.0]
        """
        # HARD DEPENDENCY: Format must be valid
        if not self.xml_validator.is_valid(response):
            return 0.0
        
        # Format is valid - start with format reward
        reward = 0.2
        
        # Extract content
        reasoning = self.xml_validator.extract_reasoning(response)
        answer = self.xml_validator.extract_answer(response)
        
        # Domain-specific rewards
        if domain in self.VERIFIABLE_DOMAINS:
            reward += self._compute_verifiable_reward(
                domain, answer, ground_truth, metadata
            )
        elif domain in self.CREATIVE_DOMAINS:
            reward += self._compute_creative_reward(
                prompt, reasoning, answer
            )
        else:
            # Unknown domain - use creative heuristics as fallback
            reward += self._compute_creative_reward(
                prompt, reasoning, answer
            )
        
        return min(reward, 1.0)  # Cap at 1.0
    
    def _compute_verifiable_reward(
        self,
        domain: str,
        answer: str,
        ground_truth: str,
        metadata: Optional[Dict]
    ) -> float:
        """Compute reward for verifiable domains."""
        validator = self.validators[domain]
        
        # Correctness check (0.6 points)
        if validator.is_correct(answer, ground_truth, metadata):
            return 0.8  # 0.6 for correctness + 0.2 bonus
        else:
            return 0.0
    
    def _compute_creative_reward(
        self,
        prompt: str,
        reasoning: str,
        answer: str
    ) -> float:
        """Compute reward for creative domains using heuristics."""
        score = 0.0
        
        # Length appropriateness (0.3 points)
        reasoning_score = self.heuristics.length_score(
            reasoning, target=250, min_len=20, max_len=500
        )
        answer_score = self.heuristics.length_score(
            answer, target=150, min_len=10, max_len=300
        )
        score += 0.15 * reasoning_score
        score += 0.15 * answer_score
        
        # Lexical diversity (0.25 points)
        diversity = self.heuristics.lexical_diversity(answer)
        score += 0.25 * diversity
        
        # Prompt relevance (0.25 points)
        relevance = self.heuristics.prompt_relevance(prompt, reasoning)
        score += 0.25 * relevance
        
        return score

def compute_reward_batch(
    prompts: List[str],
    completions: List[str],
    domain: List[str] = None,
    answer: List[str] = None,
    metadata_str: List[str] = None,
    **kwargs  # Catch any other fields
) -> List[float]:
    """
    Compute rewards for a batch of responses.
    
    Called by GRPO with:
    - prompts: List of prompts
    - completions: List of model responses
    - **kwargs: Dict with domain, answer, metadata_str, etc.
    
    Returns:
        List[float]: Reward scores for each response
    """    
    calculator = RewardCalculator()
    
    # Handle missing fields
    if domain is None:
        domain = ["unknown"] * len(completions)
    if answer is None:
        answer = [None] * len(completions)
    if metadata_str is None:
        metadata_str = ["{}"] * len(completions)
    
    # Parse metadata from JSON strings
    metadatas = []
    for meta_str in metadata_str:
        try:
            metadatas.append(json.loads(meta_str) if isinstance(meta_str, str) else {})
        except:
            metadatas.append({})
    
    # Compute rewards
    rewards = []
    for dom, prompt, completion, gt, meta in zip(
        domain, prompts, completions, answer, metadatas
    ):
        reward = calculator.compute_reward(
            domain=dom,
            prompt=prompt,
            response=completion,  # GRPO calls it completion
            ground_truth=gt,
            metadata=meta
        )
        rewards.append(reward)
    
    return rewards


---

## Training Pipeline

The following cells set up and run the GRPO training loop:

1. **Load Base Model** → Save to intermediate checkpoint for reference model
2. **Create Models** → Reference (frozen) + Policy (LoRA-wrapped)
3. **Create Optimizer** → AdamW with warmup cosine schedule
4. **Configure GRPO** → Rollout config, checkpointing, metrics logging
5. **Create Trainer** → Wire up models, tokenizer, and reward function
6. **Run Training** → Execute the GRPO loop
7. **Save Checkpoint** → Saving final checkpoint.

**Note:** Training takes ~7-9 hours on Kaggle TPU v5-8. Checkpoints are saved every 500 steps to `/tmp/content/grpo_checkpoints/`.

In [None]:
print("="*60)
print("1. LOADING BASE MODEL")
print("="*60)

# Clear any existing intermediate checkpoint
if os.path.exists(INTERMEDIATE_CKPT_DIR):
    shutil.rmtree(INTERMEDIATE_CKPT_DIR)
os.makedirs(INTERMEDIATE_CKPT_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

# Load base Gemma model
config = model.ModelConfig.gemma3_1b()
gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
tokenizer = params.create_tokenizer()

# Save to intermediate checkpoint
checkpointer = ocp.StandardCheckpointer()
_, state = nnx.split(gemma)
checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
checkpointer.wait_until_finished()

# Free memory
del gemma
del state
gc.collect()

print("✓ Base model saved to intermediate checkpoint")

In [None]:
print("="*60)
print("2. CREATING REFERENCE AND POLICY MODELS")
print("="*60)

def get_gemma_ref_model(ckpt_path):
    """Load Gemma model with proper sharding."""
    mesh = jax.make_mesh(*MESH)
    model_config = model.ModelConfig.gemma3_1b()
    
    # Create abstract model for shape inference
    abs_gemma = nnx.eval_shape(
        lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, model_config)
    )
    
    # Create sharded state specification
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    
    # Restore checkpoint
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)
    
    # Merge graph and params
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    
    return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
    """Apply LoRA adapters to the model."""
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    
    # Apply sharding constraints
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    
    return lora_model


# Create reference model (frozen, for KL penalty)
ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
)
print("✓ Reference model loaded")

# Create policy model (will be trained with GRPO)
lora_policy = get_lora_model(ref_model, mesh=mesh)
print("✓ Policy model with LoRA created")

In [None]:
print("="*60)
print("3. CREATING OPTIMIZER")
print("="*60)

optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

if MAX_GRAD_NORM is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
        optimizer,
    )

print(f"AdamW optimizer with warmup cosine decay")
print(f"  - Peak LR: {LEARNING_RATE}")
print(f"  - Warmup steps: {WARMUP_STEPS}")
print(f"  - Grad clip norm: {MAX_GRAD_NORM}")

In [None]:
print("="*60)
print("4. CONFIGURING GRPO TRAINING")
print("="*60)

# Checkpoint saving options
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, 
    max_to_keep=MAX_TO_KEEP
)

# Metrics logging
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tensorboard/grpo", 
    flush_every_n_steps=20
)

# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=1536,  # 1024 + 256 + 256 buffer
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1, 106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

print("✓ GRPO configuration complete")
print(f"  - Group size: {NUM_GENERATIONS} generations per prompt")
print(f"  - Iterations: {NUM_ITERATIONS}")
print(f"  - KL beta: {BETA}")
print(f"  - Clip epsilon: {EPSILON}")

In [None]:
print("="*60)
print("5. CREATING GRPO TRAINER")
print("="*60)

# Create RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# Create GRPO trainer with our multi-domain reward function
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[compute_reward_batch],  # Our custom multi-domain reward!
    grpo_config=grpo_config,
)

print("✓ GRPO trainer created")
print(f"  - Actor: LoRA policy model")
print(f"  - Reference: Frozen base model")
print(f"  - Reward function: Multi-domain (7 domains)")

In [None]:
print("="*80)
print("6. STARTING GRPO TRAINING")
print("="*80)
print(f"Training steps: {MAX_STEPS}")
print(f"Checkpoint interval: {SAVE_INTERVAL_STEPS} steps")
print("="*80)

with mesh:
    grpo_trainer.train(train_dataset)

print("="*80)
print("TRAINING COMPLETE")
print("="*80)

In [None]:
# NOTE: This saves LoRA parameters only.
# To use: Load base Gemma 3 1B → Apply LoRA → Restore these params.

print("="*60)
print("7. SAVING FINAL checkpoint")
print("="*60)

# Save to /kaggle/working (persists after session)
final_ckpt_path = "/kaggle/working/grpo_multi_domain_final"

if os.path.exists(final_ckpt_path):
    shutil.rmtree(final_ckpt_path)

# Save LoRA parameters
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(lora_policy, nnx.LoRAParam),
)
checkpointer = ocp.StandardCheckpointer()
lora_params = nnx.state(lora_policy, nnx.LoRAParam)
checkpointer.save(final_ckpt_path, lora_params)
checkpointer.wait_until_finished()

print(f"✓ Model saved to {final_ckpt_path}")

# Create metadata for Kaggle dataset
metadata = {
    "title": "GRPO Multi-Domain Reasoning Model",
    "id": "fissalalsharef/grpo-multi-domain-final",
}

with open('/kaggle/working/dataset-metadata.json', 'w') as f:
    json.dump(metadata, f)

print("✓ Metadata created")
print("\nTRAINING COMPLETE!")

---

## Sample Results

**Note:** Inference is not run in this notebook to preserve the 9-hour session limit. 
Results below are from separate validation runs on held-out test data.

Datasets used for validation:

- [Uverifiable Domains](https://www.kaggle.com/code/fissalalsharef/stress-testing-unverifiable-domain-dataset)
- [Verifiable & Unverifiable Domains](https://www.kaggle.com/code/fissalalsharef/stress-testing-verifiable-domain-dataset)

### Format Compliance by Domain

**Verifiable Domains:**
| Domain | Accuracy |
|--------|----------|
| Math | **93.9%** |
| Coding | **93.9%** |

**Unverifiable Domains:**
| Domain | Valid Rate |
|--------|------------|
| Summarization | **95.0%** |
| Conversation | 87.9% |
| Creative Writing | 82.1% |

### LLM-as-a-Judge Scores (DeepSeek Reasoner)

| Domain | Reasoning | Answer | Overall |
|--------|-----------|--------|---------|
| Creative Writing | 4.00/5 | 3.78/5 | **7.33/10** |
| Summarization | 4.37/5 | 3.33/5 | 6.92/10 |
| Conversation | 1.90/5 | 2.62/5 | 4.07/10 |

### Key Observations

- **Verifiable domains achieve ~94% accuracy**: reward signal is clear and learnable
- **Summarization performs best in unverifiable**: structured task aligns with reasoning format
- **Conversation underperforms**: model sometimes requests clarification instead of answering
- **Common failures**: truncation (token limit) or missing closing tags