# RLHF Training Dynamics

**Rollouts, GAE, and the complete training loop**

## The Two-Phase Training Loop

Each RLHF iteration consists of:

1. **Rollout Phase** — Generate data with current policy
2. **Update Phase** — Improve policy using PPO

## Phase 1: Rollout Generation

```python
# For each batch of prompts:
1. Generate responses using policy model
2. Compute rewards using reward model
3. Compute value estimates using value network
4. Compute reference log probabilities
5. Store in rollout buffer
```

In [1]:
import torch
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class RolloutBatch:
    """A batch of rollout data for PPO training."""
    
    # Input data
    query_tensors: torch.Tensor       # Prompts
    response_tensors: torch.Tensor    # Generated responses
    
    # Model outputs during generation
    logprobs: torch.Tensor            # Log probs of generated tokens
    ref_logprobs: torch.Tensor        # Reference model log probs
    values: torch.Tensor              # Value estimates
    
    # Rewards and advantages
    rewards: torch.Tensor             # Reward model scores
    advantages: torch.Tensor          # GAE advantages
    returns: torch.Tensor             # Discounted returns

print("RolloutBatch stores all data needed for PPO updates:")
print("  - Queries and responses")
print("  - Log probabilities (policy and reference)")
print("  - Values, rewards, advantages, returns")

RolloutBatch stores all data needed for PPO updates:
  - Queries and responses
  - Log probabilities (policy and reference)
  - Values, rewards, advantages, returns


## Generalized Advantage Estimation (GAE)

GAE efficiently estimates advantages:

$$\hat{A}_t = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}$$

where $\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$ is the TD error.

