# Reward Model Training and PPO

## RLHF: Reward Model + PPO Training

This notebook covers:
- Creating preference pairs from PersonaChat
- Training reward model for persona consistency
- Setting up PPO trainer
- Training with RLHF using LoRA
- Monitoring reward signals and KL divergence
- Evaluating persona consistency improvements

In [None]:
# Install required packages
!pip install -q transformers datasets peft trl accelerate
!pip install -q rouge-score sacrebleu evaluate
!pip install -q matplotlib seaborn pandas numpy

In [None]:
import sys
import os
sys.path.append('../')

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, PeftModel
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from datasets import load_dataset, Dataset
from tqdm import tqdm
import json
import time
from typing import List, Dict

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Environment Setup

In [None]:
# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")

## 2. Configuration

In [None]:
# Configuration
config = {
    # Paths
    'sft_model_path': '../models/sft_lora/final',
    'reward_model_output': '../models/reward_model',
    'ppo_output_dir': '../models/ppo_lora',
    
    # Base model
    'base_model_name': 'gpt2-medium',
    
    # Reward model training
    'reward_epochs': 3,
    'reward_batch_size': 4,
    'reward_learning_rate': 5e-5,
    
    # PPO configuration
    'ppo_epochs': 4,
    'ppo_batch_size': 4,
    'ppo_learning_rate': 1.5e-5,
    'ppo_steps': 2000,
    'init_kl_coef': 0.2,
    'target_kl': 6.0,
    'max_new_tokens': 50,
    
    # Generation
    'temperature': 0.9,
    'top_p': 0.9,
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 3. Load Dataset and Create Preference Pairs

In [None]:
# Helper functions for flexible dataset field handling
def get_persona(example):
    """Get persona from example (works with multiple dataset formats)"""
    # Google Synthetic-Persona-Chat uses 'user_1_persona' and 'user_2_persona' (with underscores)
    for field in ['user_1_persona', 'user_2_persona', 'personality', 'persona', 'personas', 'user_persona', 'user 1 personas', 'user 2 personas']:
        if field in example and example[field]:
            return example[field] if isinstance(example[field], list) else [example[field]]
    return []

def get_conversation(example):
    """Get conversation from example (works with multiple dataset formats)"""
    # Google Synthetic-Persona-Chat uses 'utterances' field
    # Try all possible field names (ordered by likelihood)
    for field in ['utterances', 'history', 'conversation', 'dialogue', 'messages', 'Best Generated Conversation']:
        if field in example and example[field]:
            value = example[field]
            if isinstance(value, list):
                return value
            elif isinstance(value, str):
                # Split by newlines
                return [line.strip() for line in value.split('\n') if line.strip()]
    return []

In [None]:
# Load PersonaChat
print("Loading PersonaChat dataset...")
dataset = load_dataset("google/Synthetic-Persona-Chat")

print(f"Train: {len(dataset['train'])} examples")
print(f"Validation: {len(dataset['validation'])} examples")

In [None]:
def create_preference_pairs(examples, num_pairs=1000):
    """Create preference pairs for reward model training"""
    pairs = []
    
    for i in tqdm(range(min(num_pairs, len(examples)))):
        ex = examples[i]
        persona = get_persona(ex)
        history = get_conversation(ex)
        
        # Skip if missing data
        if not persona or len(history) < 2:
            continue
        
        # Create prompt
        persona_text = "Persona: " + " ".join(persona)
        context = "\n".join([f"User: {history[j]}" if j % 2 == 0 else f"Assistant: {history[j]}" 
                             for j in range(min(4, len(history) - 1))])
        prompt = f"{persona_text}\n\n{context}\nAssistant:"
        
        # Use actual response as chosen
        chosen = history[-1] if len(history) > 0 else ""
        
        # Create a generic/inconsistent response as rejected
        # In practice, you would generate multiple responses and rank them
        rejected = "That's interesting. Tell me more."  # Generic response
        
        pairs.append({
            'prompt': prompt,
            'chosen': chosen,
            'rejected': rejected,
            'persona': persona
        })
    
    return pairs

# Create preference pairs
print("\nCreating preference pairs for reward model...")
train_pairs = create_preference_pairs(dataset['train'], num_pairs=2000)
val_pairs = create_preference_pairs(dataset['validation'], num_pairs=200)

print(f"\nCreated {len(train_pairs)} training pairs")
print(f"Created {len(val_pairs)} validation pairs")

# Show example
print("\nExample preference pair:")
print(f"Prompt: {train_pairs[0]['prompt'][:200]}...")
print(f"Chosen: {train_pairs[0]['chosen']}")
print(f"Rejected: {train_pairs[0]['rejected']}")

## 4. Train Reward Model

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['base_model_name'])
tokenizer.pad_token = tokenizer.eos_token

