# Unified Reasoning RL Training

This notebook implements a **two-phase training pipeline** for mathematical reasoning:

## Training Pipeline

```
Phase 1: SFT Cold Start (Required)
    ‚îî‚îÄ‚îÄ Teaches model the <think>...</think> format
    ‚îî‚îÄ‚îÄ Train for 1 epoch on format examples
    ‚îî‚îÄ‚îÄ Saves to: checkpoints/sft/
            ‚îÇ
            ‚ñº
Phase 2: RL Training (Choose Algorithm)
    ‚îî‚îÄ‚îÄ Loads SFT checkpoint as base
    ‚îî‚îÄ‚îÄ PPO / GRPO / Dr.GRPO / GSPO / DAPO / GRPO-LEAD
    ‚îî‚îÄ‚îÄ Saves to: checkpoints/<algorithm>/
```

## Why SFT First?
RL algorithms require the model to spontaneously emit `<think>` tags. An untrained model won't do this. SFT "cold start" teaches the format before RL optimizes for correctness.

---

## 1. Configuration

In [None]:
#@title Training Configuration {display-mode: "form"}

#@markdown ### Select Training Phase
TRAINING_PHASE = "Phase 1: SFT Cold Start" #@param ["Phase 1: SFT Cold Start", "Phase 2: RL Training"]

#@markdown ### RL Algorithm (only used in Phase 2)
#@markdown **Recommended:** GRPO (stable baseline) or DR.GRPO (length-corrected)
RL_ALGORITHM = "GRPO" #@param ["PPO", "GRPO", "DR.GRPO", "GSPO", "DAPO", "GRPO-LEAD"]

#@markdown ---
#@markdown ### Model & Data
MODEL_NAME = "Qwen/Qwen2.5-Math-1.5B-Instruct" #@param {type:"string"}
DATASET_NAME = "openai/gsm8k" #@param ["openai/gsm8k", "HuggingFaceH4/openr1-math-220k"]

#@markdown ---
#@markdown ### Hyperparameters (Optimized for Performance)
#@markdown **Memory Guide (with gradient checkpointing):**
#@markdown - T4 (16GB): batch=1, group=2
#@markdown - L4 (22GB): batch=1, group=4
#@markdown - A100 40GB: batch=2, group=8
#@markdown - A100 80GB: batch=8, group=8
BATCH_SIZE = 8 #@param {type:"integer"}
GROUP_SIZE = 8 #@param {type:"integer"}
MAX_SAMPLES = 5000 #@param {type:"integer"}
EPOCHS = 1 #@param {type:"integer"}
SAVE_STEPS = 100 #@param {type:"integer"}

#@markdown ### Phase-Specific Settings (auto-configured below)
MAX_NEW_TOKENS = 384 #@param {type:"integer"}
PPO_EPOCHS = 1 #@param {type:"integer"}

#@markdown ---
#@markdown ### Google Drive Settings
PROJECT_NAME = "unified-reasoning-rl" #@param {type:"string"}
RESUME_FROM_CHECKPOINT = True #@param {type:"boolean"}

#@markdown ### Logging
USE_WANDB = False #@param {type:"boolean"}
WANDB_PROJECT = "UnifiedRL" #@param {type:"string"}

# ============================================================
# AUTO-CONFIGURATION BASED ON PHASE
# ============================================================
IS_SFT_PHASE = "SFT" in TRAINING_PHASE
ALGORITHM = "SFT" if IS_SFT_PHASE else RL_ALGORITHM

# Phase-specific learning rates (key for good performance!)
if IS_SFT_PHASE:
    LEARNING_RATE = 2e-5   # Higher LR for SFT (format learning is easy)
    _MAX_NEW_TOKENS = 512  # Not used in SFT but set for reference
    _PPO_EPOCHS = 1
else:
    LEARNING_RATE = 5e-6   # Lower LR for RL (stability)
    _MAX_NEW_TOKENS = MAX_NEW_TOKENS  # 384 recommended for reasoning
    _PPO_EPOCHS = PPO_EPOCHS

# Derived paths
DRIVE_BASE_PATH = f"/content/drive/MyDrive/Colab Notebooks/{PROJECT_NAME}"
SFT_CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/sft"
RL_CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/{RL_ALGORITHM.lower().replace('.', '_')}"
CHECKPOINT_DIR = SFT_CHECKPOINT_DIR if IS_SFT_PHASE else RL_CHECKPOINT_DIR

