# 02 — The Four Technical Pillars: Real Implementations

> **Purpose:** Implement the four critical RL training techniques as real, runnable code. Each technique is a modular function you can plug into any training loop.

**Prerequisites:** This notebook builds on `01_rl_training_loop_foundations.ipynb`. We reuse the TinyLM model.

---

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

torch.manual_seed(42)
np.random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cpu


## Setup: TinyLM from Section 01

A minimal language model for teaching purposes.

In [2]:
class TinyLM(nn.Module):
    """Minimal language model for RL experiments."""
    def __init__(self, vocab_size=100, hidden_size=64, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)
        self.head = nn.Linear(hidden_size, vocab_size)
        self.vocab_size = vocab_size
    
    def forward(self, x):
        h = self.embed(x)
        h, _ = self.rnn(h)
        return self.head(h)
    
    def generate(self, prompt, max_len=20, temperature=1.0):
        """Generate tokens autoregressively."""
        self.eval()
        tokens = prompt.clone()
        log_probs = []
        
        with torch.no_grad():
            for _ in range(max_len):
                logits = self(tokens)[:, -1, :] / temperature
                dist = Categorical(logits=logits)
                next_token = dist.sample()
                log_probs.append(dist.log_prob(next_token))
                tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
        
        return tokens, torch.stack(log_probs, dim=1)

# Instantiate
model = TinyLM().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 62,820


## 1. Advantage Normalization: Three Real Implementations

Each function takes rewards and returns normalized advantages.

In [None]:
def normalize_grpo_local(rewards, group_size):
    """
    GRPO-style local normalization.
    Normalizes within each prompt's group of responses.
    
    Args:
        rewards: Tensor of shape (batch_size,) - reward per response
        group_size: int - number of responses per prompt (K)
    
    Returns:
        advantages: Tensor of shape (batch_size,)
    """
    batch_size = rewards.shape[0]
    num_groups = batch_size // group_size
    eps = 1e-8
    
    # Reshape to (num_groups, group_size)
    grouped = rewards.view(num_groups, group_size)
    
    # Compute per-group statistics
    group_mean = grouped.mean(dim=1, keepdim=True)  # (num_groups, 1)
    group_std = grouped.std(dim=1, keepdim=True)    # (num_groups, 1)
    
    # Normalize within each group
    advantages = (grouped - group_mean) / (group_std + eps)
    
    return advantages.view(batch_size)


def normalize_global(rewards):
    """
    REINFORCE++-style global normalization.
    Uses batch-level mean and std.
    
    Args:
        rewards: Tensor of shape (batch_size,)
    
    Returns:
        advantages: Tensor of shape (batch_size,)
    """
    eps = 1e-8
    batch_mean = rewards.mean()
    batch_std = rewards.std()
    
    return (rewards - batch_mean) / (batch_std + eps)


def normalize_hybrid(rewards, group_size):
    """
    Hybrid normalization: local mean + global std.
    The production-recommended approach.
    
    Args:
        rewards: Tensor of shape (batch_size,)
        group_size: int - number of responses per prompt (K)
    
    Returns:
        advantages: Tensor of shape (batch_size,)
    """
    batch_size = rewards.shape[0]
    num_groups = batch_size // group_size
    eps = 1e-8
    
    # Global std (stable, never near zero)
    batch_std = rewards.std()
    
    # Reshape to (num_groups, group_size)
    grouped = rewards.view(num_groups, group_size)
    
    # Local mean (preserves intra-prompt competition)
    group_mean = grouped.mean(dim=1, keepdim=True)
    
    # Hybrid: local mean, global std
    advantages = (grouped - group_mean) / (batch_std + eps)
    
    return advantages.view(batch_size)

In [4]:
# TEST: Compare normalization strategies on easy data (gradient explosion case)