In [2]:
def compute_gae(
    rewards: torch.Tensor,
    values: torch.Tensor,
    gamma: float = 0.99,
    lam: float = 0.95
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute Generalized Advantage Estimation.
    
    Args:
        rewards: Rewards at each timestep, shape (batch, seq_len)
        values: Value estimates, shape (batch, seq_len)
        gamma: Discount factor
        lam: GAE lambda (bias-variance tradeoff)
    
    Returns:
        advantages: GAE advantages
        returns: Discounted returns (for value function training)
    """
    batch_size, seq_len = rewards.shape
    advantages = torch.zeros_like(rewards)
    last_gae = 0
    
    # Process backwards through time
    for t in reversed(range(seq_len)):
        if t == seq_len - 1:
            next_value = 0  # Terminal state
        else:
            next_value = values[:, t + 1]
        
        # TD error: r + γV(s') - V(s)
        delta = rewards[:, t] + gamma * next_value - values[:, t]
        
        # GAE: accumulate discounted TD errors
        last_gae = delta + gamma * lam * last_gae
        advantages[:, t] = last_gae
    
    # Returns = advantages + values
    returns = advantages + values
    
    return advantages, returns

# Example
batch_size, seq_len = 2, 10
rewards = torch.randn(batch_size, seq_len) * 0.5
values = torch.randn(batch_size, seq_len)

advantages, returns = compute_gae(rewards, values)

print(f"Rewards shape: {rewards.shape}")
print(f"Advantages shape: {advantages.shape}")
print(f"Returns shape: {returns.shape}")

Rewards shape: torch.Size([2, 10])
Advantages shape: torch.Size([2, 10])
Returns shape: torch.Size([2, 10])


## Advantage Normalization

Normalize advantages for stable training:

In [3]:
def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
    """
    Normalize advantages to have zero mean and unit variance.
    
    This stabilizes training by ensuring consistent gradient scales.
    """
    mean = advantages.mean()
    std = advantages.std() + 1e-8
    return (advantages - mean) / std

# Example
print(f"Before whitening:")
print(f"  Mean: {advantages.mean():.4f}, Std: {advantages.std():.4f}")

whitened = whiten_advantages(advantages)
print(f"After whitening:")
print(f"  Mean: {whitened.mean():.4f}, Std: {whitened.std():.4f}")

Before whitening:
  Mean: 1.5281, Std: 0.9532
After whitening:
  Mean: 0.0000, Std: 1.0000


## Phase 2: PPO Update

In [4]:
def ppo_update_step(
    policy_model,
    value_network,
    rollout: RolloutBatch,
    optimizer,
    config
):
    """
    Perform one PPO update step.
    
    This is called multiple times (ppo_epochs) per rollout.
    """
    # Get current policy outputs
    # (Re-compute because policy has changed since rollout)
    current_logprobs = policy_model.get_logprobs(
        rollout.query_tensors,
        rollout.response_tensors
    )
    
    current_values = value_network(
        rollout.query_tensors,
        rollout.response_tensors
    )
    
    # Compute PPO loss
    ratio = torch.exp(current_logprobs - rollout.logprobs)
    
    # Clipped objective
    unclipped = ratio * rollout.advantages
    clipped = torch.clamp(ratio, 1 - config['clip_ratio'], 1 + config['clip_ratio'])
    clipped = clipped * rollout.advantages
    
    policy_loss = -torch.min(unclipped, clipped).mean()
    
    # Value loss
    value_loss = ((current_values - rollout.returns) ** 2).mean()
    
    # KL penalty
    kl_penalty = (current_logprobs - rollout.ref_logprobs).mean()
    
    # Total loss
    total_loss = (
        policy_loss +
        config['vf_coef'] * value_loss +
        config['kl_coef'] * kl_penalty
    )
    
    # Update
    optimizer.zero_grad()
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config['max_grad_norm'])
    optimizer.step()
    
    return {
        'policy_loss': policy_loss.item(),
        'value_loss': value_loss.item(),
        'kl': kl_penalty.item()
    }

print("PPO Update performs:")
print("  1. Re-compute log probs with current policy")
print("  2. Compute clipped policy loss")
print("  3. Compute value loss")
print("  4. Add KL penalty")
print("  5. Backprop and update")

PPO Update performs:
  1. Re-compute log probs with current policy
  2. Compute clipped policy loss
  3. Compute value loss
  4. Add KL penalty
  5. Backprop and update


## Complete Training Loop

In [5]:
def train_rlhf_loop(config):
    """
    Simplified RLHF training loop structure.
    """
    for iteration in range(config['num_iterations']):
        # ===== PHASE 1: ROLLOUT =====
        # Sample prompts
        # prompts = sample_prompts(config['batch_size'])
        
        # Generate responses
        # responses = policy_model.generate(prompts)
        
        # Score with reward model
        # rewards = reward_model(prompts, responses)
        
        # Get values and log probs
        # values = value_network(prompts, responses)
        # policy_logprobs = policy_model.get_logprobs(prompts, responses)
        # ref_logprobs = reference_model.get_logprobs(prompts, responses)
        
        # Compute GAE
        # advantages, returns = compute_gae(rewards, values)
        # advantages = whiten_advantages(advantages)
        
        # Store in rollout buffer
        # rollout = RolloutBatch(...)
        
        # ===== PHASE 2: PPO UPDATE =====
        # for epoch in range(config['ppo_epochs']):
        #     metrics = ppo_update_step(policy_model, value_network, rollout, optimizer, config)
        
        # ===== LOGGING =====
        # log_metrics(metrics)
        
        pass
    
    print("RLHF Training Loop Structure:")
    print("  For each iteration:")
    print("    1. Generate rollouts (responses + rewards)")
    print("    2. Compute advantages with GAE")
    print("    3. Run PPO updates (multiple epochs)")
    print("    4. Log metrics and checkpoint")

train_rlhf_loop({'num_iterations': 1000, 'ppo_epochs': 4})

RLHF Training Loop Structure:
  For each iteration:
    1. Generate rollouts (responses + rewards)
    2. Compute advantages with GAE
    3. Run PPO updates (multiple epochs)
    4. Log metrics and checkpoint


## Key Hyperparameters

| Parameter | Typical Value | Notes |
|-----------|---------------|-------|
| `gamma` | 0.99 | Discount factor |
| `gae_lambda` | 0.95 | GAE bias-variance tradeoff |
| `ppo_epochs` | 4 | Updates per rollout |
| `batch_size` | 4-16 | Prompts per iteration |

## Next Steps

Let's learn about reference models — how to create and manage them for stable RLHF training.