# Phase 3: SDPO Training for IIT JEE Fine-tuning

This notebook runs **Self-Distillation Preference Optimization (SDPO)** on your SFT model from Phase 2.

**Pipeline:**
1. Download SFT model + data from HuggingFace
2. Generate rollouts (model attempts JEE questions)
3. Judge rollouts (rule-based + optional LLM judge)
4. Build DPO preference pairs (chosen vs rejected)
5. Train with DPO using TRL
6. Upload trained model to HuggingFace

**Requirements:** Colab Pro+ with A100 GPU runtime

## 1. Setup & Install Dependencies

In [None]:
!pip install -q torch transformers accelerate peft trl datasets bitsandbytes
!pip install -q anthropic huggingface_hub tqdm jsonlines

In [None]:
# Verify GPU
!nvidia-smi
import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## 2. Configuration

Fill in your credentials below.

In [None]:
import os

# ============================================================
# FILL THESE IN
# ============================================================
HF_TOKEN = "YOUR_HF_TOKEN_HERE"  # Your HuggingFace token
ANTHROPIC_API_KEY = ""  # Your Anthropic API key (optional — leave empty for rule-based judge)
# ============================================================

os.environ["HF_TOKEN"] = HF_TOKEN
if ANTHROPIC_API_KEY:
    os.environ["ANTHROPIC_API_KEY"] = ANTHROPIC_API_KEY

# Model configuration
SFT_MODEL_ID = "vipsehgal/qwen3-8b-jee-sft"  # Your SFT model from Phase 2
OUTPUT_MODEL_NAME = "qwen3-8b-jee-sdpo"        # Name for the output model

# Training configuration
NUM_PROMPTS = 500       # Number of prompts for rollouts (reduce for faster runs)
NUM_ROLLOUTS = 2        # Rollouts per prompt
MAX_NEW_TOKENS = 512    # Max tokens per rollout
DPO_EPOCHS = 2          # DPO training epochs
DPO_BATCH_SIZE = 2      # Per-device batch size
DPO_GRAD_ACCUM = 4      # Gradient accumulation steps (effective batch = 8)
DPO_LR = 5e-6           # Learning rate
DPO_BETA = 0.1          # DPO beta (higher = more conservative)
LORA_R = 16             # LoRA rank
LORA_ALPHA = 32         # LoRA alpha

print("Configuration set!")
print(f"  Model: {SFT_MODEL_ID}")
print(f"  Prompts: {NUM_PROMPTS}, Rollouts/prompt: {NUM_ROLLOUTS}")
print(f"  Judge mode: {'LLM (Claude Opus)' if ANTHROPIC_API_KEY else 'Rule-based only'}")

## 3. Download Model & Data

In [None]:
from huggingface_hub import snapshot_download, hf_hub_download, login
import json

login(token=HF_TOKEN)

# Download the SFT model (this takes a few minutes for 16GB)
print("Downloading SFT model...")
snapshot_download(
    SFT_MODEL_ID,
    local_dir="./sft-model",
    ignore_patterns=["sdpo_data/*"],  # Skip data files in first download
)
print("Model downloaded!")

In [None]:
import os

# Download SDPO data files
os.makedirs("./sdpo_data", exist_ok=True)

for fname in ["rl_prompts.jsonl", "eval_prompts.jsonl", "judge_config.json", "train.jsonl"]:
    hf_hub_download(
        SFT_MODEL_ID,
        filename=f"sdpo_data/{fname}",
        local_dir=".",
    )
    print(f"Downloaded {fname}")

# Load prompts
rl_prompts = []
with open("./sdpo_data/rl_prompts.jsonl") as f:
    for line in f:
        if line.strip():
            rl_prompts.append(json.loads(line))

eval_prompts = []
with open("./sdpo_data/eval_prompts.jsonl") as f:
    for line in f:
        if line.strip():
            eval_prompts.append(json.loads(line))

# Load training data (has gold solutions for DPO chosen responses)
train_data = []
with open("./sdpo_data/train.jsonl") as f:
    for line in f:
        if line.strip():
            train_data.append(json.loads(line))