# Scenario: 4 prompts, 4 responses each. First prompt is "easy" (all correct).
rewards = torch.tensor([
    1.0, 1.0, 1.0, 1.0,   # Prompt 1: all correct (std=0!)
    1.0, 0.0, 1.0, 0.0,   # Prompt 2: mixed
    0.0, 0.0, 1.0, 0.0,   # Prompt 3: mostly wrong
    0.0, 0.0, 0.0, 0.0,   # Prompt 4: all wrong (std=0!)
])

group_size = 4

adv_local = normalize_grpo_local(rewards, group_size)
adv_global = normalize_global(rewards)
adv_hybrid = normalize_hybrid(rewards, group_size)

print("Rewards:", rewards.numpy())
print()
print(f"{'Method':<20} {'Max |Advantage|':>18} {'Prompt 1 (easy)':>20}")
print("-" * 60)
print(f"{'GRPO (local)':<20} {adv_local.abs().max().item():>18.2f} {adv_local[:4].numpy()}")
print(f"{'Global':<20} {adv_global.abs().max().item():>18.2f} {adv_global[:4].numpy()}")
print(f"{'Hybrid':<20} {adv_hybrid.abs().max().item():>18.2f} {adv_hybrid[:4].numpy()}")
print()
print("GRPO explodes on Prompt 1 (all-correct) because local std = 0")
print("Hybrid stays stable because it uses global std")

Rewards: [1. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0.]

Method                  Max |Advantage|      Prompt 1 (easy)
------------------------------------------------------------
GRPO (local)                       1.50 [0. 0. 0. 0.]
Global                             1.10 [1.0978876 1.0978876 1.0978876 1.0978876]
Hybrid                             1.46 [0. 0. 0. 0.]

GRPO explodes on Prompt 1 (all-correct) because local std = 0
Hybrid stays stable because it uses global std


## 2. Clip-Higher: Asymmetric PPO Clipping

Standard PPO clips ratio to `[1-ε, 1+ε]`. Clip-Higher uses `[1-ε_low, 1+ε_high]` where `ε_high > ε_low`.

In [5]:
def compute_ppo_loss_symmetric(log_probs, old_log_probs, advantages, epsilon=0.2):
    """
    Standard PPO loss with symmetric clipping.
    
    Args:
        log_probs: Current policy log probs, shape (batch, seq_len)
        old_log_probs: Old policy log probs, shape (batch, seq_len)
        advantages: Advantage values, shape (batch,) or (batch, 1)
        epsilon: Clipping parameter (default 0.2)
    
    Returns:
        loss: Scalar tensor (to minimize, so we negate the objective)
    """
    # Expand advantages to match token dimension
    if advantages.dim() == 1:
        advantages = advantages.unsqueeze(1)  # (batch, 1)
    
    # Probability ratio
    ratio = torch.exp(log_probs - old_log_probs)  # (batch, seq_len)
    
    # Clipped ratio (symmetric)
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    
    # PPO objective: min of clipped and unclipped
    obj1 = ratio * advantages
    obj2 = clipped_ratio * advantages
    ppo_obj = torch.min(obj1, obj2)
    
    # Loss is negative objective (we minimize loss)
    return -ppo_obj.mean()


def compute_ppo_loss_clip_higher(log_probs, old_log_probs, advantages, 
                                  epsilon_low=0.2, epsilon_high=0.28):
    """
    PPO loss with asymmetric clipping (Clip-Higher).
    Allows more growth for advantageous tokens.
    
    Args:
        log_probs: Current policy log probs
        old_log_probs: Old policy log probs
        advantages: Advantage values
        epsilon_low: Lower bound clipping (default 0.2)
        epsilon_high: Upper bound clipping (default 0.28)
    
    Returns:
        loss: Scalar tensor
    """
    if advantages.dim() == 1:
        advantages = advantages.unsqueeze(1)
    
    ratio = torch.exp(log_probs - old_log_probs)
    
    # Asymmetric clipping: lower bound is tighter than upper bound
    clipped_ratio = torch.clamp(ratio, 1 - epsilon_low, 1 + epsilon_high)
    
    obj1 = ratio * advantages
    obj2 = clipped_ratio * advantages
    ppo_obj = torch.min(obj1, obj2)
    
    return -ppo_obj.mean()

