In [None]:
# Install Tunix
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix
!pip uninstall -q flax -y && pip install -q git+https://github.com/google/flax

# Imports
import jax
import jax.numpy as jnp
from flax import nnx
import grain
import optax
from tunix.models.gemma3 import model as gemma_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.grpo.grpo_learner import GRPOConfig, GRPOLearner
from tunix.rl.rollout import base_rollout
from tunix.generate import tokenizer_adapter as tokenizer_lib
import qwix
from huggingface_hub import snapshot_download
import kagglehub
import csv
import re

# V1: TPU Check
print(f"JAX devices: {jax.devices()}")
assert jax.device_count() > 0, "No TPU found!"

# Config (V4/V5 safe values)
TRAIN_MICRO_BATCH_SIZE = 2  # V4: Safe batch size
NUM_ITERATIONS = 1  # V5: Safe iteration count
NUM_BATCHES = 5000
RANK = 64
REASONING_WEIGHT = 0.6  # COMPETITIVE ADVANTAGE!
ANSWER_WEIGHT = 0.4

# Dataset
kaggle_path = kagglehub.competition_download("google-tunix-hack")
data = []
with open(f"{kaggle_path}/main_train.csv") as f:
    for row in csv.DictReader(f):
        data.append({"question": row["question"], "answer": row["answer"].split("####")[-1].strip()})

dataset = grain.MapDataset.source(data[:NUM_BATCHES]).batch(TRAIN_MICRO_BATCH_SIZE)

# Model
MODEL_ID = "google/gemma-2-2b-it"
local_model_path = snapshot_download(repo_id=MODEL_ID, ignore_patterns=["*.pth"])
model_config = gemma_lib.ModelConfig.gemma2_2b()
mesh = jax.make_mesh((1, 4), ("fsdp", "tp"))

with mesh:
    gemma2 = params_safetensors_lib.create_model_from_safe_tensors(local_model_path, model_config, mesh)

# LoRA
lora_provider = qwix.LoraProvider(module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj", rank=RANK, alpha=64.0)
lora_policy = qwix.apply_lora_to_model(gemma2, lora_provider, **gemma2.get_model_input())

# COMPETITIVE ADVANTAGE: Trajectory Reward
def trajectory_reward(prompts, completions, answer, **kwargs):
    rewards = []
    for comp, ref in zip(completions, answer):
        # Reasoning quality
        reasoning = re.search(r'<reasoning>(.*?)</reasoning>', comp, re.DOTALL)
        r_score = 0.5 if reasoning and len(reasoning.group(1).split()) > 20 else 0.0
        
        # Answer correctness
        ans = re.search(r'<answer>(.*?)</answer>', comp, re.DOTALL)
        a_score = 1.0 if ans and ans.group(1).strip() == ref else 0.0
        
        # 60/40 split!
        reward = (REASONING_WEIGHT * r_score + ANSWER_WEIGHT * a_score) * 3.0
        rewards.append(reward)
    return rewards

# GRPO
optimizer = optax.adamw(learning_rate=5e-5)
cluster_config = rl_cluster_lib.ClusterConfig(
    role_to_mesh={rl_cluster_lib.Role.ACTOR: mesh, rl_cluster_lib.Role.REFERENCE: mesh, rl_cluster_lib.Role.ROLLOUT: mesh},
    rollout_engine='vanilla',
    training_config=rl_cluster_lib.RLTrainingConfig(
        actor_optimizer=optimizer,
        max_steps=NUM_BATCHES,
        mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
        train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
        checkpoint_root_directory="/kaggle/working/checkpoints"
    ),
    rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=768, max_prompt_length=256)
)

grpo_config = GRPOConfig(num_generations=2, num_iterations=NUM_ITERATIONS, beta=0.08, epsilon=0.2)

tokenizer = tokenizer_lib.Tokenizer(tokenizer_path="gs://gemma-data/tokenizers/tokenizer.model")
rl_cluster = rl_cluster_lib.RLCluster(actor=lora_policy, reference=gemma2, tokenizer=tokenizer, cluster_config=cluster_config)
grpo_trainer = GRPOLearner(rl_cluster=rl_cluster, reward_fns=[trajectory_reward], grpo_config=grpo_config)

# Train!
print("ðŸš€ Starting training with trajectory reward (60% reasoning + 40% answer)...")
with mesh:
    grpo_trainer.train(dataset, None)

print("âœ… Training complete!")