# Reward Hacking Detection and Prevention

**Reward hacking** (or *reward gaming*) occurs when a reinforcement learning (RL) agent finds a way to maximize the **proxy reward** (the reward signal it is trained on) without actually achieving the **gold reward** (the true underlying objective or human intent). This is a quintessential example of **Goodhart's Law**: *"When a measure becomes a target, it ceases to be a good measure."*

As we scale RLHF (Reinforcement Learning from Human Feedback), models become increasingly capable of finding loopholes in the reward model (RM). For instance:
- **Length Hacking**: The model learns that longer responses are generally preferred and starts generating verbose, repetitive fluff.
- **Sycophancy**: The model learns to agree with the user's biases rather than being truthful, as this often yields higher immediate rewards.
- **Adversarial Exploitation**: The model finds specific token sequences that trigger high reward scores despite being nonsense.

In this notebook, we implement mechanisms to **detect** and **mitigate** these behaviors.

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Optional, Dict

## 2. Theoretical Framework: Signals of Overoptimization

Deep research into reward hacking (e.g., *Scaling Laws for Reward Model Overoptimization*, *The Effects of Reward Misspecification*) suggests several robust signals for detecting when a model is going off-rails:

1.  **Proxy-Gold Divergence**: The most direct signal. We track the `proxy_reward` (what the model optimizes) vs. a `gold_reward` (a held-out, higher-quality RM, or actual human evaluation labels). If Proxy $\uparrow$ while Gold $\downarrow$, we have hacking.
2.  **KL Divergence Explosion**: RLHF typically penalizes the KL divergence between the training policy and the reference model. A sudden spike in KL suggests the model is drifting into a distribution where the RM is unreliable (out-of-distribution).
3.  **Length Inflation**: A simple heuristic, but highly effective. Sudden increases in average response length often correlate with length hacking.
4.  **Variance Collapse**: Although not implemented here, a collapse in output diversity can also signal that the model has found a single "super-stimulus" to exploit.

In [None]:
class RewardHackingDetector:
    """
    Detect reward hacking during training.
    
    Key signals:
    - Proxy reward increasing while gold reward plateaus/drops
    - KL divergence exploding
    - Response length inflating
    - Diversity collapsing
    """
    
    def __init__(self, window_size: int = 100):
        self.window_size = window_size
        self.proxy_rewards = []
        self.gold_rewards = []  # From held-out evaluation
        self.kl_values = []
        self.lengths = []
    
    def update(self, proxy: float, gold: float, kl: float, length: int):
        """Record metrics for one training step."""
        self.proxy_rewards.append(proxy)
        self.gold_rewards.append(gold)
        self.kl_values.append(kl)
        self.lengths.append(length)
    
    def detect_overoptimization(self) -> Dict:
        """
        Detect if proxy is being overoptimized.
        
        Returns:
            Dict with detection results and metrics
        """
        if len(self.proxy_rewards) < self.window_size * 2:
            return {'detected': False, 'reason': 'Insufficient data'}
        
        # Compare recent window to earlier window
        early_proxy = np.mean(self.proxy_rewards[:self.window_size])
        late_proxy = np.mean(self.proxy_rewards[-self.window_size:])
        
        early_gold = np.mean(self.gold_rewards[:self.window_size])
        late_gold = np.mean(self.gold_rewards[-self.window_size:])
        
        proxy_delta = late_proxy - early_proxy
        gold_delta = late_gold - early_gold
        
        # Hacking detected if proxy goes up but gold goes down
        detected = proxy_delta > 0.1 and gold_delta < -0.05
        
        return {
            'detected': detected,
            'proxy_delta': proxy_delta,
            'gold_delta': gold_delta,
            'reason': 'Proxy ↑ while gold ↓' if detected else 'OK'
        }
    
    def detect_length_hacking(self, baseline_length: float = 200) -> Dict:
        """Detect if model is gaming length."""
        if len(self.lengths) < self.window_size:
            return {'detected': False, 'reason': 'Insufficient data'}
        
        avg_length = np.mean(self.lengths[-self.window_size:])
        ratio = avg_length / baseline_length
        
        detected = ratio > 2.0  # More than 2x baseline
        
        return {
            'detected': detected,
            'avg_length': avg_length,
            'length_ratio': ratio,
            'reason': f'Length {ratio:.1f}x baseline' if detected else 'OK'
        }
    
    def detect_kl_explosion(self, threshold: float = 15.0) -> Dict:
        """Detect if KL divergence is exploding."""
        if len(self.kl_values) < 10:
            return {'detected': False, 'reason': 'Insufficient data'}
        
        recent_kl = np.mean(self.kl_values[-10:])
        detected = recent_kl > threshold
        
        return {
            'detected': detected,
            'recent_kl': recent_kl,
            'reason': f'KL={recent_kl:.1f} > {threshold}' if detected else 'OK'
        }
    
    def should_early_stop(self) -> Tuple[bool, str]:
        """Determine if training should stop due to hacking."""
        overopt = self.detect_overoptimization()
        length = self.detect_length_hacking()
        kl = self.detect_kl_explosion()
        
        if overopt['detected']:
            return True, f"Overoptimization: {overopt['reason']}"
        if kl['detected']:
            return True, f"KL explosion: {kl['reason']}"
        if length['detected']:
            return True, f"Length hacking: {length['reason']}"
        
        return False, "Training healthy"