In [6]:
# TEST: Clip-Higher allows more growth for low-probability tokens

# Simulate a scenario: token with low old probability, positive advantage
# Model wants to increase probability significantly

old_log_prob = torch.tensor([[-3.9]])  # exp(-3.9) ≈ 0.02 (2% probability)
new_log_prob = torch.tensor([[-2.3]])  # exp(-2.3) ≈ 0.10 (10% probability)
advantage = torch.tensor([1.0])        # Positive advantage

ratio = torch.exp(new_log_prob - old_log_prob).item()
print(f"Desired ratio: {ratio:.2f}x (model wants to 5x the probability)")
print()

# Symmetric clipping clips at 1.2
clipped_sym = min(ratio, 1.2)
print(f"Symmetric (ε=0.2): clipped to {clipped_sym:.2f}x")

# Clip-Higher clips at 1.28
clipped_high = min(ratio, 1.28)
print(f"Clip-Higher (ε_high=0.28): clipped to {clipped_high:.2f}x")

print()
print(f"Clip-Higher allows {(clipped_high/clipped_sym - 1)*100:.0f}% more update per step")

Desired ratio: 4.95x (model wants to 5x the probability)

Symmetric (ε=0.2): clipped to 1.20x
Clip-Higher (ε_high=0.28): clipped to 1.28x

Clip-Higher allows 7% more update per step


## 3. Loss Aggregation: Token-Level vs Sequence-Level

In [7]:
def aggregate_loss_sequence_level(per_token_loss, response_lengths):
    """
    Sequence-level aggregation: each response contributes equally.
    
    Args:
        per_token_loss: Tensor of shape (batch, max_seq_len)
        response_lengths: Tensor of shape (batch,) - actual length of each response
    
    Returns:
        loss: Scalar tensor
    """
    batch_size = per_token_loss.shape[0]
    
    # Create mask for valid tokens
    max_len = per_token_loss.shape[1]
    mask = torch.arange(max_len, device=per_token_loss.device).unsqueeze(0) < response_lengths.unsqueeze(1)
    
    # Average loss per response, then average across responses
    masked_loss = per_token_loss * mask
    per_response_loss = masked_loss.sum(dim=1) / response_lengths.float()  # (batch,)
    
    return per_response_loss.mean()


def aggregate_loss_token_level(per_token_loss, response_lengths):
    """
    Token-level aggregation: each token contributes equally.
    
    Args:
        per_token_loss: Tensor of shape (batch, max_seq_len)
        response_lengths: Tensor of shape (batch,) - actual length of each response
    
    Returns:
        loss: Scalar tensor
    """
    max_len = per_token_loss.shape[1]
    mask = torch.arange(max_len, device=per_token_loss.device).unsqueeze(0) < response_lengths.unsqueeze(1)
    
    # Sum all token losses, divide by total token count
    total_loss = (per_token_loss * mask).sum()
    total_tokens = response_lengths.sum()
    
    return total_loss / total_tokens

In [8]:
# TEST: Show the difference in weighting

# Two responses with same per-token loss but different lengths
per_token_loss = torch.ones(2, 100) * 0.1  # Same loss per token
response_lengths = torch.tensor([100, 1000])  # Short vs long

# Pad the second response
per_token_loss_padded = torch.zeros(2, 1000)
per_token_loss_padded[0, :100] = 0.1
per_token_loss_padded[1, :1000] = 0.1

seq_loss = aggregate_loss_sequence_level(per_token_loss_padded, response_lengths)
tok_loss = aggregate_loss_token_level(per_token_loss_padded, response_lengths)

print("Two responses: 100 tokens and 1000 tokens, same per-token loss")
print()
print(f"Sequence-level loss: {seq_loss.item():.4f}")
print(f"Token-level loss: {tok_loss.item():.4f}")
print()

