<a href="https://www.kaggle.com/code/pragnyanramtha/ai-math?scriptVersionId=282874042" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
"""
SFT Training Script for Phi-4-Reasoning-Plus
Unsloth + Maximum Quality Settings
"""

import os
import time
import torch
from pathlib import Path
from datetime import datetime, timedelta

from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported

from transformers import TrainingArguments, TrainerCallback
from datasets import load_dataset
from trl import SFTTrainer

os.environ["WANDB_DISABLED"] = "true"


class RealTimeLogger(TrainerCallback):
    def __init__(self):
        self.start_time = None
        
    def on_train_begin(self, args, state, control, **kwargs):
        self.start_time = time.time()
        print("\n" + "="*70)
        print("üöÄ TRAINING STARTED (Unsloth + LoRA r=512 + Max Quality)")
        print("="*70 + "\n")
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            step = state.global_step
            total = state.max_steps
            pct = (step / total) * 100 if total > 0 else 0
            loss = logs.get("loss", 0)
            lr = logs.get("learning_rate", 0)
            epoch = logs.get("epoch", 0)
            
            elapsed = time.time() - self.start_time
            steps_per_sec = step / elapsed if elapsed > 0 else 0
            remaining = (total - step) / steps_per_sec if steps_per_sec > 0 else 0
            
            if torch.cuda.is_available():
                mem_used = torch.cuda.memory_allocated() / 1e9
                mem_total = torch.cuda.get_device_properties(0).total_memory / 1e9
            else:
                mem_used, mem_total = 0, 1
            
            print(f"[{pct:5.1f}%] Step {step:>5}/{total} | "
                  f"Loss: {loss:.4f} | LR: {lr:.2e} | "
                  f"Epoch: {epoch:.2f} | VRAM: {mem_used:.0f}/{mem_total:.0f}GB | "
                  f"ETA: {timedelta(seconds=int(remaining))}")
    
    def on_save(self, args, state, control, **kwargs):
        print(f"\nüíæ Checkpoint saved at step {state.global_step}\n")
    
    def on_train_end(self, args, state, control, **kwargs):
        elapsed = time.time() - self.start_time
        print("\n" + "="*70)
        print(f"‚úÖ TRAINING COMPLETE | Time: {timedelta(seconds=int(elapsed))}")
        print("="*70 + "\n")


# === MAXIMUM QUALITY Configuration ===
CONFIG = {
    "model_name": "microsoft/Phi-4-reasoning-plus",
    "max_seq_length": 8192,  # Increased from 4096
    "dataset_path": "/kaggle/input/aimath-train/data/sft_dataset.jsonl",
    "output_dir": "/kaggle/working/outputs/sft",
    
    # LoRA - MAXIMUM RANK
    "lora_r": 512,           # Increased from 256
    "lora_alpha": 512,       # Match rank
    "lora_dropout": 0,
    
    # Training - LARGER BATCHES
    "num_train_epochs": 3,   # Increased from 2
    "per_device_train_batch_size": 8,   # Increased from 2
    "gradient_accumulation_steps": 4,   # Effective batch = 16
    "learning_rate": 1e-4,   # Slightly lower for stability
    "lr_scheduler_type": "cosine",
    "warmup_ratio": 0.03,
    "weight_decay": 0.01,
    "max_grad_norm": 1.0,
    
    # Saving
    "save_steps": 200,
    "save_total_limit": 2,
    "logging_steps": 10,
    "seed": 42,
}


def load_model_and_tokenizer():
    print("\n" + "="*70)
    print("üì• Loading Model with Unsloth (MAX QUALITY)")
    print("="*70)
    
    print(f"\n   Model: {CONFIG['model_name']}")
    print(f"   LoRA rank: {CONFIG['lora_r']} (MAXIMUM)")
    print(f"   Max seq length: {CONFIG['max_seq_length']}")
    
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=CONFIG["model_name"],
        max_seq_length=CONFIG["max_seq_length"],
        dtype=torch.bfloat16,
        load_in_4bit=False,
        trust_remote_code=True,
    )
    
    print("   ‚úÖ Base model loaded")
    
    # Add MAXIMUM rank LoRA
    model = FastLanguageModel.get_peft_model(
        model,
        r=CONFIG["lora_r"],
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
            "lm_head",      # Also train output layer
            "embed_tokens", # Also train embeddings
        ],
        lora_alpha=CONFIG["lora_alpha"],
        lora_dropout=CONFIG["lora_dropout"],
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=CONFIG["seed"],
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"   ‚úÖ LoRA applied (r={CONFIG['lora_r']})")
    print(f"   üìä Total params: {total_params / 1e9:.2f}B")
    print(f"   üìä Trainable params: {trainable_params / 1e6:.0f}M ({100*trainable_params/total_params:.1f}%)")
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    if torch.cuda.is_available():
        mem = torch.cuda.memory_allocated() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"   üìä VRAM: {mem:.1f} / {total:.0f} GB")
    
    return model, tokenizer


