# Reference Models in RLHF

**Creating and managing frozen reference models**

## What is a Reference Model?

The **reference model** is a frozen copy of the policy at the start of RLHF training. It serves as an anchor to prevent the policy from drifting too far.

```
SFT Model
    │
    ├── → Policy Model (trainable)
    │
    └── → Reference Model (frozen)
```

## Why Reference Models Matter

Without a reference model, the policy can:

1. **Reward hack** — Find degenerate high-reward outputs
2. **Mode collapse** — Generate repetitive responses
3. **Forget language** — Lose coherent generation ability

The KL penalty against the reference prevents these failure modes.

In [None]:
import torch
import copy
from transformers import AutoModelForCausalLM

def create_reference_model(policy_model):
    """
    Create a frozen reference model from the policy.
    
    The reference model is a deep copy with all parameters frozen.
    """
    # Deep copy the model
    reference_model = copy.deepcopy(policy_model)
    
    # Freeze all parameters
    for param in reference_model.parameters():
        param.requires_grad = False
    
    # Set to evaluation mode
    reference_model.eval()
    
    return reference_model

# Example
policy_model = AutoModelForCausalLM.from_pretrained("gpt2")
reference_model = create_reference_model(policy_model)

# Verify
policy_trainable = sum(p.numel() for p in policy_model.parameters() if p.requires_grad)
ref_trainable = sum(p.numel() for p in reference_model.parameters() if p.requires_grad)

print(f"Policy trainable params: {policy_trainable:,}")
print(f"Reference trainable params: {ref_trainable:,}")

## Memory Optimization

Having both policy and reference in memory doubles memory usage. Solutions:

In [None]:
# Option 1: Keep reference in half precision
def create_reference_model_fp16(policy_model):
    """Create reference model in half precision to save memory."""
    reference_model = copy.deepcopy(policy_model)
    reference_model = reference_model.half()  # Convert to FP16
    
    for param in reference_model.parameters():
        param.requires_grad = False
    
    reference_model.eval()
    return reference_model

# Option 2: Move reference to CPU (slower but saves GPU memory)
def create_reference_model_cpu(policy_model):
    """Create reference model on CPU to save GPU memory."""
    reference_model = copy.deepcopy(policy_model)
    reference_model = reference_model.cpu()
    
    for param in reference_model.parameters():
        param.requires_grad = False
    
    reference_model.eval()
    return reference_model

print("Memory optimization strategies:")
print("  1. FP16 reference: ~50% memory reduction")
print("  2. CPU reference: Full GPU memory for policy (slower)")
print("  3. Compute KL only periodically (approximation)")

## Computing Reference Log Probabilities

In [None]:
import torch.nn.functional as F

def get_log_probs(model, input_ids, attention_mask):
    """
    Get log probabilities for tokens under a model.
    """
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    
    # Shift for next-token prediction
    shift_logits = logits[:, :-1, :]
    shift_labels = input_ids[:, 1:]
    
    # Compute log probabilities
    log_probs = F.log_softmax(shift_logits, dim=-1)
    
    # Gather log probs for actual tokens
    token_log_probs = torch.gather(
        log_probs,
        dim=-1,
        index=shift_labels.unsqueeze(-1)
    ).squeeze(-1)
    
    return token_log_probs

# Example usage
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

text = "Hello, how are you doing today?"
inputs = tokenizer(text, return_tensors="pt")

policy_logprobs = get_log_probs(policy_model, inputs['input_ids'], inputs['attention_mask'])
ref_logprobs = get_log_probs(reference_model, inputs['input_ids'], inputs['attention_mask'])

print(f"Policy log probs shape: {policy_logprobs.shape}")
print(f"Reference log probs shape: {ref_logprobs.shape}")

kl = (policy_logprobs - ref_logprobs).mean()
print(f"KL divergence: {kl.item():.4f}")

## Verifying Reference is Frozen

In [None]:
def verify_reference_frozen(policy_model, reference_model):
    """
    Verify that reference model is properly frozen.
    """
    # Check no gradients
    ref_requires_grad = any(p.requires_grad for p in reference_model.parameters())
    
    # Check weights are different from policy (after training)
    # Initially they should be the same
    first_policy_param = next(policy_model.parameters())
    first_ref_param = next(reference_model.parameters())
    weights_equal = torch.allclose(first_policy_param, first_ref_param)
    
    print("Reference Model Verification:")
    print(f"  Requires grad: {ref_requires_grad} (should be False)")
    print(f"  Weights equal to policy: {weights_equal} (True initially, False after training)")
    
    return not ref_requires_grad

verify_reference_frozen(policy_model, reference_model)

## Monitoring Divergence

In [None]:
def compute_weight_divergence(policy_model, reference_model):
    """
    Compute how far policy weights have diverged from reference.
    """
    total_diff = 0.0
    total_norm = 0.0
    
    for (name, p_param), (_, r_param) in zip(
        policy_model.named_parameters(),
        reference_model.named_parameters()
    ):
        diff = (p_param - r_param).norm().item()
        norm = r_param.norm().item()
        total_diff += diff
        total_norm += norm
    
    relative_divergence = total_diff / (total_norm + 1e-8)
    
    return {
        'absolute_divergence': total_diff,
        'relative_divergence': relative_divergence
    }

# Initially should be ~0
divergence = compute_weight_divergence(policy_model, reference_model)
print(f"Weight divergence:")
print(f"  Absolute: {divergence['absolute_divergence']:.6f}")
print(f"  Relative: {divergence['relative_divergence']:.6f}")

## Next Steps

Now that we understand the complete RLHF pipeline, let's explore DPO — a simpler alternative that doesn't require a reward model.