# Compute effective weight per response
print("Effective weight of each response:")
print(f"  Sequence-level: Response 1 = 50%, Response 2 = 50%")
print(f"  Token-level: Response 1 = {100/1100*100:.1f}%, Response 2 = {1000/1100*100:.1f}%")

Two responses: 100 tokens and 1000 tokens, same per-token loss

Sequence-level loss: 0.1000
Token-level loss: 0.1000

Effective weight of each response:
  Sequence-level: Response 1 = 50%, Response 2 = 50%
  Token-level: Response 1 = 9.1%, Response 2 = 90.9%


## 4. Overlong Filtering

In [9]:
def create_overlong_mask(response_lengths, max_length, has_eos):
    """
    Create mask to filter out truncated (overlong) responses.
    
    Args:
        response_lengths: Tensor of shape (batch,) - length of each response
        max_length: int - the maximum generation length limit
        has_eos: Tensor of shape (batch,) - whether response has EOS token
    
    Returns:
        mask: Tensor of shape (batch,) - True for responses to KEEP
    """
    # Response is truncated if it hit max_length AND has no EOS
    is_truncated = (response_lengths >= max_length) & (~has_eos)
    
    # Keep responses that are NOT truncated
    return ~is_truncated


def apply_overlong_filter(rewards, advantages, valid_mask):
    """
    Zero out rewards/advantages for truncated responses.
    
    Args:
        rewards: Tensor of shape (batch,)
        advantages: Tensor of shape (batch,)
        valid_mask: Tensor of shape (batch,) - True for valid responses
    
    Returns:
        filtered_rewards, filtered_advantages
    """
    return rewards * valid_mask.float(), advantages * valid_mask.float()

In [10]:
# TEST: Overlong filtering

max_length = 8192

response_lengths = torch.tensor([500, 3000, 8192, 8192, 400])
has_eos = torch.tensor([True, True, False, True, True])
rewards = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])  # Response 3 is false negative

mask = create_overlong_mask(response_lengths, max_length, has_eos)

print(f"{'Response':<12} {'Length':>8} {'Has EOS':>10} {'Reward':>8} {'Keep?':>8}")
print("-" * 55)
for i in range(len(response_lengths)):
    note = " ← FILTERED (truncated, might be false neg)" if not mask[i] else ""
    print(f"{i+1:<12} {response_lengths[i].item():>8} {str(has_eos[i].item()):>10} {rewards[i].item():>8.1f} {str(mask[i].item()):>8}{note}")

Response       Length    Has EOS   Reward    Keep?
-------------------------------------------------------
1                 500       True      1.0     True
2                3000       True      1.0     True
3                8192      False      0.0    False ← FILTERED (truncated, might be false neg)
4                8192       True      1.0     True
5                 400       True      0.0     True


## 5. Complete Configurable Trainer

Putting it all together with switchable options.