# Build a lookup from question -> gold solution
gold_solutions = {}
for item in train_data:
    msgs = item["messages"]
    question = msgs[1]["content"]  # user message
    solution = msgs[2]["content"]  # assistant message
    gold_solutions[question[:200]] = solution  # key by first 200 chars

print(f"\nLoaded {len(rl_prompts)} RL prompts, {len(eval_prompts)} eval prompts")
print(f"Gold solutions available: {len(gold_solutions)}")

## 4. Generate Rollouts

The model generates solution attempts for JEE questions. We'll judge these to create preference pairs.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import random

# Load model in 4-bit for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

print("Loading model in 4-bit...")
tokenizer = AutoTokenizer.from_pretrained("./sft-model")
model = AutoModelForCausalLM.from_pretrained(
    "./sft-model",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

print(f"Model loaded! Memory: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

In [None]:
from tqdm.notebook import tqdm

SYSTEM_MSG = (
    "You are an expert IIT JEE tutor. Solve problems step-by-step "
    "using LaTeX notation. Show all work clearly and arrive at the final answer."
)

# Select a random subset of prompts
random.seed(42)
selected_prompts = random.sample(rl_prompts, min(NUM_PROMPTS, len(rl_prompts)))
print(f"Generating rollouts for {len(selected_prompts)} prompts x {NUM_ROLLOUTS} each...")
print(f"Total generations: {len(selected_prompts) * NUM_ROLLOUTS}")

all_rollouts = []

for prompt_data in tqdm(selected_prompts, desc="Generating rollouts"):
    question = prompt_data["prompt"]
    ground_truth = prompt_data["ground_truth"]

    messages = [
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": question},
    ]

    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

    for _ in range(NUM_ROLLOUTS):
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.7,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )

        generated = tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        all_rollouts.append({
            "question": question,
            "ground_truth": ground_truth,
            "model_output": generated,
            "subject": prompt_data.get("subject", "Unknown"),
            "source": prompt_data.get("source", "unknown"),
        })

print(f"\nGenerated {len(all_rollouts)} rollouts")

# Save rollouts to disk (checkpoint)
with open("rollouts.jsonl", "w") as f:
    for r in all_rollouts:
        f.write(json.dumps(r) + "\n")
print("Saved rollouts to rollouts.jsonl")

In [None]:
# Free GPU memory before training
del model
torch.cuda.empty_cache()
import gc
gc.collect()
print(f"GPU memory freed: {torch.cuda.memory_allocated() / 1e9:.1f} GB used")

## 5. Judge Rollouts

Check each rollout for correctness. Optionally call Claude Opus for rich feedback on wrong answers.

In [None]:
import re

NUMERICAL_TOLERANCE = 0.01

def extract_answer(response):
    """Extract the final answer from a model's response."""
    # \boxed{...}
    match = re.search(r'\\boxed\{([^}]+)\}', response)
    if match:
        return match.group(1).strip()

    # **Answer:** ...
    match = re.search(r'\*\*Answer:\*\*\s*(.+?)(?:\n|$)', response)
    if match:
        return match.group(1).strip()

    # Answer: ...
    match = re.search(r'(?:^|\n)\s*Answer:\s*(.+?)(?:\n|$)', response)
    if match:
        return match.group(1).strip()

    # The answer is ...
    match = re.search(r'[Tt]he\s+answer\s+is\s+(.+?)(?:\.|$)', response)
    if match:
        return match.group(1).strip()

    # Last line with option letter
    for line in reversed(response.strip().split('\n')):
        match = re.search(r'\(([A-D])\)', line)
        if match:
            return match.group(1)

    return ""


def normalize_answer(answer):
    """Normalize an answer for comparison."""
    answer = str(answer).strip().lower()
    answer = re.sub(r'^\((.+)\)$', r'\1', answer)
    answer = re.sub(r'^\$(.+)\$$', r'\1', answer)
    answer = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', answer)
    answer = re.sub(r'[\\{}\s]', '', answer)
    return answer


def check_answer(generated, ground_truth):
    """Check if generated answer matches ground truth. Returns (is_correct, detail)."""
    gen = normalize_answer(generated)
    gt = normalize_answer(ground_truth)

    if not gen:
        return False, "Could not extract answer"
    if not gt:
        return False, "No ground truth"

    # Exact match
    if gen == gt:
        return True, "Exact match"

    # Containment
    if gt in gen or gen in gt:
        return True, "Partial match"

    # Numerical comparison
    try:
        gen_num = float(re.sub(r'[^0-9.\-e]', '', gen))
        gt_num = float(re.sub(r'[^0-9.\-e]', '', gt))
        if gt_num != 0:
            rel_error = abs(gen_num - gt_num) / abs(gt_num)
            if rel_error < NUMERICAL_TOLERANCE:
                return True, f"Numerical match (error: {rel_error:.4f})"
        elif abs(gen_num - gt_num) < 1e-6:
            return True, "Numerical match (near zero)"
    except (ValueError, ZeroDivisionError):
        pass

    # Multi-answer MCQ
    gen_letters = set(re.findall(r'[a-d]', gen))
    gt_letters = set(re.findall(r'[a-d]', gt))
    if gen_letters and gt_letters:
        if gen_letters == gt_letters:
            return True, "MCQ match"

    return False, f"No match: '{generated}' vs '{ground_truth}'"

print("Judge functions defined.")

In [None]:
# Judge all rollouts
judged_rollouts = []
correct_count = 0
no_answer_count = 0

for rollout in tqdm(all_rollouts, desc="Judging rollouts"):
    answer = extract_answer(rollout["model_output"])
    is_correct, detail = check_answer(answer, rollout["ground_truth"])

    judged_rollouts.append({
        **rollout,
        "answer_extracted": answer,
        "is_correct": is_correct,
        "detail": detail,
    })

    if is_correct:
        correct_count += 1
    if not answer:
        no_answer_count += 1

total = len(judged_rollouts)
print(f"\nResults:")
print(f"  Total rollouts: {total}")
print(f"  Correct: {correct_count} ({correct_count/total*100:.1f}%)")
print(f"  Incorrect: {total - correct_count} ({(total-correct_count)/total*100:.1f}%)")
print(f"  No answer extracted: {no_answer_count}")

In [None]:
# Optional: Get rich feedback from Claude Opus for incorrect rollouts
# This costs API credits but improves DPO training quality.
# Skip this cell if you don't have an Anthropic API key.

if ANTHROPIC_API_KEY:
    import anthropic
    client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

    # Test the API key first
    try:
        test = client.messages.create(
            model="claude-sonnet-4-5-20250929",  # Use Sonnet for cost efficiency
            max_tokens=50,
            messages=[{"role": "user", "content": "Say 'ok'"}],
        )
        print(f"API key works! Response: {test.content[0].text}")
    except Exception as e:
        print(f"API key error: {e}")
        print("Continuing without LLM judge.")
        ANTHROPIC_API_KEY = ""
        client = None

    if client:
        # Get feedback for incorrect rollouts (sample to control cost)
        incorrect = [r for r in judged_rollouts if not r["is_correct"]]
        sample_size = min(200, len(incorrect))  # Cap at 200 API calls
        sampled = random.sample(incorrect, sample_size)

        print(f"\nGetting LLM feedback for {sample_size} incorrect rollouts...")
        feedback_map = {}  # question_key -> feedback

        for r in tqdm(sampled, desc="LLM Judge"):
            try:
                resp = client.messages.create(
                    model="claude-sonnet-4-5-20250929",
                    max_tokens=1024,
                    system=(
                        "You are an expert IIT JEE examiner. Provide a correct, "
                        "step-by-step solution to the question. Use LaTeX notation."
                    ),
                    messages=[{"role": "user", "content": (
                        f"Question: {r['question']}\n"
                        f"Correct answer: {r['ground_truth']}\n\n"
                        f"Provide a clear step-by-step solution arriving at the correct answer."
                    )}],
                )
                feedback_map[r["question"][:200]] = resp.content[0].text
            except Exception as e:
                print(f"API error: {e}")
                continue

        print(f"Got {len(feedback_map)} LLM feedback responses")
else:
    print("No Anthropic API key — using gold solutions from training data as chosen responses.")
    feedback_map = {}

## 6. Build DPO Preference Dataset

For each question, we create (chosen, rejected) pairs:
- **chosen**: A correct solution (from rollout, gold training data, or LLM feedback)
- **rejected**: An incorrect solution (from rollout)

In [None]:
from collections import defaultdict

# Group rollouts by question
question_rollouts = defaultdict(lambda: {"correct": [], "incorrect": []})
for r in judged_rollouts:
    key = r["question"][:200]
    if r["is_correct"]:
        question_rollouts[key]["correct"].append(r)
    else:
        question_rollouts[key]["incorrect"].append(r)

# Build DPO pairs
dpo_data = {"prompt": [], "chosen": [], "rejected": []}

both_available = 0
used_gold = 0
used_feedback = 0
skipped = 0

for key, groups in question_rollouts.items():
    correct_list = groups["correct"]
    incorrect_list = groups["incorrect"]

    if not incorrect_list:
        # All correct — skip (model already handles this)
        skipped += 1
        continue

    # Get the rejected response (pick one incorrect rollout)
    rejected_response = incorrect_list[0]["model_output"]
    question = incorrect_list[0]["question"]

    # Get the chosen response (priority: correct rollout > LLM feedback > gold solution)
    chosen_response = None

    if correct_list:
        # Best case: model got it right in another rollout
        chosen_response = correct_list[0]["model_output"]
        both_available += 1
    elif key in feedback_map:
        # LLM judge provided a correct solution
        chosen_response = feedback_map[key]
        used_feedback += 1
    elif key in gold_solutions:
        # Fall back to training data solution
        chosen_response = gold_solutions[key]
        used_gold += 1
    else:
        skipped += 1
        continue

    # Format as chat messages (just the assistant response)
    prompt_text = f"{SYSTEM_MSG}\n\n{question}"
    dpo_data["prompt"].append(prompt_text)
    dpo_data["chosen"].append(chosen_response)
    dpo_data["rejected"].append(rejected_response)

print(f"DPO dataset built:")
print(f"  Total pairs: {len(dpo_data['prompt'])}")
print(f"  From correct rollouts: {both_available}")
print(f"  From LLM feedback: {used_feedback}")
print(f"  From gold solutions: {used_gold}")
print(f"  Skipped (all correct or no chosen): {skipped}")

In [None]:
from datasets import Dataset

# Create HuggingFace Dataset
dpo_dataset = Dataset.from_dict(dpo_data)
print(f"DPO dataset: {dpo_dataset}")
print(f"\nSample:")
sample = dpo_dataset[0]
print(f"  Prompt: {sample['prompt'][:150]}...")
print(f"  Chosen: {sample['chosen'][:150]}...")
print(f"  Rejected: {sample['rejected'][:150]}...")

# Save checkpoint
dpo_dataset.save_to_disk("./dpo_dataset")
print("\nSaved DPO dataset to ./dpo_dataset")

## 7. DPO Training

Train the model using TRL's DPOTrainer with LoRA for memory efficiency.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import DPOConfig, DPOTrainer
import torch

# Reload model in 4-bit for training
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

print("Loading model for DPO training...")
tokenizer = AutoTokenizer.from_pretrained("./sft-model")
model = AutoModelForCausalLM.from_pretrained(
    "./sft-model",
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

# LoRA configuration
peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

print(f"Model loaded! Memory: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

In [None]:
# DPO training configuration
training_args = DPOConfig(
    output_dir="./dpo-output",
    num_train_epochs=DPO_EPOCHS,
    per_device_train_batch_size=DPO_BATCH_SIZE,
    gradient_accumulation_steps=DPO_GRAD_ACCUM,
    learning_rate=DPO_LR,
    beta=DPO_BETA,
    max_length=1024,
    max_prompt_length=512,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    bf16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    remove_unused_columns=False,
    report_to="none",  # Change to "wandb" if you want W&B logging
)

# Initialize DPO trainer
# TRL's DPOTrainer automatically handles the reference model internally
trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dpo_dataset,
    processing_class=tokenizer,
    peft_config=peft_config,
)

print("DPO Trainer initialized!")
print(f"  Dataset size: {len(dpo_dataset)}")
print(f"  Epochs: {DPO_EPOCHS}")
print(f"  Effective batch size: {DPO_BATCH_SIZE * DPO_GRAD_ACCUM}")
print(f"  Total steps: ~{len(dpo_dataset) * DPO_EPOCHS // (DPO_BATCH_SIZE * DPO_GRAD_ACCUM)}")

In [None]:
# Train!
print("Starting DPO training...")
print("="*60)
trainer.train()
print("="*60)
print("DPO training complete!")

In [None]:
# Save the LoRA adapters
trainer.save_model("./dpo-adapters")
tokenizer.save_pretrained("./dpo-adapters")
print("LoRA adapters saved to ./dpo-adapters")

## 8. Merge & Export

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import gc

# Free memory
del model, trainer
torch.cuda.empty_cache()
gc.collect()

# Load base model in full precision for merging
print("Loading base model for merging...")
base_model = AutoModelForCausalLM.from_pretrained(
    "./sft-model",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("./sft-model")

# Load and merge LoRA
print("Merging LoRA adapters...")
model = PeftModel.from_pretrained(base_model, "./dpo-adapters")
merged_model = model.merge_and_unload()

# Save merged model
output_dir = f"./{OUTPUT_MODEL_NAME}"
print(f"Saving merged model to {output_dir}...")
merged_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("Merged model saved!")

## 9. Quick Evaluation

In [None]:
# Quick test on a few eval prompts
del base_model, model, merged_model
torch.cuda.empty_cache()
gc.collect()

print("Loading SDPO model for evaluation...")
eval_model = AutoModelForCausalLM.from_pretrained(
    output_dir,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
eval_tokenizer = AutoTokenizer.from_pretrained(output_dir)

if eval_tokenizer.pad_token is None:
    eval_tokenizer.pad_token = eval_tokenizer.eos_token

# Evaluate on a sample of eval prompts
eval_sample = random.sample(eval_prompts, min(50, len(eval_prompts)))
correct = 0

for prompt_data in tqdm(eval_sample, desc="Evaluating"):
    messages = [
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": prompt_data["prompt"]},
    ]
    input_text = eval_tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = eval_tokenizer(input_text, return_tensors="pt").to(eval_model.device)

    with torch.no_grad():
        outputs = eval_model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=0.1,  # Low temp for evaluation
            do_sample=True,
            pad_token_id=eval_tokenizer.pad_token_id,
        )

    generated = eval_tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    )

    answer = extract_answer(generated)
    is_correct, _ = check_answer(answer, prompt_data["ground_truth"])
    if is_correct:
        correct += 1

print(f"\nEvaluation Results (SDPO model):")
print(f"  Accuracy: {correct}/{len(eval_sample)} = {correct/len(eval_sample)*100:.1f}%")

del eval_model
torch.cuda.empty_cache()
gc.collect()

## 10. Upload to HuggingFace

In [None]:
from huggingface_hub import HfApi

HF_USERNAME = SFT_MODEL_ID.split("/")[0]  # e.g. "vipsehgal"
REPO_ID = f"{HF_USERNAME}/{OUTPUT_MODEL_NAME}"

api = HfApi(token=HF_TOKEN)

# Create repo if it doesn't exist
try:
    api.create_repo(REPO_ID, private=True, exist_ok=True)
except Exception as e:
    print(f"Repo creation note: {e}")

print(f"Uploading to {REPO_ID}...")
api.upload_folder(
    folder_path=output_dir,
    repo_id=REPO_ID,
    repo_type="model",
)
print(f"\nUpload complete!")
print(f"Model: https://huggingface.co/{REPO_ID}")
print(f"\nNext: Download this model on your Mac for Phase 4 (local inference)")

## Done!

Your SDPO-trained model is now on HuggingFace. 

**Phase 4 (on your Mac):**
```bash
# Convert to MLX 4-bit for local inference
python -m mlx_lm.convert \
    --hf-path vipsehgal/qwen3-8b-jee-sdpo \
    --q-bits 4 --q-group-size 64 \
    --mlx-path ./qwen3-8b-jee-sdpo-mlx-4bit

# Run inference
mlx_lm.generate \
    --model ./qwen3-8b-jee-sdpo-mlx-4bit \
    --prompt "Solve: A block of mass 5 kg is placed on an incline..."

# Or serve as API
mlx_lm.server --model ./qwen3-8b-jee-sdpo-mlx-4bit --port 8080
```