# Multi-Domain GRPO Training

**Goal:** Train Gemma 3 1B with GRPO on 15K multi-domain samples

**Domains:** Math, Coding, Science, Logic, Summarization, Creative Writing, Creative Ideation

## Install Dependencies

In [None]:
# Clean up
!pip uninstall -q -y gensim bigframes tensorflow-decision-forests tf-keras flax jax jaxlib qwix tunix

# Install Google Cloud SDKs
!pip install -U -q google-cloud-storage google-cloud-automl google-cloud-bigquery protobuf

# Install NumPy 2.0
!pip install -q "numpy>=2.0" "ml_dtypes>=0.4.0"

# Install EXACT working versions
!pip install -q \
    "jax[tpu]==0.8.1" \
    "flax==0.12.1" \
    "qwix==0.1.4" \
    "optax==0.2.6" \
    "orbax-checkpoint==0.11.31" \
    "chex==0.1.91" \
    "google-tunix[prod]==0.1.3" \
    tensorflow \
    kagglehub \
    grain \
    humanize

## Imports

In [None]:
import functools
import gc
import os
import re
import csv
import shutil
from pathlib import Path
from pprint import pprint
from tqdm import tqdm

import jax
import jax.numpy as jnp
import kagglehub
from flax import nnx
import grain
import optax
import humanize
from orbax import checkpoint as ocp

# Tunix imports
from tunix.models.gemma3 import model, params
from tunix.generate import sampler as sampler_lib

# GRPO-specific imports
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
from tunix.rl.grpo.grpo_learner import GRPOLearner, GRPOConfig
from tunix.sft import metrics_logger

import qwix
import numpy as np 
import json
import pandas as pd
import re
from typing import Optional, Dict, List, Tuple
from abc import ABC, abstractmethod
import sys
from typing import Dict, List
import pandas as pd
import grain
from typing import Dict, Tuple, Optional, List
from dataclasses import dataclass

## Configuration & HyperParams

In [None]:
# ==================== PATHS ====================
DATASET_PATH = "/kaggle/input/harmonic-oscillation/full_dataset_pool.jsonl"
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/grpo_checkpoints/"

# ==================== TRAINING DATA CONFIG ====================
TRAIN_SIZE = 10000
VAL_SIZE = 500
BATCH_SIZE = 2
DISCO_TEMP = 0.5
SEED = 42

# ==================== GRPO CONFIG ====================
NUM_EPOCHS = 1
NUM_ITERATIONS = 4
NUM_GENERATIONS = 4

# RL hyperparameters
BETA = 0.04
EPSILON = 0.2

# Training config
TRAIN_MICRO_BATCH_SIZE = 2
MAX_STEPS = int((TRAIN_SIZE / BATCH_SIZE) * NUM_ITERATIONS * NUM_EPOCHS)
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
WARMUP_STEPS = int(0.1 * MAX_STEPS)
MAX_GRAD_NORM = 0.1

# Checkpointing
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 3
EVAL_EVERY_N_STEPS = 500

# ==================== MODEL CONFIG ====================
MODEL_CP_PATH = params.GEMMA3_1B_IT
MESH = ((1, 4), ("fsdp", "tp"))

# LoRA config
LORA_RANK = 16
LORA_ALPHA = 32.0

# ==================== GENERATION CONFIG ====================
MAX_PROMPT_LENGTH = 1024          # Both need same limit
TOTAL_GENERATION_STEPS = 512      # MUST MATCH OR INFERENCE >= TRAINING
TEMPERATURE = 0.7
INFERENCE_TEMPERATURE = 0.0
TOP_P = 0.95
TOP_K = 50

# ==================== PROMPTING ====================
SYSTEM_PROMPT = """Provide your reasoning in <reasoning> tags, then your final answer in <answer> tags.
Format:
<reasoning>Your step-by-step thinking</reasoning>
<answer>Your final answer</answer>"""

TEMPLATE = """<start_of_turn>user
{system_prompt}
Task: {question}<end_of_turn>
<start_of_turn>model"""

## Data Loader