In [11]:
class ConfigurableRLTrainer:
    """
    RL trainer with configurable technical pillars.
    
    Supports:
    - Normalization: 'local', 'global', 'hybrid'
    - Loss aggregation: 'token', 'sequence'
    - Clip-Higher: True/False
    - Overlong filtering: True/False
    """
    
    def __init__(self, model, 
                 norm_type='hybrid',
                 loss_agg='token',
                 clip_higher=False,
                 overlong_filter=False,
                 group_size=4,
                 max_length=8192,
                 epsilon=0.2,
                 epsilon_high=0.28,
                 lr=1e-5):
        
        self.model = model
        self.norm_type = norm_type
        self.loss_agg = loss_agg
        self.clip_higher = clip_higher
        self.overlong_filter = overlong_filter
        self.group_size = group_size
        self.max_length = max_length
        self.epsilon = epsilon
        self.epsilon_high = epsilon_high
        
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    def compute_advantages(self, rewards):
        """Compute advantages using configured normalization."""
        if self.norm_type == 'local':
            return normalize_grpo_local(rewards, self.group_size)
        elif self.norm_type == 'global':
            return normalize_global(rewards)
        elif self.norm_type == 'hybrid':
            return normalize_hybrid(rewards, self.group_size)
        else:
            raise ValueError(f"Unknown norm_type: {self.norm_type}")
    
    def compute_loss(self, log_probs, old_log_probs, advantages, response_lengths):
        """Compute PPO loss with configured clipping and aggregation."""
        
        # Expand advantages to token level
        advantages_expanded = advantages.unsqueeze(1).expand_as(log_probs)
        
        # Probability ratio
        ratio = torch.exp(log_probs - old_log_probs)
        
        # Clipping
        if self.clip_higher:
            clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon_high)
        else:
            clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)
        
        # PPO objective (per token)
        obj1 = ratio * advantages_expanded
        obj2 = clipped_ratio * advantages_expanded
        per_token_obj = torch.min(obj1, obj2)
        
        # We want to maximize objective, so loss = -objective
        per_token_loss = -per_token_obj
        
        # Aggregation
        if self.loss_agg == 'token':
            return aggregate_loss_token_level(per_token_loss, response_lengths)
        elif self.loss_agg == 'sequence':
            return aggregate_loss_sequence_level(per_token_loss, response_lengths)
        else:
            raise ValueError(f"Unknown loss_agg: {self.loss_agg}")
    
    def train_step(self, prompts, rewards, log_probs, old_log_probs, 
                   response_lengths, has_eos):
        """
        Single training step with all configured techniques.
        
        Returns:
            loss: The computed loss value
        """
        # Overlong filtering
        if self.overlong_filter:
            valid_mask = create_overlong_mask(response_lengths, self.max_length, has_eos)
            # Zero out rewards for filtered responses
            rewards = rewards * valid_mask.float()
        
        # Compute advantages
        advantages = self.compute_advantages(rewards)
        
        # Compute loss
        loss = self.compute_loss(log_probs, old_log_probs, advantages, response_lengths)
        
        # Optimization step
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

In [12]:
# DEMONSTRATION: Different trainer configurations

configs = {
    'Lite PPO (base model)': {
        'norm_type': 'hybrid',
        'loss_agg': 'token',
        'clip_higher': False,
        'overlong_filter': False,
    },
    'Aligned Model Config': {
        'norm_type': 'hybrid',
        'loss_agg': 'sequence',
        'clip_higher': True,
        'overlong_filter': False,
    },
    'GRPO (original)': {
        'norm_type': 'local',
        'loss_agg': 'sequence',
        'clip_higher': False,
        'overlong_filter': False,
    },
}

print("Recommended configurations by model type:\n")
for name, config in configs.items():
    print(f"{name}:")
    for k, v in config.items():
        print(f"  {k}: {v}")
    print()

Recommended configurations by model type:

Lite PPO (base model):
  norm_type: hybrid
  loss_agg: token
  clip_higher: False
  overlong_filter: False

Aligned Model Config:
  norm_type: hybrid
  loss_agg: sequence
  clip_higher: True
  overlong_filter: False

GRPO (original):
  norm_type: local
  loss_agg: sequence
  clip_higher: False
  overlong_filter: False



## 6. Summary: Decision Rules

```python
# Base model configuration
trainer = ConfigurableRLTrainer(
    model,
    norm_type='hybrid',    # Always use hybrid
    loss_agg='token',      # Token-level for base models
    clip_higher=False,     # No benefit for base models
    overlong_filter=False, # Only if max_length <= 8K
)

# Aligned model configuration  
trainer = ConfigurableRLTrainer(
    model,
    norm_type='hybrid',     # Always use hybrid
    loss_agg='sequence',    # Sequence-level for aligned models
    clip_higher=True,       # Prevents entropy collapse
    overlong_filter=False,  # Only if max_length <= 8K
)
```

**Key insight:** The effectiveness of each technique depends on model state, not just the technique itself.

---
**Next:** `03_reward_design_fundamentals.ipynb`