def main():
    start_time = time.time()
    
    print("\n" + "="*70)
    print("üéØ Phi-4 Math SFT Training (MAXIMUM QUALITY)")
    print("="*70)
    print(f"   Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"   Platform: Kaggle H100 80GB")
    print(f"   Method: Unsloth + LoRA r={CONFIG['lora_r']} + 3 epochs")
    
    model, tokenizer = load_model_and_tokenizer()
    
    print(f"\nüìÇ Loading dataset...")
    dataset = load_dataset("json", data_files=CONFIG["dataset_path"], split="train")
    print(f"   ‚úÖ Loaded {len(dataset):,} examples")
    
    output_dir = Path(CONFIG["output_dir"])
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Resume checkpoint
    resume_checkpoint = None
    checkpoints = list(output_dir.glob("checkpoint-*"))
    if checkpoints:
        resume_checkpoint = str(max(checkpoints, key=lambda x: int(x.name.split("-")[1])))
        print(f"\nüîÑ Found checkpoint: {resume_checkpoint}")
    
    training_args = TrainingArguments(
        output_dir=str(output_dir),
        num_train_epochs=CONFIG["num_train_epochs"],
        per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
        gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
        learning_rate=CONFIG["learning_rate"],
        lr_scheduler_type=CONFIG["lr_scheduler_type"],
        warmup_ratio=CONFIG["warmup_ratio"],
        weight_decay=CONFIG["weight_decay"],
        max_grad_norm=CONFIG["max_grad_norm"],
        bf16=is_bfloat16_supported(),
        fp16=not is_bfloat16_supported(),
        optim="adamw_torch_fused",
        save_strategy="steps",
        save_steps=CONFIG["save_steps"],
        save_total_limit=CONFIG["save_total_limit"],
        logging_steps=CONFIG["logging_steps"],
        logging_first_step=True,
        report_to="none",
        seed=CONFIG["seed"],
        dataloader_num_workers=2,
    )
    
    effective_batch = CONFIG["per_device_train_batch_size"] * CONFIG["gradient_accumulation_steps"]
    steps_per_epoch = len(dataset) // effective_batch
    total_steps = steps_per_epoch * CONFIG["num_train_epochs"]
    
    print("\n" + "="*70)
    print("üìã Training Configuration (MAXIMUM QUALITY)")
    print("="*70)
    print(f"   Dataset:         {len(dataset):,} examples")
    print(f"   Max seq length:  {CONFIG['max_seq_length']}")
    print(f"   LoRA rank:       {CONFIG['lora_r']} (+ embed + lm_head)")
    print(f"   Batch size:      {CONFIG['per_device_train_batch_size']} x {CONFIG['gradient_accumulation_steps']} = {effective_batch}")
    print(f"   Epochs:          {CONFIG['num_train_epochs']}")
    print(f"   Learning rate:   {CONFIG['learning_rate']}")
    print(f"   Steps/epoch:     {steps_per_epoch:,}")
    print(f"   Total steps:     {total_steps:,}")
    print(f"   Est. time:       {timedelta(seconds=total_steps * 2)}")  # ~2s/step with larger batch
    print("="*70)
    
    print("\nüèãÔ∏è Creating SFT Trainer...")
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        args=training_args,
        dataset_text_field="text",
        max_seq_length=CONFIG["max_seq_length"],
        packing=True,
        callbacks=[RealTimeLogger()],
    )
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        mem = torch.cuda.memory_allocated() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"\nüìä VRAM before training: {mem:.1f} / {total:.0f} GB")
    
    if resume_checkpoint:
        print(f"\nüîÑ Resuming from {resume_checkpoint}")
        trainer.train(resume_from_checkpoint=resume_checkpoint)
    else:
        trainer.train()
    
    # Save
    print("\nüíæ Saving LoRA model...")
    lora_path = output_dir / "lora_model"
    model.save_pretrained(str(lora_path))
    tokenizer.save_pretrained(str(lora_path))
    print(f"   ‚úÖ LoRA adapters saved to: {lora_path}")
    
    print("\nüíæ Merging LoRA into base model (16-bit)...")
    merged_path = output_dir / "merged_model"
    model.save_pretrained_merged(
        str(merged_path),
        tokenizer,
        save_method="merged_16bit",
    )
    print(f"   ‚úÖ Merged model saved to: {merged_path}")
    
    elapsed = time.time() - start_time
    print(f"\n‚úÖ Training complete!")
    print(f"   Total time: {timedelta(seconds=int(elapsed))}")
    
    return str(merged_path)


if __name__ == "__main__":
    final_model_path = main()
    print(f"\nüéâ SFT Complete! Model at: {final_model_path}")