# Load reward model (sequence classification)
print("Loading reward model...")
reward_model = AutoModelForSequenceClassification.from_pretrained(
    config['base_model_name'],
    num_labels=1,  # Scalar reward
    torch_dtype=torch.float16,
    device_map='auto'
)

# Apply LoRA to reward model
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=['c_attn'],
    bias="none",
    task_type="SEQ_CLS"
)
reward_model = get_peft_model(reward_model, lora_config)
reward_model.print_trainable_parameters()

print("Reward model loaded and LoRA applied")

In [None]:
def prepare_reward_data(pairs):
    """Prepare data for reward model training"""
    data = []
    
    for pair in pairs:
        # Tokenize chosen
        chosen_text = pair['prompt'] + " " + pair['chosen']
        chosen_tokens = tokenizer(
            chosen_text,
            truncation=True,
            max_length=512,
            padding='max_length'
        )
        
        # Tokenize rejected
        rejected_text = pair['prompt'] + " " + pair['rejected']
        rejected_tokens = tokenizer(
            rejected_text,
            truncation=True,
            max_length=512,
            padding='max_length'
        )
        
        data.append({
            'chosen_input_ids': chosen_tokens['input_ids'],
            'chosen_attention_mask': chosen_tokens['attention_mask'],
            'rejected_input_ids': rejected_tokens['input_ids'],
            'rejected_attention_mask': rejected_tokens['attention_mask'],
        })
    
    return data

# Prepare reward training data
print("Preparing reward model training data...")
reward_train_data = prepare_reward_data(train_pairs)
reward_val_data = prepare_reward_data(val_pairs)

print(f"Reward training data: {len(reward_train_data)} examples")
print(f"Reward validation data: {len(reward_val_data)} examples")

