# PPO (Proximal Policy Optimization) Training Demo

This notebook demonstrates how to train language models using PPO for RLHF (Reinforcement Learning from Human Feedback).

**Key Features:**
- PPO algorithm implementation
- Human feedback simulation
- Reward model training
- Policy optimization

In [None]:
# Install TRL and dependencies
!pip install -q trl transformers torch accelerate peft
!pip install -q datasets wandb

import torch
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
import numpy as np

In [None]:
# PPO Configuration
config = PPOConfig(
    model_name="microsoft/DialoGPT-medium",
    learning_rate=1.41e-5,
    batch_size=16,
    mini_batch_size=4,
    gradient_accumulation_steps=4,
    optimize_cuda_cache=True,
    early_stopping=False,
    target_kl=0.1,
    ppo_epochs=4,
    seed=0,
    init_kl_coef=0.2,
    adap_kl_ctrl=True
)

print("PPO Configuration:")
print(f"Model: {config.model_name}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Batch Size: {config.batch_size}")

In [None]:
# Load model and tokenizer
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLM.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded: {model.__class__.__name__}")
print(f"Tokenizer vocab size: {len(tokenizer)}")

In [None]:
# Simple reward function (replace with actual reward model)
def compute_reward(texts):
    """
    Simple reward function - in practice, you would use
    a trained reward model or human feedback
    """
    rewards = []
    for text in texts:
        # Simple heuristic: reward longer, more diverse responses
        length_reward = min(len(text.split()) / 20.0, 1.0)
        diversity_reward = len(set(text.lower().split())) / max(len(text.split()), 1)
        reward = length_reward + diversity_reward
        rewards.append([reward])
    return rewards

# Test reward function
test_texts = ["Hello world", "This is a longer and more diverse sentence with various words"]
test_rewards = compute_reward(test_texts)
print(f"Test rewards: {test_rewards}")

In [None]:
# Create sample training data
prompts = [
    "What is artificial intelligence?",
    "Explain machine learning",
    "How does deep learning work?",
    "What are neural networks?",
    "Describe natural language processing"
]

# Tokenize prompts
tokenized_prompts = []
for prompt in prompts:
    tokens = tokenizer.encode(prompt, return_tensors="pt")
    tokenized_prompts.append(tokens.squeeze())

print(f"Number of prompts: {len(prompts)}")
print(f"Sample prompt: {prompts[0]}")

In [None]:
# Initialize PPO trainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)

# Training loop
for epoch in range(2):  # Small number for demo
    print(f"\nEpoch {epoch + 1}")
    
    for batch_idx in range(0, len(tokenized_prompts), config.batch_size):
        batch_prompts = tokenized_prompts[batch_idx:batch_idx + config.batch_size]
        
        # Generate responses
        response_tensors = []
        for prompt in batch_prompts:
            response = ppo_trainer.generate(
                prompt.unsqueeze(0),
                max_length=prompt.shape[0] + 50,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                temperature=0.7
            )
            response_tensors.append(response.squeeze())
        
        # Decode responses
        responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
        
        # Compute rewards
        rewards = compute_reward(responses)
        rewards = [torch.tensor(r) for r in rewards]
        
        # PPO step
        stats = ppo_trainer.step(batch_prompts, response_tensors, rewards)
        
        print(f"Batch {batch_idx // config.batch_size + 1}: Mean reward = {np.mean([r.item() for r in rewards]):.3f}")
        
print("\nTraining completed!")