In [None]:
class DatasetLoader:
    """Loads and preprocesses multi-domain JSONL dataset with DISCO sampling."""
    
    def __init__(self, jsonl_path: str):
        """
        Initialize loader.
        
        Args:
            jsonl_path: Path to JSONL file with dataset pool
        """
        self.jsonl_path = jsonl_path
        self.df = None
    
    def load(self) -> pd.DataFrame:
        """Load JSONL file into DataFrame."""
        print(f"Loading dataset from {self.jsonl_path}...")
        
        self.df = pd.read_json(self.jsonl_path, lines=True)
        
        # print(f"Loaded {len(self.df)} total samples")
        # print(f"Domains: {self.df['domain'].unique()}")
        # print("\nDomain distribution:")
        # for domain, count in self.df['domain'].value_counts().items():
        #     pct = (count / len(self.df)) * 100
        #     print(f"  {domain:20s}: {count:6d} ({pct:5.1f}%)")
        
        return self.df
    
    def compute_disco_proportions(self, temperature: float = 1.0) -> Dict[str, float]:
        """
        Compute DISCO-adjusted domain proportions.
        
        Args:
            temperature: DISCO temperature
                - T=0.5: Moderate balancin
        """
        if self.df is None:
            self.load()
        
        # Get natural proportions
        domain_counts = self.df['domain'].value_counts()
        total = len(self.df)
        natural_props = {domain: count / total for domain, count in domain_counts.items()}
        
        # Apply DISCO temperature
        if temperature == 1.0:
            # Natural distribution
            adjusted = natural_props
        else:
            # Temperature-adjusted distribution
            adjusted_unnormalized = {
                domain: prop ** temperature
                for domain, prop in natural_props.items()
            }
            
            # Normalize to sum to 1.0
            total_adjusted = sum(adjusted_unnormalized.values())
            adjusted = {
                domain: val / total_adjusted
                for domain, val in adjusted_unnormalized.items()
            }
        
        # ---- DEBUGGIN: Print comparison ---- 
        # print(f"\nDISCO Proportions (T={temperature}):")
        # print(f"{'Domain':<20} {'Natural':>10} {'DISCO':>10} {'Change':>10}")
        # print("-" * 52)
        # for domain in sorted(natural_props.keys()):
        #     nat = natural_props[domain]
        #     disco = adjusted[domain]
        #     change = ((disco - nat) / nat) * 100
        #     print(f"{domain:<20} {nat:>9.1%} {disco:>9.1%} {change:>+9.1f}%")
        
        return adjusted
    
    def sample_dataset(
        self,
        total_size: int,
        temperature: float,
        seed: int
    ) -> pd.DataFrame:
        """
        Sample dataset using DISCO proportions.
        
        Args:
            total_size: Total number of samples to draw
            temperature: DISCO temperature 0.5
            seed: Random seed for reproducibility
        
        Returns:
            DataFrame with sampled data
        """
        if self.df is None:
            self.load()
        
        # Compute DISCO proportions
        proportions = self.compute_disco_proportions(temperature)
        
        # Sample from each domain
        sampled_dfs = []
        
        print(f"\nSampling {total_size} samples:")
        for domain, proportion in proportions.items():
            target_count = int(total_size * proportion)
            domain_df = self.df[self.df['domain'] == domain]
            available = len(domain_df)
            
            # Check if we have enough samples
            if target_count > available:
                sampled = domain_df
            else:
                sampled = domain_df.sample(n=target_count, random_state=seed)
            
            sampled_dfs.append(sampled)
            print(f"  {domain:<20}: Sampled {len(sampled):4d} samples")
        
        # Combine and shuffle
        combined = pd.concat(sampled_dfs, ignore_index=True)
        combined = combined.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        
        print(f"\nTotal: {len(combined)} samples")
        return combined
    
    def create_datasets(
        self,
        train_size: int,
        temperature: float,
        batch_size: int,
        seed: int,
        val_size: Optional[int] = None
    ) -> Tuple[grain.MapDataset, Optional[grain.MapDataset]]:
        """
        Create train and validation Grain datasets.
        
        Args:
            train_size: Number of training samples
            val_size: Number of validation samples (optional)
            temperature: DISCO temperature
            batch_size: Batch size
            seed: Random seed
        
        Returns:
            (train_dataset, val_dataset) as Grain datasets
        """
        # Sample training data
        print("="*60)
        print("TRAINING DATA")
        print("="*60)
        train_df = self.sample_dataset(
            total_size=train_size,
            temperature=temperature,
            seed=seed
        )
        
        train_dataset = self._to_grain_dataset(train_df, batch_size, shuffle=True, seed=seed)
        
        # Validation dataset (optional)
        val_dataset = None
        if val_size:
            print("\n" + "="*60)
            print("VALIDATION DATA")
            print("="*60)
            val_df = self.sample_dataset(
                total_size=val_size,
                temperature=temperature,
                seed=seed + 1  # Different seed
            )
            val_dataset = self._to_grain_dataset(val_df, batch_size, seed, shuffle=False)
        
        return train_dataset, val_dataset
    
    def _to_grain_dataset(
        self,
        df: pd.DataFrame,
        batch_size: int,
        seed: int,
        shuffle: bool = True,
    ) -> grain.MapDataset:
        """Convert DataFrame to batched Grain dataset."""
        # Convert to list of dicts
        data = df.to_dict('records')
        
        # Create Grain dataset
        dataset = grain.MapDataset.source(data)
        
        if shuffle:
            dataset = dataset.shuffle(seed=seed)
        
        # Map to GRPO format (includes truncation!)
        dataset = dataset.map(self._format_for_grpo)
        
        # Batch
        dataset = dataset.batch(batch_size)
        
        return dataset
    
    def _format_for_grpo(self, item: Dict) -> Dict:
        """
        Format a single item for GRPO training.
        
        Truncates very long prompts to fit within token limits.
        """
        import json
        
        # Truncate long prompts
        prompt_text = item['prompt']
        MAX_CHARS = 2500  # ~625 tokens (safe for 1024 limit)
        
        if len(prompt_text) > MAX_CHARS:
            prompt_text = prompt_text[:MAX_CHARS] + "..."
        
        # Format prompt
        formatted_prompt = TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT,
            question=prompt_text
        )
        
        # Normalize metadata
        metadata = item.get('metadata', {})
        if not isinstance(metadata, dict):
            metadata = {}
        
        return {
            "prompts": formatted_prompt,
            "domain": item['domain'],
            "question": item['prompt'],
            "answer": item['answer'],
            "metadata_str": json.dumps(metadata),
        }


# ============================================================================
# CONVENIENCE FUNCTIONS
# ============================================================================

def load_dataset_for_training(
    jsonl_path: str,
    train_size: int,
    batch_size: int,
    disco_temperature: float,
    seed: int = 42,
    val_size: Optional[int] = None
) -> Tuple[grain.MapDataset, Optional[grain.MapDataset]]:
    """
    Convenience function to load dataset in one call.
    
    Args:
        jsonl_path: Path to JSONL dataset
        train_size: Number of training samples
        val_size: Number of validation samples (None to skip)
        batch_size: Batch size
        disco_temperature: DISCO temperature
            - 1.0 = Natural distribution
            - 0.5 = Moderate balancing (recommended)
            - 0.3 = Aggressive balancing
        seed: Random seed
    
    Returns:
        (train_dataset, val_dataset)
    
    Example:
        train_ds, val_ds = load_dataset_for_training(
            '/kaggle/input/dataset/pool.jsonl',
            train_size=15000,
            disco_temperature=0.5
        )
    """
    loader = DatasetLoader(jsonl_path)
    return loader.create_datasets(
        train_size=train_size,
        val_size=val_size,
        temperature=disco_temperature,
        batch_size=batch_size,
        seed=seed
    )


## Validators