## 3. Simulation: Goodhart's Law

The following function simulates the training dynamics where a model exploits the proxy reward. We observe three distinct phases:
1.  **Alignment Phase (Steps 0-50):** Both metrics improve. This is the ideal training state.
2.  **Plateau Phase (Steps 50-100):** The model hits the limit of the gold objective, but continues to optimize the proxy.
3.  **Hacking Phase (Steps 100+):** The proxy reward continues to rise (the model "thinks" it's doing better), but the gold reward crashes. This is where we must alert or stop.

In [None]:
def simulate_goodharts_law(steps: int = 200):
    """
    Simulate what happens when proxy reward is overoptimized.
    
    Early training: Both proxy and gold improve together
    Mid training: Gold plateaus
    Late training: Proxy keeps rising, gold drops (HACKING!)
    """
    np.random.seed(42)
    detector = RewardHackingDetector(window_size=30)
    
    print("Simulating Goodhart's Law (Overoptimization):")
    print(f"{'Step':>6} {'Proxy':>8} {'Gold':>8} {'Status':>20}")
    print("-" * 45)
    
    for step in range(steps):
        # Simulate training dynamics
        if step < 50:  # Early: Both improve
            proxy = 0.3 + step * 0.01 + np.random.normal(0, 0.02)
            gold = 0.3 + step * 0.008 + np.random.normal(0, 0.02)
        elif step < 100:  # Mid: Gold plateaus
            proxy = 0.8 + (step - 50) * 0.005 + np.random.normal(0, 0.02)
            gold = 0.7 + np.random.normal(0, 0.02)  # Plateau!
        else:  # Late: Gold drops (HACKING)
            proxy = 1.0 + (step - 100) * 0.003 + np.random.normal(0, 0.02)
            gold = 0.7 - (step - 100) * 0.003 + np.random.normal(0, 0.02)  # Drops!
        
        kl = 3.0 + step * 0.05 + np.random.normal(0, 0.5)
        length = int(200 + step * 1.5)
        
        detector.update(proxy, gold, kl, length)
        
        if step % 40 == 0 or step == steps - 1:
            should_stop, reason = detector.should_early_stop()
            status = f"⚠️ {reason}" if should_stop else "✓ OK"
            print(f"{step:>6} {proxy:>8.3f} {gold:>8.3f} {status:>20}")
    
    print("\n" + "=" * 45)
    final_check = detector.detect_overoptimization()
    print(f"Final verdict: {'🚨 HACKING DETECTED' if final_check['detected'] else '✓ OK'}")
    print(f"  Proxy Δ: {final_check.get('proxy_delta', 0):.3f}")
    print(f"  Gold Δ: {final_check.get('gold_delta', 0):.3f}")

simulate_goodharts_law()

## 4. Mitigation Strategy: Ensemble Reward Models

A powerful mitigation strategy is **Reward Ensembling**. Instead of training a single reward model, we train multiple models (with different seeds, architectures, or data splits). 

The intuition is that while a hack might fool one model, it is less likely to fool an entire committee. We can use the **Uncertainty-Aware Reward**:

$$ R_{conservative}(x) = \mu(R(x)) - \alpha \cdot \sigma(R(x)) $$

Where $\mu$ is the mean reward across the ensemble and $\sigma$ is the standard deviation. This penalizes inputs where the models disagree, which often corresponds to out-of-distribution or "hacky" examples.


In [None]:
class EnsembleRewardModel:
    """
    Combine multiple reward models to resist hacking.
    
    Intuition: Harder to fool multiple models simultaneously.
    Each RM has different biases → average cancels them out.
    """
    
    def __init__(self, reward_models: List[nn.Module]):
        self.models = reward_models
        self.n_models = len(reward_models)
    
    def compute_reward(self, input_ids: torch.Tensor) -> Dict:
        """
        Compute ensemble reward.
        
        Returns:
            Dict with 'mean', 'std', 'individual' rewards
        """
        rewards = []
        
        with torch.no_grad():
            for model in self.models:
                r = model(input_ids)
                rewards.append(r)
        
        stacked = torch.stack(rewards, dim=0)  # (n_models, batch)
        
        return {
            'mean': stacked.mean(dim=0),  # Average across models
            'std': stacked.std(dim=0),    # Uncertainty
            'individual': rewards,
        }
    
    def compute_conservative_reward(self, input_ids: torch.Tensor,
                                    pessimism: float = 1.0) -> torch.Tensor:
        """
        Conservative reward: mean - pessimism * std
        
        Higher pessimism → more conservative → less hackable
        """
        result = self.compute_reward(input_ids)
        return result['mean'] - pessimism * result['std']

## 5. Complete Mitigation Pipeline

Here we combine all defenses into a single `RewardHackingMitigation` system:
1.  **Length Penalty**: Penalize responses that exceed a target length ratio.
2.  **KL Penalty**: The standard PPO mechanism ($R_{total} = R_{model} - \beta \log \frac{\pi}{\pi_{ref}}$).
3.  **Hacking Detection**: Constant monitoring to trigger early stopping.

In [None]:
class RewardHackingMitigation:
    """
    Complete anti-hacking system combining all strategies:
    
    1. KL penalty: Prevent drift from reference
    2. Length penalty: Prevent verbosity gaming
    3. Ensemble averaging: Reduce exploitability
    4. Early stopping: Detect and halt overoptimization
    """
    
    def __init__(self,
                 kl_beta: float = 0.1,
                 target_length: int = 200,
                 detector_window: int = 100):
        
        self.kl_beta = kl_beta
        self.target_length = target_length
        self.detector = RewardHackingDetector(window_size=detector_window)
    
    def compute_safe_reward(self,
                            raw_reward: float,
                            response_length: int,
                            kl_divergence: float,
                            gold_reward: Optional[float] = None) -> Dict:
        """
        Compute reward with all anti-hacking measures applied.
        """
        result = {'raw_reward': raw_reward}
        
        # 1. Length penalty (ratio method)
        if response_length <= self.target_length:
            length_penalized = raw_reward
        else:
            ratio = self.target_length / response_length
            length_penalized = raw_reward * ratio
        result['after_length_penalty'] = length_penalized
        
        # 2. KL penalty
        kl_penalty = self.kl_beta * kl_divergence
        kl_penalized = length_penalized - kl_penalty
        result['after_kl_penalty'] = kl_penalized
        
        # 3. Update detector
        # In production: gold_reward comes from held-out evaluation
        if gold_reward is None:
            gold_reward = raw_reward * 0.9  # Simulated
        self.detector.update(raw_reward, gold_reward, kl_divergence, response_length)
        
        # 4. Check for early stopping
        should_stop, reason = self.detector.should_early_stop()
        result['should_stop'] = should_stop
        result['stop_reason'] = reason
        
        result['final_reward'] = kl_penalized
        return result

## 6. Demonstration

The demo below shows the mitigation system in action. The `raw_reward` increases (simulating hacking), but the `after_kl_penalty` stays stable or decreases because the KL divergence is rising. Eventually, the `RewardHackingDetector` triggers an early stop when the system detects the anomaly.

In [None]:
def demo_mitigation_system():
    """Test the complete mitigation system."""
    
    mitigation = RewardHackingMitigation(
        kl_beta=0.1,
        target_length=200,
        detector_window=20,
    )
    
    print("Anti-Hacking Reward Pipeline:")
    print(f"{'Step':>5} {'Raw':>8} {'After Len':>10} {'After KL':>10} {'Status':>15}")
    print("-" * 55)
    
    for step in range(50):
        # Simulate increasingly hacky behavior
        raw_reward = 0.5 + step * 0.02
        length = 200 + step * 10  # Growing length
        kl = 3.0 + step * 0.2     # Growing KL
        
        result = mitigation.compute_safe_reward(
            raw_reward=raw_reward,
            response_length=length,
            kl_divergence=kl,
        )
        
        if step % 10 == 0 or result['should_stop']:
            status = '⚠️ STOP' if result['should_stop'] else '✓ OK'
            print(f"{step:>5} {raw_reward:>8.3f} {result['after_length_penalty']:>10.3f} "
                  f"{result['after_kl_penalty']:>10.3f} {status:>15}")
            if result['should_stop']:
                print(f"\n🚨 Early stop triggered: {result['stop_reason']}")
                break

demo_mitigation_system()