# ============================================================
# DISPLAY CONFIGURATION
# ============================================================
print("=" * 60)
print(f"üéØ TRAINING PHASE: {TRAINING_PHASE}")
print("=" * 60)
print(f"Algorithm: {ALGORITHM}")
print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_NAME}")
print("-" * 60)
print("HYPERPARAMETERS:")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Batch Size: {BATCH_SIZE} | Group Size: {GROUP_SIZE}")
print(f"  Effective sequences/step: {BATCH_SIZE * GROUP_SIZE}")
print(f"  Max New Tokens: {_MAX_NEW_TOKENS}")
print(f"  PPO Epochs: {_PPO_EPOCHS}")
print(f"  Total Samples: {MAX_SAMPLES}")
print(f"  Training Epochs: {EPOCHS}")
print("-" * 60)
print(f"Checkpoints: {CHECKPOINT_DIR}")
if not IS_SFT_PHASE:
    print(f"SFT Base: {SFT_CHECKPOINT_DIR}/final")
print("=" * 60)

# Performance tips
if IS_SFT_PHASE:
    print("\nüí° SFT Tips:")
    print("   ‚Ä¢ 1 epoch is usually enough (avoid overfitting)")
    print("   ‚Ä¢ Loss should decrease to ~1.5-2.5")
    print("   ‚Ä¢ If loss < 1.0, you may be overfitting")
else:
    print("\nüí° RL Tips:")
    print("   ‚Ä¢ Watch accuracy - should increase over time")
    print("   ‚Ä¢ KL divergence should stay < 0.1")
    print("   ‚Ä¢ If accuracy stuck at 0%, check reward function")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Create project directories