In [None]:
class XMLValidator:
    """Validates XML structure with strict ordering and uniqueness checks."""
    
    REASONING_START = "<reasoning>"
    REASONING_END = "</reasoning>"
    ANSWER_START = "<answer>"
    ANSWER_END = "</answer>"
    
    @classmethod
    def is_valid(cls, response: str) -> bool:
        """
        Check if response has valid XML structure.
        
        Requirements:
        1. Exactly ONE <reasoning> tag pair
        2. Exactly ONE <answer> tag pair
        3. <reasoning> appears BEFORE <answer>
        4. No overlapping/nested tags
        5. Content not empty
        
        Returns:
            bool: True if valid, False otherwise
        """
        # Check tag counts
        if response.count(cls.REASONING_START) != 1 or response.count(cls.REASONING_END) != 1:
            return False
        if response.count(cls.ANSWER_START) != 1 or response.count(cls.ANSWER_END) != 1:
            return False
        
        # Check order
        reasoning_start_pos = response.find(cls.REASONING_START)
        reasoning_end_pos = response.find(cls.REASONING_END)
        answer_start_pos = response.find(cls.ANSWER_START)
        answer_end_pos = response.find(cls.ANSWER_END)
        
        # Reasoning must come before answer
        if reasoning_start_pos >= answer_start_pos:
            return False
        
        # Tags must be properly paired (start before end)
        if reasoning_start_pos >= reasoning_end_pos:
            return False
        if answer_start_pos >= answer_end_pos:
            return False
        
        # No overlapping (reasoning must fully end before answer starts)
        if reasoning_end_pos >= answer_start_pos:
            return False
        
        # Extract content and check not empty
        reasoning = cls.extract_reasoning(response)
        answer = cls.extract_answer(response)
        
        if not reasoning or not answer:
            return False
        if len(reasoning.strip()) == 0 or len(answer.strip()) == 0:
            return False
        
        return True
    
    @classmethod
    def extract_reasoning(cls, response: str) -> Optional[str]:
        """Extract reasoning content between tags."""
        pattern = rf"{re.escape(cls.REASONING_START)}(.*?){re.escape(cls.REASONING_END)}"
        match = re.search(pattern, response, re.DOTALL)
        return match.group(1).strip() if match else None
    
    @classmethod
    def extract_answer(cls, response: str) -> Optional[str]:
        """Extract answer content between tags."""
        pattern = rf"{re.escape(cls.ANSWER_START)}(.*?){re.escape(cls.ANSWER_END)}"
        match = re.search(pattern, response, re.DOTALL)
        return match.group(1).strip() if match else None


In [None]:
class DomainValidator(ABC):
    """Abstract base class for domain-specific answer validation."""
    
    @abstractmethod
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """Check if predicted answer matches ground truth."""
        pass


