# Experiment 2:
# Direct Preference Optimization (DPO)

### Creating Preference Pairs from GSM8K

In [None]:
def create_preference_pairs(examples, model, tokenizer):
    """Create preference pairs for DPO training"""
    preference_data = []

    for example in examples[:500]:  # Using subset
        prompt = f"Solve this math problem:\n{example['question']}\n\nProvide step-by-step reasoning."

        # The "chosen" response is the correct, well-formatted answer
        chosen = example['answer']

        # Create a "rejected" response (less structured)
        # In practice, you'd generate these or use model outputs
        rejected = example['answer'].split('####')[0].strip()
        rejected = rejected.replace('<<', '').replace('>>', '')

        preference_data.append({
            "prompt": prompt,
            "chosen": chosen,
            "rejected": rejected
        })

    return preference_data

preference_data = create_preference_pairs(train_data, model, tokenizer)

print(f"Created {len(preference_data)} preference pairs")
print(f"\nExample preference pair:")
print(f"Prompt: {preference_data[0]['prompt'][:100]}...")
print(f"\nChosen: {preference_data[0]['chosen'][:100]}...")
print(f"\nRejected: {preference_data[0]['rejected'][:100]}...")


### Results

In [None]:
from tunix.trainers import DPOTrainer

# Load the SFT model from previous step
model = FlaxAutoModelForCausalLM.from_pretrained("./gemma-2b-gsm8k-lora")

# Initialize DPO trainer
dpo_trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    beta=0.1,  # KL penalty coefficient
    learning_rate=5e-7,
    num_epochs=2,
    batch_size=2,
    max_length=512
)

# Train with preference optimization
print("Starting DPO training...")
dpo_metrics = dpo_trainer.train(preference_data)

print(f"\nDPO Training Metrics:")
print(f"  Reward Margin: {dpo_metrics['reward_margin']:.3f}")
print(f"  Preference Accuracy: {dpo_metrics['accuracy']:.2%}")
print(f"  KL Divergence: {dpo_metrics['kl_div']:.4f}")
print(f"  Training time: {dpo_metrics['training_time_hours']:.1f} hours")

# Save aligned model
dpo_trainer.save_checkpoint("./gemma-2b-gsm8k-dpo")