os.makedirs(DRIVE_BASE_PATH, exist_ok=True)
os.makedirs(SFT_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RL_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(f"{DRIVE_BASE_PATH}/logs", exist_ok=True)

print(f"Project directory: {DRIVE_BASE_PATH}")
print(f"SFT checkpoints: {SFT_CHECKPOINT_DIR}")
print(f"RL checkpoints: {RL_CHECKPOINT_DIR}")

# Check if SFT has been completed (for Phase 2)
SFT_COMPLETED = os.path.exists(f"{SFT_CHECKPOINT_DIR}/final")
if not IS_SFT_PHASE and not SFT_COMPLETED:
    print("\n" + "!" * 60)
    print("WARNING: SFT checkpoint not found!")
    print("Please run Phase 1 (SFT Cold Start) first.")
    print("!" * 60)
elif SFT_COMPLETED:
    print(f"\n‚úì SFT checkpoint found: {SFT_CHECKPOINT_DIR}/final")

## 3. Install Dependencies

In [None]:
%%capture
# First, fix NumPy version (must be done before other installs)
!pip uninstall numpy -y
!pip install "numpy<2.0.0"

# Install core dependencies
!pip install -q torch torchvision torchaudio
!pip install -q transformers>=4.40.0 peft>=0.10.0 accelerate>=0.30.0
!pip install -q datasets>=2.18.0 scipy pyyaml tqdm wandb
!pip install -q bitsandbytes>=0.43.0 safetensors

# Flash Attention 2 (OPTIONAL - skip if you want to start training faster)
# Uncomment the line below to install (takes ~15-20 min to compile on Colab)
# !pip install flash-attn --no-build-isolation

In [None]:
# Verify installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    raise RuntimeError("GPU not available! Please enable GPU runtime.")

## 4. Source Code

In [None]:
#@title 4.1 Utility Functions {display-mode: "form"}

import os
import random
import numpy as np
import torch
import logging
import glob
import re
import shutil

def seed_everything(seed: int = 42):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_logger(name: str, log_dir: str = None):
    """Configures a standardized logger."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    if not logger.handlers:
        ch = logging.StreamHandler()
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
            fh = logging.FileHandler(os.path.join(log_dir, "training.log"))
            fh.setFormatter(formatter)
            logger.addHandler(fh)
    return logger

def find_latest_checkpoint(checkpoint_dir: str):
    """Find the latest checkpoint in a directory."""
    if not os.path.exists(checkpoint_dir):
        return None, 0
    checkpoints = glob.glob(os.path.join(checkpoint_dir, "step_*"))
    if not checkpoints:
        return None, 0
    step_pattern = re.compile(r'step_(\d+)')
    steps = []
    for ckpt in checkpoints:
        match = step_pattern.search(ckpt)
        if match:
            steps.append((int(match.group(1)), ckpt))
    if not steps:
        return None, 0
    steps.sort(reverse=True)
    return steps[0][1], steps[0][0]

print("‚úì Utilities loaded")

In [None]:
#@title 4.2 Dataset {display-mode: "form"}

from datasets import load_dataset
from torch.utils.data import Dataset

class MathReasoningDataset(Dataset):
    """Wrapper for GSM8K/Math datasets."""
    
    def __init__(self, tokenizer, split="train", max_samples=None, mode="rl", dataset_name="openai/gsm8k"):
        self.tokenizer = tokenizer
        self.mode = mode
        self.dataset_name = dataset_name

        if "openr1" in dataset_name.lower():
            self.data = load_dataset(dataset_name, split=split)
            self.problem_key = "problem"
            self.solution_key = "solution"
            self.answer_key = "answer"
        elif "gsm8k" in dataset_name.lower():
            self.data = load_dataset(dataset_name, "main", split=split)
            self.problem_key = "question"
            self.solution_key = "answer"
            self.answer_key = "answer"
        else:
            self.data = load_dataset(dataset_name, split=split)
            self.problem_key = "problem"
            self.solution_key = "solution"
            self.answer_key = "answer"

        if max_samples:
            max_samples = min(max_samples, len(self.data))
            self.data = self.data.select(range(max_samples))

    def __len__(self):
        return len(self.data)

    def _extract_gsm8k_answer(self, solution_text):
        if "####" in solution_text:
            return solution_text.split("####")[-1].strip()
        return solution_text.strip()

    def __getitem__(self, idx):
        item = self.data[idx]
        problem = item.get(self.problem_key, "")

        prompt = (
            "<|im_start|>system\n"
            "Please reason step by step and put your final answer within \\boxed{}.<|im_end|>\n"
            "<|im_start|>user\n"
            f"{problem}<|im_end|>\n"
            "<|im_start|>assistant\n"
            "<think>"
        )

        if self.mode == 'sft':
            solution = item.get(self.solution_key, "")
            # Format solution with think tags for SFT
            full_text = prompt + "\n" + solution + "\n</think>\n\\boxed{" + self._extract_gsm8k_answer(solution) + "}<|im_end|>"
            return {"text": full_text}

        answer = self._extract_gsm8k_answer(item.get(self.answer_key, "")) if "gsm8k" in self.dataset_name.lower() else item.get(self.answer_key, "")
        return {"prompt": prompt, "ground_truth": answer}

def collate_fn(batch):
    if "text" in batch[0]:
        return [b["text"] for b in batch]
    return {"prompts": [b["prompt"] for b in batch], "ground_truths": [b["ground_truth"] for b in batch]}

print("‚úì Dataset loaded")

In [None]:
#@title 4.3 Model {display-mode: "form"}

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

# Auto-detect Flash Attention
FLASH_ATTN_AVAILABLE = False
try:
    import flash_attn
    FLASH_ATTN_AVAILABLE = True
    print("‚úì Flash Attention 2 available")
except ImportError:
    print("‚ö† Flash Attention not installed - using eager attention (slower)")

class UnifiedPolicyModel(nn.Module):
    """Model with 4-bit quantization + LoRA using standard transformers + bitsandbytes."""
    
    def __init__(self, model_name: str, algo: str, max_seq_length: int = 2048, load_in_4bit: bool = True):
        super().__init__()
        self.algo = algo.upper()
        self.device = None
        self.model_name = model_name

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.padding_side = "left"

        # Auto-select attention implementation
        attn_impl = "flash_attention_2" if FLASH_ATTN_AVAILABLE else "eager"
        print(f"Using attention: {attn_impl}")

        # 4-bit quantization config
        if load_in_4bit:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True,
                attn_implementation=attn_impl,
                torch_dtype=torch.bfloat16,
            )
            # Enable gradient checkpointing - required to fit in memory
            self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=True)
            print("‚úì Gradient checkpointing enabled")
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True,
                attn_implementation=attn_impl,
            )
            self.model.gradient_checkpointing_enable()

        # Add LoRA
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=16,
            lora_alpha=32,
            lora_dropout=0.0,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            bias="none",
        )
        self.model = get_peft_model(self.model, lora_config)
        self.model.print_trainable_parameters()

        # Critic head for PPO
        self.critic = None
        if self.algo == 'PPO':
            hidden_size = self.model.config.hidden_size
            self.critic = nn.Linear(hidden_size, 1).to(torch.bfloat16)

    def to(self, device):
        self.device = device
        if self.critic is not None:
            self.critic = self.critic.to(device)
        return self

    def forward(self, input_ids, attention_mask=None):
        output_hidden = (self.critic is not None)
        if attention_mask is None:
            attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=output_hidden,
            use_cache=False,
        )
        logits = outputs.logits
        values = None
        if self.critic is not None:
            values = self.critic(outputs.hidden_states[-1]).squeeze(-1)
        return logits, values

    def generate(self, **kwargs):
        return self.model.generate(**kwargs)

    def save_pretrained(self, path):
        os.makedirs(path, exist_ok=True)
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        if self.critic:
            torch.save(self.critic.state_dict(), f"{path}/critic.pt")
        with open(f"{path}/training_info.txt", "w") as f:
            f.write(f"algorithm: {self.algo}\n")
            f.write(f"base_model: {self.model_name}\n")

    def load_adapter_from_checkpoint(self, path):
        """Load LoRA adapter weights from a checkpoint."""
        from peft import set_peft_model_state_dict
        import safetensors.torch
        
        adapter_safetensors = os.path.join(path, "adapter_model.safetensors")
        adapter_bin = os.path.join(path, "adapter_model.bin")
        
        if os.path.exists(adapter_safetensors):
            state_dict = safetensors.torch.load_file(adapter_safetensors)
            set_peft_model_state_dict(self.model, state_dict)
            print(f"‚úì Loaded adapter from {path}")
        elif os.path.exists(adapter_bin):
            state_dict = torch.load(adapter_bin, map_location="cuda")
            set_peft_model_state_dict(self.model, state_dict)
            print(f"‚úì Loaded adapter from {path}")
        else:
            print(f"‚ö† No adapter found at {path}")
            
        if self.critic and os.path.exists(f"{path}/critic.pt"):
            self.critic.load_state_dict(torch.load(f"{path}/critic.pt", map_location="cuda"))
            print(f"‚úì Loaded critic from {path}")

In [None]:
#@title 4.4 Trainer {display-mode: "form"}

import torch
import torch.nn.functional as F
from torch.optim import AdamW
import re
import gc

class UnifiedReasoningTrainer:
    def __init__(self, policy_model, config, device):
        self.policy = policy_model
        self.config = config
        self.algo = config['algo'].upper()
        self.G = config['group_size']
        self.device = device
        self.max_new_tokens = config.get('max_new_tokens', 384)
        
        self.optimizer = AdamW(self.policy.model.parameters(), lr=float(config['learning_rate']))
        self.critic_optimizer = None
        if self.policy.critic:
            self.critic_optimizer = AdamW(self.policy.critic.parameters(), lr=1e-4)
        self.ppo_epochs = config.get('ppo_epochs', 2)

    def extract_answer(self, text):
        match = re.search(r'\\boxed\{', text)
        if not match:
            return ""
        start_idx = match.end()
        brace_count = 1
        idx = start_idx
        while idx < len(text) and brace_count > 0:
            if text[idx] == '{': brace_count += 1
            elif text[idx] == '}': brace_count -= 1
            idx += 1
        return text[start_idx:idx-1] if brace_count == 0 else ""

    def compute_rewards(self, completions, ground_truths):
        rewards = []
        for comp, gt in zip(completions, ground_truths):
            pred = self.extract_answer(comp).strip().replace(" ", "")
            gt_norm = str(gt).strip().replace(" ", "")
            rewards.append(1.0 if pred == gt_norm else 0.0)
        return torch.tensor(rewards, device=self.device, dtype=torch.float32)

    def compute_gae(self, rewards, values, gamma=0.99, lam=0.95):
        advantages = torch.zeros_like(rewards)
        last_gae = 0
        for t in reversed(range(rewards.size(1))):
            next_val = values[:, t + 1] if t + 1 < rewards.size(1) else 0.0
            delta = rewards[:, t] + gamma * next_val - values[:, t]
            last_gae = delta + gamma * lam * last_gae
            advantages[:, t] = last_gae
        return advantages, advantages + values

    def compute_kl(self, log_probs, old_log_probs, mask):
        return (0.5 * ((log_probs - old_log_probs) ** 2 * mask).sum() / mask.sum()).item()

    def compute_entropy(self, logits, mask):
        probs = F.softmax(logits, dim=-1)
        ent = -(probs * F.log_softmax(logits, dim=-1)).sum(dim=-1)
        return (ent * mask).sum().item() / mask.sum().item()

    def loss_ppo(self, lp, olp, adv, ret, val, mask):
        ratio = torch.exp(lp - olp)
        s1, s2 = ratio * adv, torch.clamp(ratio, 0.8, 1.2) * adv
        pl = -(torch.min(s1, s2) * mask).sum() / mask.sum()
        vl = ((val - ret) ** 2 * mask).sum() / mask.sum()
        return pl + 0.5 * vl

    def loss_grpo(self, lp, olp, adv, mask):
        ratio = torch.exp(lp - olp)
        s1, s2 = ratio * adv, torch.clamp(ratio, 0.8, 1.2) * adv
        return (-torch.min(s1, s2) * mask).sum() / mask.sum()

    def loss_dr_grpo(self, lp, olp, adv, mask):
        B = adv.shape[0] // self.G
        lens = mask.sum(1).float().view(B, self.G)
        scale = (lens / (lens.mean(1, keepdim=True) + 1e-6)).view(-1, 1).expand_as(lp)
        ratio = torch.exp(lp - olp)
        loss = -torch.min(ratio * adv, torch.clamp(ratio, 0.8, 1.2) * adv) * scale
        return (loss * mask).sum() / mask.sum()

    def loss_gspo(self, lp, olp, adv, mask):
        ld = (lp - olp) * mask
        rho = torch.exp(ld.sum(1) / (mask.sum(1) + 1e-6)).unsqueeze(-1)
        return -torch.min(rho * adv, torch.clamp(rho, 0.8, 1.2) * adv).mean()

    def loss_dapo(self, lp, olp, rg, adv, mask):
        vm = rg.std(1) > 0
        ratio = torch.exp(lp - olp)
        upper = torch.where(adv > 0, 1.28, 1.20)
        s2 = torch.clamp(ratio, 0.8, upper) * adv
        loss = -torch.min(ratio * adv, s2) * mask
        ve = vm.repeat_interleave(self.G).view(-1, 1).expand_as(loss)
        if ve.sum() == 0: return torch.tensor(0.0, device=self.device, requires_grad=True)
        return (loss * ve).sum() / (mask * ve).sum()

    def loss_grpo_lead(self, lp, olp, rew, mask):
        lens = mask.sum(1).float()
        corr = rew == 1.0
        if corr.any():
            z = (lens - lens[corr].mean()) / (lens[corr].std() + 1e-6)
            rew = torch.where(corr, rew * torch.exp(-0.1 * z.abs()), rew)
        B = rew.shape[0] // self.G
        rg = rew.view(B, self.G)
        pr = rg.mean(1).repeat_interleave(self.G)
        dw = (2.0 - pr).view(-1, 1).expand_as(lp)
        adv = ((rg - rg.mean(1, keepdim=True)) / (rg.std(1, keepdim=True) + 1e-6)).view(-1, 1).expand_as(lp) * dw
        ratio = torch.exp(lp - olp)
        return (-torch.min(ratio * adv, torch.clamp(ratio, 0.8, 1.2) * adv) * mask).sum() / mask.sum()

    def train_step(self, batch):
        if self.algo == 'SFT':
            texts = batch
            inputs = self.policy.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=2048).to(self.device)
            out = self.policy.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, use_cache=False)
            logits = out.logits[..., :-1, :].clone()
            labels = inputs.input_ids[..., 1:].clone()
            am = inputs.attention_mask[..., 1:]
            labels = labels.masked_fill(am == 0, -100)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            return {'loss': loss.item(), 'reward': 0.0, 'kl_divergence': 0.0, 'entropy': 0.0, 'avg_response_length': 0.0, 'accuracy': 0.0}

        prompts, gts = batch['prompts'], batch['ground_truths']
        pe = [p for p in prompts for _ in range(self.G)]
        ge = [g for g in gts for _ in range(self.G)]
        inputs = self.policy.tokenizer(pe, return_tensors="pt", padding=True, padding_side="left").to(self.device)
        
        # Generation (no grad)
        with torch.no_grad():
            self.policy.model.eval()
            outs = self.policy.generate(**inputs, max_new_tokens=self.max_new_tokens, do_sample=True, temperature=0.8, use_cache=True)
        
        pl = inputs.input_ids.shape[1]
        cids = outs[:, pl:]
        am = (cids != self.policy.tokenizer.pad_token_id).float()
        dec = self.policy.tokenizer.batch_decode(cids, skip_special_tokens=True)
        rew = self.compute_rewards(dec, ge)
        
        # Free generation cache
        del inputs
        torch.cuda.empty_cache()
        
        # Compute old log probs (no grad)
        with torch.no_grad():
            lg, ov = self.policy(outs)
            lg = lg[:, pl-1:-1, :].clone()
            olp = -F.cross_entropy(lg.reshape(-1, lg.size(-1)), cids.reshape(-1), reduction='none').view(cids.shape)
            if ov is not None: ov = ov[:, pl-1:-1].clone()
            del lg  # Free memory
            torch.cuda.empty_cache()

        B = rew.shape[0] // self.G
        if self.algo == 'PPO':
            sr = torch.zeros_like(olp)
            li = am.sum(1).long() - 1
            for i, idx in enumerate(li):
                if idx >= 0: sr[i, idx] = rew[i]
            adv, ret = self.compute_gae(sr, ov)
            adv = (adv - adv.mean()) / (adv.std() + 1e-8)
        else:
            rg = rew.view(B, self.G)
            adv = ((rg - rg.mean(1, keepdim=True)) / (rg.std(1, keepdim=True) + 1e-6)).view(-1, 1).expand_as(olp)
            ret = None

        # Training loop
        self.policy.model.train()
        for _ in range(self.ppo_epochs):
            lg, val = self.policy(outs)
            lg = lg[:, pl-1:-1, :].clone()
            if val is not None: val = val[:, pl-1:-1]
            lp = -F.cross_entropy(lg.reshape(-1, lg.size(-1)), cids.reshape(-1), reduction='none').view(cids.shape)

            if self.algo == 'PPO': loss = self.loss_ppo(lp, olp, adv, ret, val, am)
            elif self.algo == 'GRPO': loss = self.loss_grpo(lp, olp, adv, am)
            elif self.algo == 'DR.GRPO': loss = self.loss_dr_grpo(lp, olp, adv, am)
            elif self.algo == 'GSPO': loss = self.loss_gspo(lp, olp, adv, am)
            elif self.algo == 'DAPO': loss = self.loss_dapo(lp, olp, rew.view(B, self.G), adv, am)
            elif self.algo == 'GRPO-LEAD': loss = self.loss_grpo_lead(lp, olp, rew, am)
            else: raise ValueError(f"Unknown: {self.algo}")

            self.optimizer.zero_grad()
            if self.critic_optimizer: self.critic_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.model.parameters(), 1.0)
            self.optimizer.step()
            if self.critic_optimizer: self.critic_optimizer.step()
            
            # Free intermediate tensors
            del lg
            torch.cuda.empty_cache()

        metrics = {'loss': loss.item(), 'reward': rew.mean().item(), 'kl_divergence': self.compute_kl(lp, olp, am),
                   'entropy': 0.0, 'avg_response_length': am.sum(1).float().mean().item(),
                   'accuracy': (rew == 1.0).float().mean().item()}
        
        # Cleanup
        del outs, cids, olp, lp, adv, am, rew
        gc.collect()
        torch.cuda.empty_cache()
        
        return metrics

print("‚úì Trainer loaded (with memory optimization)")

## 5. Initialize Training

In [None]:
# Validate Phase 2 requirements
if not IS_SFT_PHASE and not SFT_COMPLETED:
    raise RuntimeError(
        "\n" + "=" * 60 + "\n" +
        "ERROR: Cannot start RL training without SFT checkpoint!\n" +
        "Please run Phase 1 (SFT Cold Start) first.\n" +
        "Change TRAINING_PHASE to 'Phase 1: SFT Cold Start' and run again.\n" +
        "=" * 60
    )

seed_everything(42)
device = "cuda"
logger = get_logger("Trainer", f"{DRIVE_BASE_PATH}/logs")

# Check for resume
resume_path, resume_step = None, 0
if RESUME_FROM_CHECKPOINT:
    resume_path, resume_step = find_latest_checkpoint(CHECKPOINT_DIR)
    if resume_path:
        logger.info(f"Found checkpoint at step {resume_step}")

print(f"\n{'=' * 60}")
print(f"Phase: {'SFT Cold Start' if IS_SFT_PHASE else 'RL Training'}")
print(f"Algorithm: {ALGORITHM}")
print(f"Resume from: {resume_path if resume_path else 'scratch'}")
print(f"{'=' * 60}\n")

In [None]:
# Load Model
print(f"Loading model: {MODEL_NAME}...")
policy = UnifiedPolicyModel(MODEL_NAME, ALGORITHM)
# Device is handled by device_map="auto", just set reference
policy.device = "cuda"
if policy.critic:
    policy.critic = policy.critic.to("cuda")

# For Phase 2: Load SFT checkpoint as base
if not IS_SFT_PHASE:
    sft_path = f"{SFT_CHECKPOINT_DIR}/final"
    print(f"Loading SFT weights from: {sft_path}")
    policy.load_adapter_from_checkpoint(sft_path)

# Resume from checkpoint if available
if resume_path:
    print(f"Resuming from: {resume_path}")
    policy.load_adapter_from_checkpoint(resume_path)

print("‚úì Model ready")

In [None]:
# Load Dataset
from torch.utils.data import DataLoader

mode = 'sft' if IS_SFT_PHASE else 'rl'
# Use smaller dataset for SFT (format tuning only needs ~5k samples)
sft_samples = min(MAX_SAMPLES, 5000) if IS_SFT_PHASE else MAX_SAMPLES

dataset = MathReasoningDataset(
    policy.tokenizer,
    max_samples=sft_samples,
    mode=mode,
    dataset_name=DATASET_NAME
)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

print(f"Dataset: {DATASET_NAME}")
print(f"Mode: {mode.upper()}")
print(f"Samples: {len(dataset)}")
print(f"Batches: {len(loader)}")

In [None]:
# Initialize Trainer
config = {
    'algo': ALGORITHM,
    'group_size': GROUP_SIZE if not IS_SFT_PHASE else 1,
    'learning_rate': LEARNING_RATE,  # Auto-configured: 2e-5 for SFT, 5e-6 for RL
    'ppo_epochs': _PPO_EPOCHS,
    'max_new_tokens': _MAX_NEW_TOKENS,
}

trainer = UnifiedReasoningTrainer(policy, config, device)

if USE_WANDB:
    import wandb
    wandb.init(
        project=WANDB_PROJECT, 
        name=f"{ALGORITHM}_{MAX_SAMPLES}samples",
        config={
            'phase': 'SFT' if IS_SFT_PHASE else 'RL', 
            'algorithm': ALGORITHM,
            'model': MODEL_NAME,
            'dataset': DATASET_NAME,
            **config
        }
    )

print("‚úì Trainer initialized")
print(f"  Algorithm: {config['algo']}")
print(f"  Learning Rate: {config['learning_rate']}")
print(f"  Group Size: {config['group_size']}")
print(f"  Max New Tokens: {config['max_new_tokens']}")
print(f"  PPO Epochs: {config['ppo_epochs']}")

## 6. Training Loop

In [None]:
from tqdm.notebook import tqdm
import time

step = resume_step
best_metric = 0.0
start_time = time.time()

phase_name = "SFT Cold Start" if IS_SFT_PHASE else f"RL ({ALGORITHM})"
print(f"\n{'=' * 60}")
print(f"üöÄ Starting {phase_name}")
print(f"{'=' * 60}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"Total steps: {len(loader) * EPOCHS}")
print(f"{'=' * 60}\n")

for epoch in range(EPOCHS):
    logger.info(f"Epoch {epoch + 1}/{EPOCHS}")
    pbar = tqdm(loader, desc=f"Epoch {epoch + 1}")
    
    epoch_metrics = {'loss': [], 'accuracy': [], 'kl': [], 'reward': []}
    
    for batch in pbar:
        metrics = trainer.train_step(batch)
        step += 1
        
        # Track metrics
        epoch_metrics['loss'].append(metrics['loss'])
        if not IS_SFT_PHASE:
            epoch_metrics['accuracy'].append(metrics['accuracy'])
            epoch_metrics['kl'].append(metrics['kl_divergence'])
            epoch_metrics['reward'].append(metrics['reward'])
        
        # Progress bar
        if IS_SFT_PHASE:
            pbar.set_postfix({'loss': f"{metrics['loss']:.4f}"})
        else:
            pbar.set_postfix({
                'loss': f"{metrics['loss']:.4f}",
                'acc': f"{metrics['accuracy']:.1%}",
                'reward': f"{metrics['reward']:.2f}",
                'kl': f"{metrics['kl_divergence']:.4f}",
            })
        
        if USE_WANDB:
            wandb.log({**metrics, 'step': step, 'epoch': epoch + 1})
        
        # Save checkpoint
        if step % SAVE_STEPS == 0:
            ckpt = f"{CHECKPOINT_DIR}/step_{step}"
            policy.save_pretrained(ckpt)
            logger.info(f"Saved: {ckpt}")
            
            # Track best (lowest loss for SFT, highest accuracy for RL)
            current = -metrics['loss'] if IS_SFT_PHASE else metrics['accuracy']
            if current > best_metric:
                best_metric = current
                policy.save_pretrained(f"{CHECKPOINT_DIR}/best")
                logger.info(f"New best! {'Loss' if IS_SFT_PHASE else 'Accuracy'}: {abs(best_metric):.4f}")
    
    # Epoch summary
    print(f"\nüìä Epoch {epoch + 1} Summary:")
    print(f"   Avg Loss: {sum(epoch_metrics['loss'])/len(epoch_metrics['loss']):.4f}")
    if not IS_SFT_PHASE:
        print(f"   Avg Accuracy: {sum(epoch_metrics['accuracy'])/len(epoch_metrics['accuracy']):.1%}")
        print(f"   Avg Reward: {sum(epoch_metrics['reward'])/len(epoch_metrics['reward']):.3f}")
        print(f"   Avg KL: {sum(epoch_metrics['kl'])/len(epoch_metrics['kl']):.4f}")

# Save final
policy.save_pretrained(f"{CHECKPOINT_DIR}/final")
logger.info(f"Saved final: {CHECKPOINT_DIR}/final")

elapsed = time.time() - start_time
if USE_WANDB:
    wandb.finish()

print(f"\n{'=' * 60}")
print(f"‚úÖ {phase_name} Complete!")
print(f"{'=' * 60}")
print(f"Total time: {elapsed/60:.1f} minutes")
print(f"Final checkpoint: {CHECKPOINT_DIR}/final")
print(f"Best checkpoint: {CHECKPOINT_DIR}/best")
if IS_SFT_PHASE:
    print(f"\nüëâ Next Step:")
    print(f"   1. Change TRAINING_PHASE to 'Phase 2: RL Training'")
    print(f"   2. Select RL_ALGORITHM (recommend: GRPO or DR.GRPO)")
    print(f"   3. Run all cells again")
else:
    print(f"\nüëâ Next Step:")
    print(f"   1. Run the Evaluation cell to test accuracy")
    print(f"   2. Try different RL algorithms for comparison")
print(f"{'=' * 60}")

## 7. Evaluation

In [None]:
#@title Evaluate Model {display-mode: "form"}

#@markdown ### Evaluation Settings
#@markdown - **Quick check**: 100 samples (~3-5 min)
#@markdown - **Development**: 300 samples (~10 min)  
#@markdown - **Full benchmark**: 1319 samples (~40-60 min)
EVAL_SAMPLES = 300 #@param {type:"integer"}
EVAL_TEMPERATURE = 0.0 #@param {type:"number"}
EVAL_MAX_TOKENS = 512 #@param {type:"integer"}

import time
eval_start = time.time()

# Load test set
eval_dataset = MathReasoningDataset(
    policy.tokenizer, 
    split="test", 
    max_samples=EVAL_SAMPLES, 
    mode='rl', 
    dataset_name=DATASET_NAME
)

print(f"{'=' * 50}")
print(f"üìä EVALUATION")
print(f"{'=' * 50}")
print(f"Samples: {len(eval_dataset)} / 1319 (GSM8K test set)")
print(f"Temperature: {EVAL_TEMPERATURE} ({'greedy' if EVAL_TEMPERATURE == 0 else 'sampling'})")
print(f"Max tokens: {EVAL_MAX_TOKENS}")
print(f"{'=' * 50}\n")

policy.model.eval()
correct, total = 0, 0
results = []

with torch.no_grad():
    for idx in tqdm(range(len(eval_dataset)), desc="Evaluating"):
        item = eval_dataset[idx]
        inputs = policy.tokenizer(item['prompt'], return_tensors="pt").to(device)
        
        if EVAL_TEMPERATURE == 0:
            outputs = policy.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS, do_sample=False)
        else:
            outputs = policy.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS, do_sample=True, temperature=EVAL_TEMPERATURE)
        
        response = policy.tokenizer.decode(outputs[0], skip_special_tokens=True)
        pred = trainer.extract_answer(response).strip().replace(" ", "")
        gt = str(item['ground_truth']).strip().replace(" ", "")
        
        is_correct = (pred == gt)
        if is_correct:
            correct += 1
        total += 1
        
        results.append({
            'idx': idx,
            'correct': is_correct,
            'pred': pred,
            'gt': gt,
        })

eval_time = time.time() - eval_start
accuracy = correct / total

print(f"\n{'=' * 50}")
print(f"üìà RESULTS")
print(f"{'=' * 50}")
print(f"Accuracy: {accuracy:.1%} ({correct}/{total})")
print(f"Time: {eval_time/60:.1f} minutes")
print(f"{'=' * 50}")

# Show some examples
print(f"\nüìù Sample Results (first 5 incorrect):")
incorrect = [r for r in results if not r['correct']][:5]
for r in incorrect:
    print(f"   #{r['idx']}: pred='{r['pred']}' vs gt='{r['gt']}'")

# Extrapolate to full test set
if EVAL_SAMPLES < 1319:
    margin = 1.96 * ((accuracy * (1-accuracy) / total) ** 0.5)  # 95% CI
    print(f"\nüìä Estimated full test accuracy: {accuracy:.1%} ¬± {margin:.1%} (95% CI)")

## 8. Test Inference

In [None]:
#@title Test Custom Problem {display-mode: "form"}

PROBLEM = "A store sells apples for $2 each. If you buy 5 apples and pay with a $20 bill, how much change do you get?" #@param {type:"string"}

prompt = f"""<|im_start|>system
Please reason step by step and put your final answer within \\boxed{{}}.<|im_end|>
<|im_start|>user
{PROBLEM}<|im_end|>
<|im_start|>assistant
<think>"""

inputs = policy.tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    out = policy.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
response = policy.tokenizer.decode(out[0], skip_special_tokens=True)

print("Problem:", PROBLEM)
print("\n" + "-" * 50)
print("Response:")
print(response.split("<think>")[-1] if "<think>" in response else response)
print("\n" + "-" * 50)
print(f"Answer: {trainer.extract_answer(response)}")

## 9. Checkpoint Management

In [None]:
# List all checkpoints
print("=" * 50)
print("SAVED CHECKPOINTS")
print("=" * 50)

for phase, path in [("SFT", SFT_CHECKPOINT_DIR), ("RL", RL_CHECKPOINT_DIR)]:
    print(f"\n{phase}: {path}")
    if os.path.exists(path):
        items = sorted(os.listdir(path))
        for item in items:
            ip = os.path.join(path, item)
            if os.path.isdir(ip):
                size = sum(os.path.getsize(os.path.join(ip, f)) for f in os.listdir(ip) if os.path.isfile(os.path.join(ip, f)))
                print(f"  ‚îî‚îÄ {item}: {size/1e6:.1f} MB")
    else:
        print("  (no checkpoints)")

In [None]:
# Cleanup GPU
import gc
del policy, trainer
gc.collect()
torch.cuda.empty_cache()
print("‚úì GPU memory cleared")