class MathValidator(DomainValidator):
    """Validator for math domain - extracts number after ####."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """
        Math answers are in format: "reasoning\n#### 60"
        Extract number after #### and compare.
        """
        # Extract ground truth number
        if "####" in ground_truth:
            gt_value = ground_truth.split("####")[1].strip()
        else:
            gt_value = ground_truth.strip()
        
        # Normalize both to float for comparison
        try:
            gt_num = float(self._normalize_number(gt_value))
            pred_num = float(self._normalize_number(predicted))
            return abs(gt_num - pred_num) < 1e-6  # Float comparison tolerance
        except (ValueError, AttributeError):
            return False
    
    @staticmethod
    def _normalize_number(text: str) -> str:
        """Extract and normalize numeric value."""
        # Remove common formatting: commas, dollar signs, percent
        cleaned = text.replace(",", "").replace("$", "").replace("%", "")
        # Extract first number (handles cases like "answer is 42")
        numbers = re.findall(r'-?\d+\.?\d*', cleaned)
        return numbers[0] if numbers else text


class CodingValidator(DomainValidator):
    """Validator for coding domain - executes test cases."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """
        Execute test cases from metadata.
        Returns True only if ALL test cases pass.
        """
        if not metadata or 'test_cases' not in metadata:
            return False
        
        test_cases = metadata['test_cases']
        
        try:
            # Create namespace with the predicted code
            namespace = {}
            exec(predicted, namespace)
            
            # Run each test case
            for test in test_cases:
                try:
                    exec(test, namespace)
                except AssertionError:
                    return False  # Test failed
                except Exception:
                    return False  # Execution error
            
            return True  # All tests passed
        except Exception:
            return False  # Code doesn't execute


class ScienceValidator(DomainValidator):
    """Validator for science domain - case-insensitive exact match."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """Case-insensitive comparison after normalization."""
        pred_normalized = predicted.strip().lower()
        gt_normalized = ground_truth.strip().lower()
        return pred_normalized == gt_normalized


class LogicValidator(DomainValidator):
    """Validator for logic domain - Yes/No normalization."""
    
    def is_correct(self, predicted: str, ground_truth: str, metadata: Dict = None) -> bool:
        """
        Normalize Yes/No answers.
        Handles: yes, Yes, YES, no, No, NO
        """
        pred_normalized = predicted.strip().lower()
        gt_normalized = ground_truth.strip().lower()
        
        # Check for yes/no presence
        pred_is_yes = "yes" in pred_normalized
        pred_is_no = "no" in pred_normalized
        gt_is_yes = "yes" in gt_normalized
        gt_is_no = "no" in gt_normalized
        
        # Match if both have same yes/no
        return (pred_is_yes and gt_is_yes) or (pred_is_no and gt_is_no)



## Reward Func

In [None]:
class HeuristicRewards:
    """Heuristic-based quality rewards for creative domains."""
    
    @staticmethod
    def length_score(text: str, target: int, min_len: int, max_len: int) -> float:
        """
        Score based on length appropriateness.
        Returns 1.0 if within [min_len, max_len], decays outside.
        """
        length = len(text.split())
        
        if min_len <= length <= max_len:
            return 1.0
        else:
            # Linear decay based on distance from range
            if length < min_len:
                distance = min_len - length
                max_distance = min_len
            else:  # length > max_len
                distance = length - max_len
                max_distance = target
            
            return max(0.0, 1.0 - distance / max_distance)
    
    @staticmethod
    def lexical_diversity(text: str) -> float:
        """
        Calculate lexical diversity: unique words / total words.
        """
        words = text.lower().split()
        if len(words) == 0:
            return 0.0
        
        unique_words = len(set(words))
        return unique_words / len(words)
    
    @staticmethod
    def prompt_relevance(prompt: str, reasoning: str) -> float:
        """
        Calculate relevance by keyword overlap between prompt and reasoning.
        """
        # Extract meaningful words (>3 chars, not common stop words)
        stop_words = {'the', 'and', 'for', 'are', 'but', 'not', 'you', 'all', 'can', 
                     'her', 'was', 'one', 'our', 'out', 'day', 'get', 'has', 'him',
                     'his', 'how', 'man', 'new', 'now', 'old', 'see', 'two', 'way',
                     'who', 'boy', 'did', 'its', 'let', 'put', 'say', 'she', 'too', 'use'}
        
        prompt_words = set(w.lower() for w in prompt.split() if len(w) > 3 and w.lower() not in stop_words)
        reasoning_words = set(w.lower() for w in reasoning.split() if len(w) > 3 and w.lower() not in stop_words)
        
        if len(prompt_words) == 0:
            return 0.5  # Default score if no meaningful words in prompt
        
        overlap = len(prompt_words & reasoning_words)
        return overlap / len(prompt_words)


# ============================================================================
# MAIN REWARD CALCULATOR
# ============================================================================

class RewardCalculator:
    """
    Main reward calculator that routes to appropriate validators.
    
    Reward breakdown:
    - Format: 0.2 (all domains)
    - Verifiable: 0.6 correctness + 0.2 bonus
    - Creative: 0.3 length + 0.25 diversity + 0.25 relevance
    """
    
    VERIFIABLE_DOMAINS = {"math", "coding", "science", "logic"}
    CREATIVE_DOMAINS = {"creative_writing", "creative_ideation", "summarization"}
    
    def __init__(self):
        """Initialize validators."""
        self.xml_validator = XMLValidator()
        self.validators = {
            "math": MathValidator(),
            "coding": CodingValidator(),
            "science": ScienceValidator(),
            "logic": LogicValidator(),
        }
        self.heuristics = HeuristicRewards()
    
    def compute_reward(
        self,
        domain: str,
        prompt: str,
        response: str,
        ground_truth: Optional[str] = None,
        metadata: Optional[Dict] = None
    ) -> float:
        """
        Compute total reward for a response.
        
        Args:
            domain: Task domain (math, coding, science, etc.)
            prompt: Original prompt/question
            response: Model's generated response
            ground_truth: Expected answer (for verifiable domains)
            metadata: Additional data (e.g., test_cases for coding)
        
        Returns:
            float: Total reward score [0.0, 1.0]
        """
        # HARD DEPENDENCY: Format must be valid
        if not self.xml_validator.is_valid(response):
            return 0.0
        
        # Format is valid - start with format reward
        reward = 0.2
        
        # Extract content
        reasoning = self.xml_validator.extract_reasoning(response)
        answer = self.xml_validator.extract_answer(response)
        
        # Domain-specific rewards
        if domain in self.VERIFIABLE_DOMAINS:
            reward += self._compute_verifiable_reward(
                domain, answer, ground_truth, metadata
            )
        elif domain in self.CREATIVE_DOMAINS:
            reward += self._compute_creative_reward(
                prompt, reasoning, answer
            )
        else:
            # Unknown domain - use creative heuristics as fallback
            reward += self._compute_creative_reward(
                prompt, reasoning, answer
            )
        
        return min(reward, 1.0)  # Cap at 1.0
    
    def _compute_verifiable_reward(
        self,
        domain: str,
        answer: str,
        ground_truth: str,
        metadata: Optional[Dict]
    ) -> float:
        """Compute reward for verifiable domains."""
        validator = self.validators[domain]
        
        # Correctness check (0.6 points)
        if validator.is_correct(answer, ground_truth, metadata):
            return 0.8  # 0.6 for correctness + 0.2 bonus
        else:
            return 0.0
    
    def _compute_creative_reward(
        self,
        prompt: str,
        reasoning: str,
        answer: str
    ) -> float:
        """Compute reward for creative domains using heuristics."""
        score = 0.0
        
        # Length appropriateness (0.3 points)
        reasoning_score = self.heuristics.length_score(
            reasoning, target=250, min_len=20, max_len=500
        )
        answer_score = self.heuristics.length_score(
            answer, target=150, min_len=10, max_len=300
        )
        score += 0.15 * reasoning_score
        score += 0.15 * answer_score
        
        # Lexical diversity (0.25 points)
        diversity = self.heuristics.lexical_diversity(answer)
        score += 0.25 * diversity
        
        # Prompt relevance (0.25 points)
        relevance = self.heuristics.prompt_relevance(prompt, reasoning)
        score += 0.25 * relevance
        
        return score



In [None]:
def compute_reward_batch(
    prompts: List[str],
    completions: List[str],
    domain: List[str] = None,
    answer: List[str] = None,
    metadata_str: List[str] = None,
    **kwargs  # Catch any other fields
) -> List[float]:
    """
    Compute rewards for a batch of responses.
    
    Called by GRPO with:
    - prompts: List of prompts
    - completions: List of model responses
    - **kwargs: Dict with domain, answer, metadata_str, etc.
    
    Returns:
        List[float]: Reward scores for each response
    """
    import json
    
    calculator = RewardCalculator()
    
    # Handle missing fields
    if domain is None:
        domain = ["unknown"] * len(completions)
    if answer is None:
        answer = [None] * len(completions)
    if metadata_str is None:
        metadata_str = ["{}"] * len(completions)
    
    # Parse metadata from JSON strings
    metadatas = []
    for meta_str in metadata_str:
        try:
            metadatas.append(json.loads(meta_str) if isinstance(meta_str, str) else {})
        except:
            metadatas.append({})
    
    # Compute rewards
    rewards = []
    for dom, prompt, completion, gt, meta in zip(
        domain, prompts, completions, answer, metadatas
    ):
        reward = calculator.compute_reward(
            domain=dom,
            prompt=prompt,
            response=completion,  # GRPO calls it completion
            ground_truth=gt,
            metadata=meta
        )
        rewards.append(reward)
    
    return rewards

## Memory Utility

In [None]:
def show_hbm_usage():
    """Displays memory usage per device."""
    fmt_size = functools.partial(humanize.naturalsize, binary=True)
    for d in jax.local_devices():
        stats = d.memory_stats()
        used = stats["bytes_in_use"]
        limit = stats["bytes_limit"]
        print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:.1%}) on {d}")

show_hbm_usage()

## Cell 6: Load Dataset

In [None]:
print("="*60)
print("LOADING DATASET")
print("="*60)

train_dataset, val_dataset = load_dataset_for_training(
    jsonl_path=DATASET_PATH,
    train_size=TRAIN_SIZE,
    val_size=VAL_SIZE,
    batch_size=BATCH_SIZE,
    disco_temperature=DISCO_TEMP,
    seed=SEED
)

print(f"\n‚úì Train dataset: {type(train_dataset)}")
print(f"‚úì Val dataset: {type(val_dataset)}")

## Cell 8: Load Base Model to Intermediate Checkpoint

In [None]:
try:
    print("="*60)
    print("LOADING BASE MODEL")
    print("="*60)
    
    # Clear any existing intermediate checkpoint
    if os.path.exists(INTERMEDIATE_CKPT_DIR):
        shutil.rmtree(INTERMEDIATE_CKPT_DIR)
    os.makedirs(INTERMEDIATE_CKPT_DIR, exist_ok=True)
    os.makedirs(CKPT_DIR, exist_ok=True)
    
    # Load base Gemma model
    config = model.ModelConfig.gemma3_1b()
    gemma = params.create_model_from_checkpoint(MODEL_CP_PATH, config)
    tokenizer = params.create_tokenizer()
    
    # Save to intermediate checkpoint
    checkpointer = ocp.StandardCheckpointer()
    _, state = nnx.split(gemma)
    checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
    checkpointer.wait_until_finished()
    
    # Free memory
    del gemma
    del state
    gc.collect()
    
    print("‚úì Base model saved to intermediate checkpoint")
    # show_hbm_usage()
except:pass

## Cell 9: Create Reference and Policy Models

In [None]:
print("="*60)
print("CREATING REFERENCE AND POLICY MODELS")
print("="*60)

def get_gemma_ref_model(ckpt_path):
    """Load Gemma model with proper sharding."""
    mesh = jax.make_mesh(*MESH)
    model_config = model.ModelConfig.gemma3_1b()
    
    # Create abstract model for shape inference
    abs_gemma = nnx.eval_shape(
        lambda: params.create_model_from_checkpoint(MODEL_CP_PATH, model_config)
    )
    
    # Create sharded state specification
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    
    # Restore checkpoint
    checkpointer = ocp.StandardCheckpointer()
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)
    
    # Merge graph and params
    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    
    return gemma, mesh, model_config


def get_lora_model(base_model, mesh):
    """Apply LoRA adapters to the model."""
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    
    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, **model_input
    )
    
    # Apply sharding constraints
    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)
    
    return lora_model


# Create reference model (frozen, for KL penalty)
ref_model, mesh, model_config = get_gemma_ref_model(
    ckpt_path=os.path.join(INTERMEDIATE_CKPT_DIR, "state")
)
print("‚úì Reference model loaded")

# Create policy model (will be trained with GRPO)
lora_policy = get_lora_model(ref_model, mesh=mesh)
print("‚úì Policy model with LoRA created")

# show_hbm_usage()

## Cell 10: Create Optimizer

In [None]:
print("="*60)
print("CREATING OPTIMIZER")
print("="*60)

optimizer = optax.adamw(
    learning_rate=optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    ),
    b1=B1,
    b2=B2,
    weight_decay=WEIGHT_DECAY,
)

if MAX_GRAD_NORM is not None:
    optimizer = optax.chain(
        optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
        optimizer,
    )

print(f"‚úì AdamW optimizer with warmup cosine decay")
print(f"  - Peak LR: {LEARNING_RATE}")
print(f"  - Warmup steps: {WARMUP_STEPS}")
print(f"  - Grad clip norm: {MAX_GRAD_NORM}")

## Cell 11: Configure GRPO Training

In [None]:
print("="*60)
print("CONFIGURING GRPO TRAINING")
print("="*60)

# Checkpoint saving options
checkpointing_options = ocp.CheckpointManagerOptions(
    save_interval_steps=SAVE_INTERVAL_STEPS, 
    max_to_keep=MAX_TO_KEEP
)

# Metrics logging
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/content/tensorboard/grpo", 
    flush_every_n_steps=20
)

# Training config
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={
        rl_cluster_lib.Role.ACTOR: mesh,
        rl_cluster_lib.Role.REFERENCE: mesh,
        rl_cluster_lib.Role.ROLLOUT: mesh,
    },
    rollout_engine='vanilla',
    offload_to_cpu=False,
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        eval_every_n_steps=EVAL_EVERY_N_STEPS,
        max_steps=MAX_STEPS,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        metrics_logging_options=metrics_logging_options,
        checkpoint_root_directory=CKPT_DIR,
        checkpointing_options=checkpointing_options,
    ),
    rollout_config=base_rollout.RolloutConfig(
        max_tokens_to_generate=TOTAL_GENERATION_STEPS,
        max_prompt_length=MAX_PROMPT_LENGTH,
        kv_cache_size=1536,  # 1024 + 256 + 256 buffer
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        eos_tokens=[1, 106],
    ),
)

grpo_config = GRPOConfig(
    num_generations=NUM_GENERATIONS,
    num_iterations=NUM_ITERATIONS,
    beta=BETA,
    epsilon=EPSILON,
)

print("‚úì GRPO configuration complete")
print(f"  - Group size: {NUM_GENERATIONS} generations per prompt")
print(f"  - Iterations: {NUM_ITERATIONS}")
print(f"  - KL beta: {BETA}")
print(f"  - Clip epsilon: {EPSILON}")

## Cell 12: Create GRPO Trainer

In [None]:
print("="*60)
print("CREATING GRPO TRAINER")
print("="*60)

# Create RL cluster
rl_cluster = rl_cluster_lib.RLCluster(
    actor=lora_policy,
    reference=ref_model,
    tokenizer=tokenizer,
    cluster_config=cluster_config,
)

# Create GRPO trainer with our multi-domain reward function
grpo_trainer = GRPOLearner(
    rl_cluster=rl_cluster,
    reward_fns=[compute_reward_batch],  # Our custom multi-domain reward!
    grpo_config=grpo_config,
)

print("‚úì GRPO trainer created")
print(f"  - Actor: LoRA policy model")
print(f"  - Reference: Frozen base model")
print(f"  - Reward function: Multi-domain (7 domains)")

# show_hbm_usage()

## Cell 13: Run Training

**This will take ~7 hours. Monitor the logs for:**
- Average reward increasing
- Format compliance >95%
- No OOM errors

In [None]:
# print("="*80)
# print("STARTING GRPO TRAINING")
# print("="*80)
# print(f"Training steps: {MAX_STEPS}")
# print(f"Checkpoint interval: {SAVE_INTERVAL_STEPS} steps")
# print(f"Estimated time: ~7 hours")
# print("="*80)

# with mesh:
#     grpo_trainer.train(train_dataset)

# print("="*80)
# print("TRAINING COMPLETE")
# print("="*80)

## Cell 14: Load Best Checkpoint

In [None]:
# print("Loading latest checkpoint for evaluation...")

# # Find the latest checkpoint
# actor_ckpt_dir = os.path.join(CKPT_DIR, "actor")

# latest_step = -1
# if os.path.exists(actor_ckpt_dir):
#     for item in os.listdir(actor_ckpt_dir):
#         if os.path.isdir(os.path.join(actor_ckpt_dir, item)) and re.match(r'^\d+$', item):
#             step = int(item)
#             if step > latest_step:
#                 latest_step = step

# if latest_step == -1:
#     print("‚ö† No checkpoints found, using current model state")
# else:
#     print(f"Loading checkpoint from step {latest_step}...")
    
#     trained_ckpt_path = os.path.join(CKPT_DIR, "actor", str(latest_step), "model_params")
    
#     abs_params = jax.tree.map(
#         lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
#         nnx.state(lora_policy, nnx.LoRAParam),
#     )
#     checkpointer = ocp.StandardCheckpointer()
#     trained_lora_params = checkpointer.restore(trained_ckpt_path, target=abs_params)
    
#     nnx.update(
#         lora_policy,
#         jax.tree.map(
#             lambda a, b: b,
#             nnx.state(lora_policy, nnx.LoRAParam),
#             trained_lora_params,
#         ),
#     )
#     print(f"‚úì Loaded checkpoint from step {latest_step}")

## Cell 15: Validation

## Cell 16: Save Final Model

In [None]:
# print("="*60)
# print("SAVING FINAL MODEL")
# print("="*60)

# # Save to /kaggle/working (persists after session)
# final_ckpt_path = "/kaggle/working/grpo_multi_domain_final"

# if os.path.exists(final_ckpt_path):
#     shutil.rmtree(final_ckpt_path)

# # Save LoRA parameters
# abs_params = jax.tree.map(
#     lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
#     nnx.state(lora_policy, nnx.LoRAParam),
# )
# checkpointer = ocp.StandardCheckpointer()
# lora_params = nnx.state(lora_policy, nnx.LoRAParam)
# checkpointer.save(final_ckpt_path, lora_params)
# checkpointer.wait_until_finished()

# print(f"‚úì Model saved to {final_ckpt_path}")

# # Create metadata for Kaggle dataset
# import json
# metadata = {
#     "title": "GRPO Multi-Domain Reasoning Model",
#     "id": "vserifoglu/grpo-multi-domain-final",
# }

# with open('/kaggle/working/dataset-metadata.json', 'w') as f:
#     json.dump(metadata, f)

# print("‚úì Metadata created")
# print("\nüéâ TRAINING COMPLETE!")
# print("\nNext steps:")
# print("1. Save notebook version")
# print("2. Go to Output tab")
# print("3. Create new dataset from output")
# print("4. Use for competition submission")

In [None]:
# print("="*60)
# print("SAVING FINAL MODEL")
# print("="*60)

# # Save to /kaggle/working (persists after session)
# final_ckpt_path = "/kaggle/working/grpo_multi_domain_final"

# if os.path.exists(final_ckpt_path):
#     shutil.rmtree(final_ckpt_path)

# # Save LoRA parameters
# abs_params = jax.tree.map(
#     lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
#     nnx.state(lora_policy, nnx.LoRAParam),
# )
# checkpointer = ocp.StandardCheckpointer()
# lora_params = nnx.state(lora_policy, nnx.LoRAParam)
# checkpointer.save(final_ckpt_path, lora_params)
# checkpointer.wait_until_finished()

# print(f"‚úì Model saved to {final_ckpt_path}")

# # Upload to Kaggle Datasets (automatic!)
# import kagglehub

# DATASET_HANDLE = "fissalalsharef/grpo-multi-domain-final_v2"

# print(f"\nUploading to Kaggle: {DATASET_HANDLE}")
# kagglehub.dataset_upload(
#     handle=DATASET_HANDLE,
#     local_dataset_dir=final_ckpt_path,
#     version_notes="Multi-domain GRPO training with DISCO balancing - Gemma 3 1B + LoRA"
# )

# print("‚úì Dataset uploaded!")
# print("\nüéâ TRAINING COMPLETE!")
# print(f"\nYour model is available at:")
# print(f"https://www.kaggle.com/datasets/{DATASET_HANDLE}")

## verify model if loadable

In [None]:
print("="*80)
print("PRODUCTION VALIDATION - MODEL SETUP")
print("="*80)

# Step 1: Download model from Kaggle
print("\n1. Downloading trained model from Kaggle...")
import kagglehub

DATASET_HANDLE = "fissalalsharef/grpo-multi-domain-final"
downloaded_path = kagglehub.dataset_download(DATASET_HANDLE)

print(f"‚úì Dataset downloaded to: {downloaded_path}")

# Step 2: Load base Gemma model
print("\n2. Loading base Gemma 3 1B model...")
verification_base = params.create_model_from_checkpoint(
    MODEL_CP_PATH, 
    model.ModelConfig.gemma3_1b()
)
print("‚úì Base model loaded")

# Step 3: Create LoRA structure
print("\n3. Creating LoRA model structure...")
verification_policy = get_lora_model(verification_base, mesh)

# Step 4: Load trained LoRA checkpoint
print(f"\n4. Loading LoRA checkpoint from Kaggle...")
abs_params = jax.tree.map(
    lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
    nnx.state(verification_policy, nnx.LoRAParam),
)

checkpointer = ocp.StandardCheckpointer()
loaded_params = checkpointer.restore(
    downloaded_path,
    target=abs_params
)

# Step 5: Apply LoRA params to base model
print("\n5. Applying LoRA parameters...")
nnx.update(
    verification_policy,
    jax.tree.map(
        lambda a, b: b,
        nnx.state(verification_policy, nnx.LoRAParam),
        loaded_params,
    ),
)
print("‚úì LoRA applied successfully")

# Step 6: Create sampler for testing
print("\n6. Creating sampler...")
verification_sampler = sampler_lib.Sampler(
    transformer=verification_policy,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=2048,  # 1024 + 600 for generation
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

print("\n" + "="*80)
print("‚úÖ MODEL LOADED & READY FOR VALIDATION!")
print("="*80)
print(f"‚úì Model: {DATASET_HANDLE}")
print(f"‚úì Sampler: verification_sampler")
print(f"‚úì Cache size: 1624 tokens")
print("\n‚ñ∂ Run Cell 1 to begin validation tests...")

In [None]:
print("="*80)
print("PRODUCTION-GRADE MODEL VALIDATION")
print("="*80)
import random
import json
from collections import defaultdict
# Sample diverse test set from validation data
print("\n1. Sampling test data...")
test_samples = []
for batch in val_dataset:
    for i in range(len(batch['domain'])):
        test_samples.append({
            'domain': batch['domain'][i],
            'question': batch['question'][i],
            'answer': batch['answer'][i],
        })
    if len(test_samples) >= 200:
        break
# Stratify by domain for balanced testing
domain_groups = defaultdict(list)
for sample in test_samples:
    # if len(sample['question']) > 4000:
    #     continue
    domain_groups[sample['domain']].append(sample)
# Sample evenly per domain (aim for ~15 per domain)
balanced_test_set = []
samples_per_domain = 15
for domain, samples in domain_groups.items():
    sample_count = min(len(samples), samples_per_domain)
    balanced_test_set.extend(random.sample(samples, sample_count))
print(f"\n‚úì Test set created: {len(balanced_test_set)} samples")
print("\nDomain distribution:")
for domain in set(s['domain'] for s in balanced_test_set):
    count = sum(1 for s in balanced_test_set if s['domain'] == domain)
    print(f"  {domain:20s}: {count:2d} samples")
# Add edge cases manually
edge_cases = [
    {"domain": "edge_short", "question": "What?", "answer": "N/A"},
    {"domain": "edge_short", "question": "?", "answer": "N/A"},
    {"domain": "edge_ambiguous", "question": "What is it?", "answer": "N/A"},
    {"domain": "edge_ambiguous", "question": "Can you help?", "answer": "N/A"},
]
print(f"\n‚úì Added {len(edge_cases)} edge cases")
print(f"\nTotal test samples: {len(balanced_test_set) + len(edge_cases)}")

In [None]:
# only_logic = []
# for i in balanced_test_set:
#     if i["domain"] == "math":
#         only_logic.append(i)

In [None]:
print("="*80)
print("TEST 1: FORMAT COMPLIANCE (with Truncation Detection)")
print("="*80)


MAX_TOKENS = 700  # Generous limit to reduce truncation

def analyze_format_failure(output, output_length, max_tokens):
    """Classify failure as truncation or true model failure"""
    
    has_reasoning_open = "<reasoning>" in output
    has_reasoning_close = "</reasoning>" in output
    has_answer_open = "<answer>" in output
    has_answer_close = "</answer>" in output
    
    tags_present = sum([has_reasoning_open, has_reasoning_close, 
                        has_answer_open, has_answer_close])
    
    near_limit = output_length >= (max_tokens - 10)  # Within 10 tokens of limit
    
    # Classification logic
    if not has_reasoning_open and not has_answer_open:
        return "TRUE_FAILURE", "Missing opening tags - model didn't learn format"
    
    if near_limit and tags_present >= 2:
        if not has_answer_close:
            return "TRUNCATION", "Hit token limit before closing </answer>"
        if not has_reasoning_close:
            return "TRUNCATION", "Hit token limit before closing </reasoning>"
    
    if output_length < 100 and tags_present < 2:
        return "TRUE_FAILURE", "Short output with no format - model didn't try"
    
    if tags_present >= 3 and not has_answer_close:
        if near_limit:
            return "TRUNCATION", "Almost complete, hit limit"
        else:
            return "TRUE_FAILURE", "Had room but didn't close tags"
    
    return "TRUE_FAILURE", "Other format issue"


def validate_format(output):
    """Check if format is valid"""
    checks = {
        'has_reasoning_open': '<reasoning>' in output,
        'has_reasoning_close': '</reasoning>' in output,
        'has_answer_open': '<answer>' in output,
        'has_answer_close': '</answer>' in output,
    }
    
    all_present = all(checks.values())
    
    if all_present:
        r_open = output.find('<reasoning>')
        r_close = output.find('</reasoning>')
        a_open = output.find('<answer>')
        a_close = output.find('</answer>')
        
        correct_order = (r_open < r_close < a_open < a_close)
        checks['correct_order'] = correct_order
        return correct_order, checks
    
    return False, checks


# Run validation
print(f"\nTesting {len(balanced_test_set)} samples with {MAX_TOKENS} token limit...")
print("(This may take 10-15 minutes)\n")

results = []
true_failures = []
truncation_failures = []

for i, sample in enumerate(balanced_test_set):
    if len(sample['question']) > 4000:
        continue
    
    prompt = TEMPLATE.format(
        system_prompt=SYSTEM_PROMPT,
        question=sample['question']
    )
    
    try:
        output_data = verification_sampler(
            input_strings=[prompt],
            max_generation_steps=MAX_TOKENS,
            temperature=0.7,
            echo=False,
            eos_tokens=[1, 106],
        )
        output = output_data.text[0]
        output_tokens = len(tokenizer.encode(output))
        
        is_valid, checks = validate_format(output)
        
        result = {
            'domain': sample['domain'],
            'question': sample['question'],
            'valid': is_valid,
            'output': output,
            'output_tokens': output_tokens,
            'checks': checks,
            'failure_type': None,
            'failure_reason': None,
            'prompt': f'{str(prompt)}'
        }
        
        if not is_valid:
            failure_type, reason = analyze_format_failure(output, output_tokens, MAX_TOKENS)
            result['failure_type'] = failure_type
            result['failure_reason'] = reason
            
            if failure_type == "TRUNCATION":
                truncation_failures.append(result)
            else:
                true_failures.append(result)
        
        results.append(result)
        
        # Progress
        if (i + 1) % 10 == 0:
            valid = sum(1 for r in results if r['valid'])
            trunc = len(truncation_failures)
            true_f = len(true_failures)
            print(f"Progress: {i+1}/{len(balanced_test_set)} | "
                  f"Valid: {valid} | Truncated: {trunc} | True Failures: {true_f}")
        
        
    except Exception as e:
        print(f"‚ùå Error on sample {i+1}: {e}")
        continue

In [None]:
# only_false = []
# for i in results:
#     if i["valid"] is False:
#         only_false.append(i)

# only_false

In [None]:
# ============================================================
# RESULTS SUMMARY
# ============================================================
print("\n" + "="*80)
print("VALIDATION RESULTS")
print("="*80)

total = len(results)
valid_count = sum(1 for r in results if r['valid'])
trunc_count = len(truncation_failures)
true_fail_count = len(true_failures)

# Overall stats
print(f"\nüìä OVERALL:")
print(f"  Total tested: {total}")
print(f"  ‚úÖ Valid format: {valid_count} ({valid_count/total*100:.1f}%)")
print(f"  ‚ö†Ô∏è Truncation failures: {trunc_count} ({trunc_count/total*100:.1f}%)")
print(f"  ‚ùå True model failures: {true_fail_count} ({true_fail_count/total*100:.1f}%)")

# Adjusted compliance (excluding truncation)
adjusted_total = valid_count + true_fail_count
adjusted_rate = (valid_count / adjusted_total * 100) if adjusted_total > 0 else 0
print(f"\nüìà ADJUSTED (excluding truncation):")
print(f"  Format compliance: {valid_count}/{adjusted_total} ({adjusted_rate:.1f}%)")

# Per-domain breakdown
print(f"\nüìã PER-DOMAIN BREAKDOWN:")
print(f"{'Domain':<20} {'Valid':>6} {'Trunc':>6} {'Fail':>6} {'Total':>6} {'Rate':>8}")
print("-" * 60)

domain_stats = defaultdict(lambda: {'valid': 0, 'trunc': 0, 'fail': 0, 'total': 0})

for r in results:
    d = r['domain']
    domain_stats[d]['total'] += 1
    if r['valid']:
        domain_stats[d]['valid'] += 1
    elif r['failure_type'] == 'TRUNCATION':
        domain_stats[d]['trunc'] += 1
    else:
        domain_stats[d]['fail'] += 1

for domain, stats in sorted(domain_stats.items()):
    rate = (stats['valid'] / stats['total'] * 100) if stats['total'] > 0 else 0
    adj_rate = (stats['valid'] / (stats['valid'] + stats['fail']) * 100) if (stats['valid'] + stats['fail']) > 0 else 100
    status = "‚úÖ" if adj_rate >= 90 else "‚ö†Ô∏è" if adj_rate >= 75 else "‚ùå"
    print(f"{status} {domain:<18} {stats['valid']:>6} {stats['trunc']:>6} {stats['fail']:>6} {stats['total']:>6} {adj_rate:>7.1f}%")

# ============================================================
# DECISION LOGIC
# ============================================================
print("\n" + "="*80)
print("DIAGNOSIS & DECISION")
print("="*80)

if true_fail_count == 0:
    print("\n‚úÖ NO TRUE MODEL FAILURES!")
    if trunc_count > 0:
        print(f"   ‚Üí {trunc_count} truncation issues: Increase max_tokens to fix")
    print("   ‚Üí Model learned format correctly")
    print("   ‚Üí NO RETRAINING NEEDED")

elif adjusted_rate >= 90:
    print(f"\n‚úÖ HIGH COMPLIANCE ({adjusted_rate:.1f}%)")
    print(f"   ‚Üí {true_fail_count} true failures (acceptable)")
    print(f"   ‚Üí {trunc_count} truncation issues (increase tokens)")
    print("   ‚Üí SHIP IT (competition ready)")

elif adjusted_rate >= 75:
    print(f"\n‚ö†Ô∏è MODERATE COMPLIANCE ({adjusted_rate:.1f}%)")
    print(f"   ‚Üí {true_fail_count} true failures (investigate)")
    print("   ‚Üí BORDERLINE: Your decision to ship or retrain")

else:
    print(f"\n‚ùå LOW COMPLIANCE ({adjusted_rate:.1f}%)")
    print(f"   ‚Üí {true_fail_count} true failures (critical)")
    print("   ‚Üí RETRAIN RECOMMENDED")

# Show sample failures
if true_failures:
    print("\n" + "="*80)
    print("SAMPLE TRUE FAILURES (First 3)")
    print("="*80)
    
    for i, fail in enumerate(true_failures[:3]):
        print(f"\n--- Failure {i+1} ---")
        print(f"Domain: {fail['domain']}")
        print(f"Reason: {fail['failure_reason']}")
        print(f"Tokens: {fail['output_tokens']}/{MAX_TOKENS}")
        print(f"Output:\n{fail['output']}")