In [None]:
class RewardTrainer(Trainer):
    """Custom trainer for reward model with ranking loss"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Get chosen rewards
        chosen_rewards = model(
            input_ids=torch.tensor(inputs['chosen_input_ids']).to(model.device),
            attention_mask=torch.tensor(inputs['chosen_attention_mask']).to(model.device)
        ).logits
        
        # Get rejected rewards
        rejected_rewards = model(
            input_ids=torch.tensor(inputs['rejected_input_ids']).to(model.device),
            attention_mask=torch.tensor(inputs['rejected_attention_mask']).to(model.device)
        ).logits
        
        # Ranking loss: chosen should have higher reward than rejected
        loss = -torch.nn.functional.logsigmoid(chosen_rewards - rejected_rewards).mean()
        
        return (loss, {'chosen': chosen_rewards, 'rejected': rejected_rewards}) if return_outputs else loss

# Training arguments
reward_training_args = TrainingArguments(
    output_dir=config['reward_model_output'],
    num_train_epochs=config['reward_epochs'],
    per_device_train_batch_size=config['reward_batch_size'],
    per_device_eval_batch_size=config['reward_batch_size'],
    learning_rate=config['reward_learning_rate'],
    fp16=True,
    logging_steps=50,
    eval_steps=200,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=200,
    load_best_model_at_end=True,
    report_to="none"
)

# Create trainer
reward_trainer = RewardTrainer(
    model=reward_model,
    args=reward_training_args,
    train_dataset=reward_train_data,
    eval_dataset=reward_val_data,
)

print("Reward trainer initialized")

In [None]:
# Train reward model
print("Training reward model...\n")
reward_start_time = time.time()

reward_result = reward_trainer.train()

reward_training_time = time.time() - reward_start_time

print(f"\n✅ Reward model training completed!")
print(f"Training time: {reward_training_time / 60:.2f} minutes")
print(f"Final loss: {reward_result.training_loss:.4f}")

# Save reward model
reward_model.save_pretrained(os.path.join(config['reward_model_output'], 'final'))
print(f"Reward model saved to {config['reward_model_output']}/final")

## 5. Setup PPO Training

In [None]:
# Load SFT model for PPO
print("Loading SFT model for PPO training...")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    config['base_model_name'],
    torch_dtype=torch.float16,
    device_map='auto'
)

# Load SFT LoRA adapters
sft_model = PeftModel.from_pretrained(base_model, config['sft_model_path'])
print("SFT model loaded")

# Create model with value head for PPO
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(sft_model)
print("PPO model created with value head")

# Create reference model (frozen copy of SFT model)
ref_model = create_reference_model(ppo_model)
print("Reference model created")

In [None]:
# PPO configuration
ppo_config = PPOConfig(
    model_name=config['base_model_name'],
    learning_rate=config['ppo_learning_rate'],
    batch_size=config['ppo_batch_size'],
    mini_batch_size=config['ppo_batch_size'],
    ppo_epochs=config['ppo_epochs'],
    init_kl_coef=config['init_kl_coef'],
    target=config['target_kl'],
    horizon=10000,
    gamma=1.0,
    lam=0.95,
    cliprange=0.2,
    cliprange_value=0.2,
    vf_coef=0.1,
)

# Create PPO trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=ppo_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
)

print("PPO trainer initialized")

## 6. Prepare PPO Prompts

In [ ]:
def prepare_ppo_prompts(dataset, num_prompts=500):
    """Prepare prompts for PPO training"""
    prompts = []
    personas = []
    
    for i in range(min(num_prompts, len(dataset))):
        ex = dataset[i]
        persona = get_persona(ex)
        history = get_conversation(ex)
        
        # Skip if missing data
        if not persona or len(history) < 2:
            continue
        
        # Create prompt
        persona_text = "Persona: " + " ".join(persona)
        context = "\n".join([f"User: {history[j]}" if j % 2 == 0 else f"Assistant: {history[j]}" 
                             for j in range(min(4, len(history) - 1))])
        prompt = f"{persona_text}\n\n{context}\nAssistant:"
        
        prompts.append(prompt)
        personas.append(persona)
    
    return prompts, personas

# Prepare PPO prompts
print("Preparing PPO prompts...")
ppo_prompts, ppo_personas = prepare_ppo_prompts(dataset['train'], num_prompts=500)

print(f"Prepared {len(ppo_prompts)} prompts for PPO")
print(f"\nExample prompt:\n{ppo_prompts[0][:300]}...")

## 7. Define Reward Function

In [None]:
def compute_reward(prompt, response, persona):
    """Compute reward for a response using reward model + rule-based rewards"""
    # Combine prompt and response
    full_text = prompt + " " + response
    
    # Get reward from reward model
    inputs = tokenizer(full_text, return_tensors='pt', truncation=True, max_length=512).to(reward_model.device)
    
    with torch.no_grad():
        reward_score = reward_model(**inputs).logits[0, 0].item()
    
    # Add rule-based persona consistency reward
    response_lower = response.lower()
    persona_matches = 0
    
    for trait in persona:
        trait_words = set(trait.lower().split()) - {'i', 'am', 'have', 'like', 'my', 'a', 'the'}
        for word in trait_words:
            if len(word) > 3 and word in response_lower:
                persona_matches += 1
                break
    
    persona_reward = persona_matches / max(len(persona), 1)
    
    # Combine rewards
    total_reward = reward_score + persona_reward
    
    return total_reward

print("Reward function defined")

## 8. PPO Training Loop

In [None]:
# PPO training
print("Starting PPO training...\n")
ppo_start_time = time.time()

# Training metrics
training_stats = {
    'rewards': [],
    'kl_divs': [],
    'losses': []
}

# Training loop
for step in tqdm(range(0, config['ppo_steps'], config['ppo_batch_size'])):
    # Get batch of prompts
    batch_idx = step % len(ppo_prompts)
    batch_prompts = ppo_prompts[batch_idx:batch_idx + config['ppo_batch_size']]
    batch_personas = ppo_personas[batch_idx:batch_idx + config['ppo_batch_size']]
    
    if len(batch_prompts) < config['ppo_batch_size']:
        continue
    
    # Tokenize prompts
    prompt_tensors = [tokenizer.encode(prompt, return_tensors='pt')[0].to(ppo_model.pretrained_model.device) 
                      for prompt in batch_prompts]
    
    # Generate responses
    response_tensors = []
    for prompt_tensor in prompt_tensors:
        response = ppo_trainer.generate(
            prompt_tensor.unsqueeze(0),
            max_new_tokens=config['max_new_tokens'],
            do_sample=True,
            temperature=config['temperature'],
            top_p=config['top_p'],
            pad_token_id=tokenizer.eos_token_id
        )
        response_tensors.append(response.squeeze()[len(prompt_tensor):])
    
    # Decode responses
    responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
    
    # Compute rewards
    rewards = [torch.tensor(compute_reward(prompt, response, persona)) 
               for prompt, response, persona in zip(batch_prompts, responses, batch_personas)]
    
    # PPO step
    stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
    
    # Track metrics
    if stats:
        training_stats['rewards'].append(torch.mean(torch.stack(rewards)).item())
        if 'objective/kl' in stats:
            training_stats['kl_divs'].append(stats['objective/kl'])
        if 'ppo/loss/total' in stats:
            training_stats['losses'].append(stats['ppo/loss/total'])
    
    # Log progress
    if step % 100 == 0:
        avg_reward = np.mean(training_stats['rewards'][-10:]) if training_stats['rewards'] else 0
        print(f"Step {step}: Avg Reward = {avg_reward:.3f}")
    
    # Save checkpoint
    if step > 0 and step % 500 == 0:
        checkpoint_path = os.path.join(config['ppo_output_dir'], f'checkpoint-{step}')
        ppo_model.save_pretrained(checkpoint_path)
        print(f"Checkpoint saved at step {step}")

ppo_training_time = time.time() - ppo_start_time

print(f"\n✅ PPO training completed!")
print(f"Training time: {ppo_training_time / 60:.2f} minutes")
print(f"Final avg reward: {np.mean(training_stats['rewards'][-10:]):.3f}")

## 9. Save Final Model

In [None]:
# Save final PPO model
final_ppo_path = os.path.join(config['ppo_output_dir'], 'final')
os.makedirs(final_ppo_path, exist_ok=True)

print(f"Saving final PPO model to {final_ppo_path}...")
ppo_model.save_pretrained(final_ppo_path)
tokenizer.save_pretrained(final_ppo_path)

print("✅ Final model saved successfully")

## 10. Visualize Training Progress

In [None]:
# Plot training metrics
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Rewards over time
if training_stats['rewards']:
    axes[0].plot(training_stats['rewards'], alpha=0.6)
    # Moving average
    window = 20
    if len(training_stats['rewards']) >= window:
        moving_avg = pd.Series(training_stats['rewards']).rolling(window=window).mean()
        axes[0].plot(moving_avg, linewidth=2, label=f'{window}-step MA')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Reward')
    axes[0].set_title('Rewards Over Training')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

# KL divergence
if training_stats['kl_divs']:
    axes[1].plot(training_stats['kl_divs'], color='orange')
    axes[1].axhline(y=config['target_kl'], color='r', linestyle='--', label=f'Target KL={config["target_kl"]}')
    axes[1].set_xlabel('Step')
    axes[1].set_ylabel('KL Divergence')
    axes[1].set_title('KL Divergence Over Training')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

# Loss
if training_stats['losses']:
    axes[2].plot(training_stats['losses'], color='green')
    axes[2].set_xlabel('Step')
    axes[2].set_ylabel('PPO Loss')
    axes[2].set_title('PPO Loss Over Training')
    axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config['ppo_output_dir'], 'ppo_training_curves.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Training curves saved")

## 11. Test Final Model

In [None]:
# Generate sample responses
print("Testing final PPO model:\n")
print("=" * 70)

for i in range(3):
    prompt = ppo_prompts[i]
    persona = ppo_personas[i]
    
    # Generate response
    inputs = tokenizer(prompt, return_tensors='pt').to(ppo_model.pretrained_model.device)
    
    with torch.no_grad():
        outputs = ppo_model.generate(
            **inputs,
            max_new_tokens=config['max_new_tokens'],
            do_sample=True,
            temperature=config['temperature'],
            top_p=config['top_p'],
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
    reward = compute_reward(prompt, response, persona)
    
    print(f"\nExample {i+1}:")
    print(f"Persona: {', '.join(persona[:2])}...")
    print(f"Generated: {response}")
    print(f"Reward: {reward:.3f}")
    print("-" * 70)

## 12. Save Training Summary

In [None]:
# Compile training summary
rlhf_summary = {
    'reward_model': {
        'training_time_minutes': float(reward_training_time / 60),
        'final_loss': float(reward_result.training_loss),
        'num_preference_pairs': len(train_pairs)
    },
    'ppo_training': {
        'training_time_minutes': float(ppo_training_time / 60),
        'total_steps': config['ppo_steps'],
        'final_avg_reward': float(np.mean(training_stats['rewards'][-10:])),
        'max_reward': float(np.max(training_stats['rewards'])),
        'initial_avg_reward': float(np.mean(training_stats['rewards'][:10])),
        'reward_improvement': float(np.mean(training_stats['rewards'][-10:]) - np.mean(training_stats['rewards'][:10]))
    },
    'configuration': config,
    'total_rlhf_time_minutes': float((reward_training_time + ppo_training_time) / 60)
}

# Save summary
summary_path = os.path.join(config['ppo_output_dir'], 'rlhf_summary.json')
with open(summary_path, 'w') as f:
    json.dump(rlhf_summary, f, indent=2)

print("RLHF training summary saved to:", summary_path)
print("\n" + "=" * 50)
print("RLHF Training Complete!")
print("=" * 50)

## Summary

This notebook has:
- ✅ Created preference pairs from PersonaChat
- ✅ Trained reward model for persona consistency
- ✅ Implemented PPO training with RLHF
- ✅ Monitored rewards and KL divergence
- ✅ Saved trained models and checkpoints
- ✅ Visualized training progress

**Key Achievements:**
- Reward model successfully trained on preference pairs
- PPO optimization improved persona consistency
- KL divergence controlled to prevent drift from SFT model
- Models ready for comprehensive evaluation

Next: Proceed to `5_evaluation.ipynb` for comprehensive model evaluation.