# Week 8.3: Reinforcement Learning from Human Feedback (RLHF)

**Resource Required**: Google Colab (A100) GPU or Provisioned GPU (RunPod 1x A100 recommended)

## Objectives:
* Understand how the RL agent/action/environment paradigm works in the context of autoregressive transformer models
* Learn how the RLHF algorithm works and how it fits on top of PPO (Proximal Policy Optimization)
* Understand value heads and how they turn transformers into actor-critic networks
* Implement key components of RLHF including KL penalty, advantage estimation, and value heads
* Train a transformer model using RLHF with simple reward functions

⚠️ **NOTE**: This notebook requires substantial GPU memory. We recommend using an A100 GPU. If you're using a less powerful GPU (e.g., A10), set `LOW_GPU_MEM = True` in the code below to use smaller models and batch sizes.

## 1. Introduction to RLHF

Reinforcement Learning from Human Feedback (RLHF) is a technique for training language models to behave in ways that align with human preferences. It builds on top of standard reinforcement learning algorithms like PPO (Proximal Policy Optimization) but applies them to the unique setting of autoregressive language models.

The key insight of RLHF is that we can treat text generation as a sequential decision-making problem where:
- The **agent** is our language model
- The **environment** is the context/prompt and the text generated so far
- **Actions** are the tokens the model generates
- **Rewards** come from a reward model trained on human preferences

Let's start by installing the necessary dependencies:

In [None]:
!pip install transformers>=4.40.1 accelerate>=0.27.2 torch einops jaxtyping matplotlib

In [None]:
import torch as t
import torch.nn as nn
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
from typing import Callable
import einops
from jaxtyping import Float, Int
from torch import Tensor
import matplotlib.pyplot as plt  # Added for GRPO visualization
import random  # Added for countdown problem generation
import re  # Added for reward function pattern matching

# Configuration for low GPU memory
LOW_GPU_MEM = False  # Set to True if using less powerful GPU
BASE_MODEL = "gpt2-small" if LOW_GPU_MEM else "gpt2-medium"

device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Using model: {BASE_MODEL}")

## 2. RL Concepts in Language Models

### States, Actions, and Episodes

In the context of language models:

- **States** ($s_t$): The entire sequence of tokens generated up to time $t$. This includes both the initial prompt and all tokens generated so far.
- **Actions** ($a_t$): The next token to generate. The action space is the entire vocabulary of the model.
- **Transitions**: Given state $s_t$ (sequence) and action $a_t$ (new token), the next state is simply the concatenation: $s_{t+1} = [s_t \; a_t]$
- **Episodes**: Each episode starts with a prompt and continues for a fixed number of generated tokens.

Let's visualize this with a simple example:

In [None]:
# Load a simple GPT-2 model to demonstrate
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Example of states and actions in language modeling
prompt = "The weather today is"
print(f"Initial state s_0: '{prompt}'")
print(f"Action space size: {model.config.vocab_size} (entire vocabulary)\n")

# Generate one token at a time to show the state transitions
current_text = prompt
for i in range(3):
    inputs = tokenizer(current_text, return_tensors="pt").to(device)
    with t.no_grad():
        outputs = model(**inputs)
        next_token_id = outputs.logits[0, -1, :].argmax()
        next_token = tokenizer.decode([next_token_id])
    
    print(f"Step {i+1}:")
    print(f"  Current state s_{i}: '{current_text}'")
    print(f"  Action a_{i} (next token): '{next_token}'")
    current_text += next_token
    print(f"  New state s_{i+1}: '{current_text}'\n")

### Rewards and Value Functions

In RLHF:
- **Rewards** $r_t$ are typically computed only at the end of the generation (episode), based on the complete sequence
- **Value function** $V(s_t)$ estimates the expected future reward from state $s_t$
- Since we only get rewards at the end, there's no discounting (effectively $\gamma = 1$)

## 3. The RLHF Algorithm

RLHF consists of three main stages:

1. **Supervised Fine-Tuning (SFT)**: Train the base model on high-quality human demonstrations
2. **Reward Model Training**: Train a model to predict human preferences between pairs of outputs
3. **RL Fine-Tuning**: Use PPO to optimize the model to maximize the reward while staying close to the original model

Today we'll focus on stage 3, using a predefined reward function instead of training one from human feedback.

### Key Components of RLHF:

1. **PPO Objective**: Maximize expected reward while constraining policy changes
2. **KL Penalty**: Prevent the model from diverging too far from the reference model
3. **Value Head**: Estimate future rewards for advantage calculation

## 4. PPO Intuitions

PPO (Proximal Policy Optimization) is the RL algorithm underlying RLHF. The key ideas are:

1. **Clipped Objective**: Prevent large policy updates that could destabilize training
2. **Advantage Estimation**: Use the difference between actual returns and value estimates
3. **Multiple Update Epochs**: Reuse collected data for multiple gradient updates

The PPO objective function is:

$$L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right]$$

Where:
- $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ is the probability ratio
- $\hat{A}_t$ is the advantage estimate
- $\epsilon$ is the clipping parameter (typically 0.2)

Let's implement a simple advantage calculation:

In [None]:
@t.no_grad()
def compute_advantages(
    values: Float[Tensor, "batch seq_len"],
    rewards: Float[Tensor, "batch"],
    prefix_len: int,
) -> Float[Tensor, "batch gen_len"]:
    """
    Compute advantages for RLHF using simple one-step estimation.
    
    In RLHF, advantages are computed as:
    A(s_t, a_t) = Q(s_t, a_t) - V(s_t)
    
    Where Q(s_t, a_t) is approximated by V(s_{t+1}) for all tokens
    except the last one, where it equals the actual reward.
    """
    batch_size, seq_len = values.shape
    gen_len = seq_len - prefix_len
    
    # For tokens before the last, Q = V(next state)
    # For the last token, Q = reward
    q_values = t.cat([
        values[:, prefix_len:],  # V(s_{t+1}) for t < T
        rewards[:, None]         # r for t = T
    ], dim=1)
    
    # V(s_t) for all generated tokens
    v_values = values[:, prefix_len-1:-1]
    
    # Advantages = Q - V
    advantages = q_values - v_values
    
    return advantages

# Example
batch_size, seq_len, prefix_len = 2, 10, 3
values = t.randn(batch_size, seq_len)
rewards = t.randn(batch_size)
advantages = compute_advantages(values, rewards, prefix_len)
print(f"Values shape: {values.shape}")
print(f"Rewards shape: {rewards.shape}")
print(f"Advantages shape: {advantages.shape}")
print(f"Expected advantages shape: ({batch_size}, {seq_len - prefix_len})")

## 5. Value Heads and Actor-Critic Architecture

In RLHF, we use a **value head** - a small neural network attached to the transformer's final layer that estimates the value function. This creates an actor-critic architecture where:

- **Actor**: The language model itself (generates actions/tokens)
- **Critic**: The value head (estimates future rewards)

Both share the same transformer backbone, which allows them to share learned representations.

Let's implement a transformer with a value head:

In [None]:
class TransformerWithValueHead(nn.Module):
    """
    A transformer model with an additional value head for RLHF.
    
    The value head is attached to the final layer's output (after LayerNorm)
    and produces a scalar value estimate for each token position.
    """
    
    def __init__(self, base_model_name: str):
        super().__init__()
        
        # Load the base transformer model
        self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
        self.config = self.base_model.config
        
        # Create value head: a 2-layer MLP
        # Input: hidden states from transformer (d_model)
        # Output: scalar value estimate
        d_model = self.config.hidden_size
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, 1)
        )
    
    def forward(
        self, 
        input_ids: Int[Tensor, "batch seq_len"],
        attention_mask: Float[Tensor, "batch seq_len"] = None
    ) -> tuple[Float[Tensor, "batch seq_len vocab_size"], Float[Tensor, "batch seq_len"]]:
        """
        Forward pass through both the language model and value head.
        
        Returns:
            logits: Token predictions from the language model
            values: Value estimates for each position
        """
        # Get transformer outputs
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # Extract hidden states from the final layer
        # Shape: (batch, seq_len, d_model)
        hidden_states = outputs.hidden_states[-1]
        
        # Compute value estimates
        # Shape: (batch, seq_len)
        values = self.value_head(hidden_states).squeeze(-1)
        
        return outputs.logits, values

# Create and test the model
model_with_value_head = TransformerWithValueHead("gpt2").to(device)

# Test forward pass
test_input = t.randint(0, 1000, (2, 10)).to(device)
logits, values = model_with_value_head(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Logits shape: {logits.shape} (batch_size, seq_len, vocab_size)")
print(f"Values shape: {values.shape} (batch_size, seq_len)")
print(f"\nValue head parameters: {sum(p.numel() for p in model_with_value_head.value_head.parameters()):,}")

## 6. KL Divergence Penalty

A crucial component of RLHF is the KL divergence penalty, which prevents the model from deviating too far from the reference model. This helps maintain coherent outputs while optimizing for rewards.

The KL penalty is computed as:

$$D_{KL}(\pi_{\text{new}} || \pi_{\text{ref}}) = \sum_{t} \mathbb{E}_{a \sim \pi_{\text{new}}} \left[ \log \frac{\pi_{\text{new}}(a|s_t)}{\pi_{\text{ref}}(a|s_t)} \right]$$

Let's implement this:

In [None]:
def calc_kl_penalty(
    logits: Float[Tensor, "batch gen_len vocab_size"],
    ref_logits: Float[Tensor, "batch gen_len vocab_size"],
    kl_coef: float = 0.1,
) -> Float[Tensor, ""]:
    """
    Calculate KL divergence between new and reference policies.
    
    This penalizes the model for generating tokens that would be
    very unlikely under the reference model.
    """
    # Convert logits to log probabilities (numerically stable)
    log_probs = logits.log_softmax(dim=-1)
    ref_log_probs = ref_logits.log_softmax(dim=-1)
    
    # Convert to probabilities
    probs = log_probs.exp()
    
    # KL divergence: sum_a p(a) log(p(a)/q(a))
    kl_div = (probs * (log_probs - ref_log_probs)).sum(dim=-1)
    
    # Average over batch and sequence
    kl_penalty = kl_coef * kl_div.mean()
    
    return kl_penalty

# Example
batch_size, gen_len, vocab_size = 2, 5, 100
logits = t.randn(batch_size, gen_len, vocab_size)
ref_logits = t.randn(batch_size, gen_len, vocab_size)

kl_penalty = calc_kl_penalty(logits, ref_logits)
print(f"KL penalty: {kl_penalty.item():.4f}")

# When logits are identical, KL should be 0
kl_penalty_same = calc_kl_penalty(logits, logits)
print(f"KL penalty (identical distributions): {kl_penalty_same.item():.6f}")

## 7. Simple RLHF Training Loop

Now let's put everything together in a simplified RLHF training loop. We'll use a toy reward function that counts periods (.) in the generated text:

In [None]:
# Simple reward function: count periods
def reward_fn_period_count(texts: list[str]) -> Float[Tensor, "batch"]:
    """Reward function that counts periods in generated text."""
    rewards = t.tensor([text.count('.') for text in texts], dtype=t.float32)
    return rewards.to(device)

@dataclass
class RLHFConfig:
    """Configuration for RLHF training."""
    # Model
    base_model: str = "gpt2"
    
    # Training
    batch_size: int = 4 if LOW_GPU_MEM else 8
    num_epochs: int = 10
    learning_rate: float = 1e-5
    
    # Generation
    max_gen_len: int = 20
    temperature: float = 1.0
    
    # RLHF specific
    kl_coef: float = 0.1
    clip_coef: float = 0.2
    value_coef: float = 0.5
    
    # Prompts
    prompts: list[str] = None
    
    def __post_init__(self):
        if self.prompts is None:
            self.prompts = [
                "The weather today is",
                "I think that",
                "Once upon a time",
                "In my opinion",
            ]

In [None]:
def simple_rlhf_step(model, ref_model, tokenizer, config, optimizer):
    """
    Perform one RLHF training step.
    
    This is a simplified version that demonstrates the key concepts.
    """
    model.train()
    
    # 1. Generate samples from current policy
    generated_texts = []
    all_input_ids = []
    all_values = []
    all_logits = []
    
    for prompt in config.prompts:
        # Tokenize prompt
        inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
        input_ids = inputs.input_ids
        
        # Generate with current model
        with t.no_grad():
            outputs = model.base_model.generate(
                input_ids,
                max_new_tokens=config.max_gen_len,
                temperature=config.temperature,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )
            
            # Get full generated sequence
            generated_ids = outputs.sequences
            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            generated_texts.append(generated_text)
            
            # Get values and logits for generated sequence
            logits, values = model(generated_ids)
            
            all_input_ids.append(generated_ids)
            all_values.append(values)
            all_logits.append(logits)
    
    # 2. Compute rewards
    rewards = reward_fn_period_count(generated_texts)
    
    # 3. Compute advantages (simplified)
    advantages_list = []
    for i, (values, reward) in enumerate(zip(all_values, rewards)):
        prompt_len = len(tokenizer.encode(config.prompts[i]))
        advantages = compute_advantages(values, reward.unsqueeze(0), prompt_len)
        advantages_list.append(advantages)
    
    # 4. Compute loss components
    total_loss = 0
    
    for i in range(len(config.prompts)):
        # Get reference model logits
        with t.no_grad():
            ref_outputs = ref_model(all_input_ids[i])
            ref_logits = ref_outputs.logits if hasattr(ref_outputs, 'logits') else ref_outputs[0]
        
        # Compute KL penalty
        gen_start = len(tokenizer.encode(config.prompts[i]))
        kl_loss = calc_kl_penalty(
            all_logits[i][:, gen_start:],
            ref_logits[:, gen_start:],
            config.kl_coef
        )
        
        # Value function loss (simplified)
        value_loss = config.value_coef * advantages_list[i].pow(2).mean()
        
        total_loss += kl_loss + value_loss
    
    # 5. Optimization step
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return {
        'loss': total_loss.item(),
        'mean_reward': rewards.mean().item(),
        'generated_texts': generated_texts
    }

In [None]:
# Initialize models and training
config = RLHFConfig()
model = TransformerWithValueHead(config.base_model).to(device)
ref_model = AutoModelForCausalLM.from_pretrained(config.base_model).to(device)
ref_model.eval()  # Reference model stays frozen

tokenizer = AutoTokenizer.from_pretrained(config.base_model)
tokenizer.pad_token = tokenizer.eos_token

# Optimizer with different learning rates for base model and value head
optimizer = t.optim.Adam([
    {'params': model.base_model.parameters(), 'lr': config.learning_rate},
    {'params': model.value_head.parameters(), 'lr': config.learning_rate * 10}
])

print("Starting RLHF training...\n")

# Training loop
for epoch in range(config.num_epochs):
    results = simple_rlhf_step(model, ref_model, tokenizer, config, optimizer)
    
    print(f"Epoch {epoch + 1}/{config.num_epochs}")
    print(f"  Loss: {results['loss']:.4f}")
    print(f"  Mean Reward: {results['mean_reward']:.2f}")
    print(f"  Sample: {results['generated_texts'][0][:100]}...")
    print()

## 8. Observing RLHF Behavior

As training progresses, you should observe:

1. **Increasing rewards**: The model learns to generate more periods
2. **Behavioral changes**: Sentences may become shorter or the model may find creative ways to include periods
3. **KL constraint effect**: The model won't completely collapse into generating only periods due to the KL penalty

Common strategies the model might learn:
- Shorter sentences
- Abbreviations (e.g., "U.S.", "Dr.")
- Decimal numbers
- Lists with periods

Let's examine the model's behavior before and after training:

In [None]:
def compare_models(prompt, ref_model, trained_model, tokenizer, max_length=50):
    """Compare outputs from reference and RLHF-trained models."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate from reference model
    with t.no_grad():
        ref_output = ref_model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=0.8,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
        ref_text = tokenizer.decode(ref_output[0], skip_special_tokens=True)
        
        # Generate from trained model
        trained_output = trained_model.base_model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=0.8,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )
        trained_text = tokenizer.decode(trained_output[0], skip_special_tokens=True)
    
    print(f"Prompt: {prompt}\n")
    print(f"Reference model output:\n{ref_text}\n")
    print(f"Period count: {ref_text.count('.')}\n")
    print("-" * 50)
    print(f"RLHF model output:\n{trained_text}\n")
    print(f"Period count: {trained_text.count('.')}")

# Compare outputs
test_prompts = [
    "The most important thing about science is",
    "I went to the store and bought"
]

for prompt in test_prompts:
    compare_models(prompt, ref_model, model, tokenizer)
    print("\n" + "="*70 + "\n")

## Summary and Key Takeaways

In this notebook, we've covered the fundamentals of RLHF:

1. **RL Framework for Language**: How text generation maps to the RL paradigm with states (sequences), actions (tokens), and rewards

2. **PPO Algorithm**: The core RL algorithm used in RLHF, with its clipped objective and advantage estimation

3. **Value Heads**: How to extend transformers with value estimation capabilities for actor-critic training

4. **KL Penalty**: The crucial constraint that keeps RLHF models from diverging too far from sensible outputs

5. **Implementation**: A simplified but functional RLHF training loop demonstrating these concepts

### Real-World RLHF

In practice, RLHF systems are more complex:
- **Human feedback**: Reward models are trained on actual human preferences
- **Scale**: Training happens on much larger models with more sophisticated infrastructure
- **Safety**: Additional constraints ensure helpful, harmless, and honest outputs
- **Efficiency**: Techniques like LoRA reduce computational requirements

### Further Reading

- [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) (InstructGPT paper)
- [Constitutional AI: Harmlessness from AI Feedback](https://arxiv.org/abs/2212.08073)
- [Direct Preference Optimization](https://arxiv.org/abs/2305.18290) (DPO - a simpler alternative to RLHF)

## 9. GRPO: A Simpler Approach to RL for LLMs

### What is GRPO?

**Group Relative Policy Optimization (GRPO)** is a simplified variant of PPO that has gained popularity in recent LLM training, particularly with models like DeepSeek-R1. The key insight is that instead of computing advantages for each token individually, GRPO computes a single advantage score for entire responses and uses group-based normalization.

### The GRPO Recipe

Here's the high-level procedure:

1. **Start** with a base LLM and a dataset containing problem prompts paired with their final answers
2. For each training iteration:
   - Sample a batch of prompts from the dataset
   - For each prompt, sample G responses from the model (forming a "group")
   - Compute a reward for each response
   - Normalize rewards within each group to calculate advantages
   - Update the model using these advantages

### Key Differences from Standard PPO

1. **Response-Level Advantages**: While PPO typically computes advantages token-by-token, GRPO assigns the same advantage to all tokens in a response
2. **Group Normalization**: For each prompt, multiple responses are generated (a "group"), and advantages are normalized within this group
3. **Simpler Implementation**: No need for complex value function bootstrapping or GAE (Generalized Advantage Estimation)
4. **Fully Online**: Each batch of data is used for only one gradient update

Let's understand GRPO through a concrete example - the Countdown task:

In [None]:
# The Countdown Task: A Perfect Example for GRPO

"""
The Countdown game is a numerical puzzle where you must reach a target number 
using a set of given numbers and basic arithmetic operations (+, -, *, /).
Each number can only be used once.

Example:
    Target: 622
    Available Numbers: [25, 3, 6, 100]
    Solution: (100 × 6) + (25 − 3) = 622

This task is ideal for GRPO because:
1. Rewards are clear (correct answer = 1, wrong = 0)
2. The task encourages multi-step reasoning
3. Models can learn to verify and self-correct
"""

# Let's create a simple countdown problem generator
def generate_countdown_problem():
    """Generate a random countdown problem."""
    # Generate random numbers
    numbers = [random.choice([1, 2, 5, 10, 25, 50, 75, 100]) for _ in range(4)]
    
    # Create a target by randomly combining some numbers
    # (This is simplified - real countdown uses more complex generation)
    ops = ['+', '-', '*']
    op1, op2 = random.choices(ops, k=2)
    a, b, c = random.sample(numbers, 3)
    
    if op1 == '+':
        temp = a + b
    elif op1 == '-':
        temp = abs(a - b)  # Keep positive
    else:  # '*'
        temp = a * b
        
    if op2 == '+':
        target = temp + c
    elif op2 == '-':
        target = abs(temp - c)
    else:  # '*'
        target = temp * c
    
    return numbers, target

# Generate an example problem
numbers, target = generate_countdown_problem()
print(f"Countdown Problem:")
print(f"Target: {target}")
print(f"Available Numbers: {numbers}")
print(f"\nGoal: Create an equation using these numbers to reach {target}")

### GRPO Advantage Computation

The core of GRPO is its group-relative advantage computation. Let's implement and visualize this:

In [None]:
def compute_grpo_advantages(
    rewards: Float[Tensor, "group_size"],
) -> Float[Tensor, "group_size"]:
    """
    Compute GRPO advantages using group-relative normalization.
    
    For each prompt x with generated responses y_1, ..., y_G and rewards R_1, ..., R_G:
    1. Compute mean: μ = mean(R_1, ..., R_G)
    2. Compute std: σ = std(R_1, ..., R_G)
    3. Compute advantage for response i: A_i = (R_i - μ) / σ
    
    This normalization encourages responses better than average and
    discourages those worse than average.
    """
    mean_reward = rewards.mean()
    std_reward = rewards.std() + 1e-4  # Add epsilon for numerical stability
    
    # Normalize rewards to get advantages
    advantages = (rewards - mean_reward) / std_reward
    
    return advantages

# Let's see how GRPO advantages work in different scenarios
print("GRPO Advantage Examples for Countdown Task:\n")

# Scenario 1: Binary rewards (correct/incorrect)
print("1. Binary rewards (typical for countdown):")
rewards_binary = t.tensor([1.0, 1.0, 0.0, 0.0, 0.0])  # 2 correct, 3 incorrect
advantages_binary = compute_grpo_advantages(rewards_binary)
print(f"   Rewards: {rewards_binary.tolist()}")
print(f"   Advantages: {[f'{a:.3f}' for a in advantages_binary.tolist()]}")
print(f"   → Correct answers get positive advantage, incorrect get negative\n")

# Scenario 2: All incorrect (hard problem)
print("2. All incorrect responses:")
rewards_all_wrong = t.tensor([0.0, 0.0, 0.0, 0.0])
advantages_all_wrong = compute_grpo_advantages(rewards_all_wrong)
print(f"   Rewards: {rewards_all_wrong.tolist()}")
print(f"   Advantages: {advantages_all_wrong.tolist()}")
print(f"   → No learning signal when all responses are equally bad\n")

# Scenario 3: All correct (easy problem)
print("3. All correct responses:")
rewards_all_correct = t.tensor([1.0, 1.0, 1.0, 1.0])
advantages_all_correct = compute_grpo_advantages(rewards_all_correct)
print(f"   Rewards: {rewards_all_correct.tolist()}")
print(f"   Advantages: {advantages_all_correct.tolist()}")
print(f"   → No learning signal when all responses are equally good\n")

# Scenario 4: One standout response
print("4. One exceptional response:")
rewards_one_great = t.tensor([1.0, 0.0, 0.0, 0.0])
advantages_one_great = compute_grpo_advantages(rewards_one_great)
print(f"   Rewards: {rewards_one_great.tolist()}")
print(f"   Advantages: {[f'{a:.3f}' for a in advantages_one_great.tolist()]}")
print(f"   → Strong positive signal for the correct response")

### Reward Functions for Countdown

Now let's implement the reward functions used in R1-style training. These rewards encourage both correctness and good reasoning format:

In [None]:
def format_reward_func(completion: str) -> float:
    """
    Reward proper formatting: <think>...</think>\\n<answer>...</answer>
    
    This encourages the model to:
    1. Show its reasoning process in <think> tags
    2. Provide a clean final answer in <answer> tags
    """
    # Check for the expected format
    pattern = r"<think>.*?</think>\s*\n\s*<answer>.*?</answer>"
    if re.search(pattern, completion, re.DOTALL):
        # Check if answer contains only valid mathematical expression
        answer_match = re.search(r"<answer>(.*?)</answer>", completion, re.DOTALL)
        if answer_match:
            answer_content = answer_match.group(1).strip()
            # Only numbers, operators, parentheses, and whitespace allowed
            if re.match(r"^[\d+\-*/().\s]+$", answer_content):
                return 1.0  # Perfect format
            else:
                return 0.5  # Good structure, but answer has extra text
    return 0.0  # Wrong format

def equation_reward_func(completion: str, nums: list[int], target: int) -> float:
    """
    Reward correct mathematical solutions.
    
    Checks:
    1. The equation evaluates to the target
    2. All and only the given numbers are used
    3. Each number is used at most once
    """
    try:
        # Extract answer from tags
        match = re.search(r"<answer>(.*?)</answer>", completion)
        if not match:
            return 0.0
            
        equation = match.group(1).strip()
        
        # Extract all numbers from the equation
        used_numbers = [int(n) for n in re.findall(r"\d+", equation)]
        
        # Check if all numbers are used exactly once
        if sorted(used_numbers) != sorted(nums):
            return 0.0
            
        # Safely evaluate the equation
        try:
            result = eval(equation, {"__builtins__": {}}, {})
            if abs(float(result) - float(target)) < 1e-5:
                return 1.0
        except:
            return 0.0
            
    except Exception:
        return 0.0
    
    return 0.0

# Test the reward functions
print("Testing Reward Functions:\n")

# Good example
good_completion = """<think>
I need to make 24 using [2, 3, 4, 6].
Let me try: 6 * 4 = 24. Yes, that works!
</think>
<answer>6 * 4</answer>"""

print("1. Well-formatted correct answer:")
print(f"   Format reward: {format_reward_func(good_completion)}")
print(f"   Equation reward: {equation_reward_func(good_completion, [2, 3, 4, 6], 24)}")

# Bad format example
bad_format = "The answer is 6 * 4 = 24"
print("\n2. Correct answer, wrong format:")
print(f"   Format reward: {format_reward_func(bad_format)}")
print(f"   Equation reward: {equation_reward_func(bad_format, [2, 3, 4, 6], 24)}")

# Wrong answer example
wrong_answer = """<think>
Let me calculate...
</think>
<answer>2 + 3 + 4</answer>"""

print("\n3. Well-formatted wrong answer:")
print(f"   Format reward: {format_reward_func(wrong_answer)}")
print(f"   Equation reward: {equation_reward_func(wrong_answer, [2, 3, 4, 6], 24)}")

# Complex correct example
complex_correct = """<think>
I have [10, 25, 5, 2] and need to make 100.
Let me think: 25 * 5 = 125, too big.
Actually, (25 - 5) * 10 / 2 = 20 * 10 / 2 = 200 / 2 = 100!
</think>
<answer>(25 - 5) * 10 / 2</answer>"""

print("\n4. Complex reasoning with correct answer:")
print(f"   Format reward: {format_reward_func(complex_correct)}")
print(f"   Equation reward: {equation_reward_func(complex_correct, [10, 25, 5, 2], 100)}")

# Demonstrate combined reward
def compute_total_reward(completion: str, nums: list[int], target: int) -> float:
    """Combined reward function as used in R1-style training."""
    format_r = format_reward_func(completion)
    equation_r = equation_reward_func(completion, nums, target)
    return format_r + equation_r

print(f"\n   Total reward: {compute_total_reward(complex_correct, [10, 25, 5, 2], 100)}")

### GRPO Episode Creation

In GRPO, episode creation is different from standard RL. We generate multiple responses per prompt and assign uniform advantages:

In [None]:
def create_grpo_training_episodes(
    samples: list[dict],
    all_generations: list[list[int]], 
    tokenizer,
    generations_per_sample: int = 4
) -> tuple[dict, dict]:
    """
    Create training episodes for GRPO from generated responses.
    
    Key aspects:
    1. Groups responses by prompt (generations_per_sample per prompt)
    2. Computes rewards and normalizes them within each group
    3. Assigns the same advantage to ALL tokens in a response
    
    This is the key difference from PPO: uniform advantages across tokens!
    """
    # Group indices for responses
    groups = [
        list(range(i, i + generations_per_sample))
        for i in range(0, len(all_generations), generations_per_sample)
    ]
    
    all_query_ids = []
    all_response_ids = []
    all_advantages = []
    
    stats = {
        'rewards': [],
        'format_rewards': [],
        'equation_rewards': [],
        'response_lengths': []
    }
    
    for sample, group_indices in zip(samples, groups):
        # Get responses for this group
        response_ids = [all_generations[i] for i in group_indices]
        responses = tokenizer.batch_decode(response_ids, skip_special_tokens=False)
        
        # Compute rewards for each response
        rewards = []
        format_rewards = []
        equation_rewards = []
        
        for resp in responses:
            f_reward = format_reward_func(resp)
            e_reward = equation_reward_func(resp, sample['nums'], sample['target'])
            total_reward = f_reward + e_reward
            
            rewards.append(total_reward)
            format_rewards.append(f_reward)
            equation_rewards.append(e_reward)
        
        # Convert to tensors and compute GRPO advantages
        rewards_tensor = t.tensor(rewards, dtype=t.float32)
        advantages = compute_grpo_advantages(rewards_tensor)
        
        # Create episodes with uniform advantages
        for i, (resp_ids, adv) in enumerate(zip(response_ids, advantages)):
            # CRITICAL: Assign same advantage to ALL tokens in the response
            uniform_advantages = [adv.item()] * len(resp_ids)
            
            all_query_ids.append(sample['input_ids'])
            all_response_ids.append(resp_ids)
            all_advantages.append(uniform_advantages)
            
            stats['response_lengths'].append(len(resp_ids))
        
        # Record statistics
        stats['rewards'].extend(rewards)
        stats['format_rewards'].extend(format_rewards)
        stats['equation_rewards'].extend(equation_rewards)
    
    episodes = {
        'all_query_token_ids': all_query_ids,
        'all_response_token_ids': all_response_ids,
        'all_advantages': all_advantages
    }
    
    return episodes, stats

# Demonstrate episode creation with a mock example
print("GRPO Episode Creation Example:\n")

# Mock data
mock_sample = {
    'input_ids': [1, 2, 3, 4, 5],  # Tokenized prompt
    'nums': [2, 3, 4, 6],
    'target': 24
}

# Mock generated responses (token IDs)
mock_generations = [
    [10, 11, 12, 13],  # Response 1
    [20, 21, 22],      # Response 2  
    [30, 31, 32, 33, 34],  # Response 3
    [40, 41]           # Response 4
]

# Mock tokenizer decode (for demonstration)
class MockTokenizer:
    def batch_decode(self, ids, skip_special_tokens=False):
        # Return different quality responses
        return [
            "<think>6 * 4 = 24</think>\n<answer>6 * 4</answer>",  # Perfect
            "<think>Let me see...</think>\n<answer>2 + 3</answer>",  # Wrong
            "The answer is 24",  # Bad format
            "<think>Hmm</think>\n<answer>4 * 6</answer>"  # Good
        ]

mock_tokenizer = MockTokenizer()

# Create episodes
episodes, stats = create_grpo_training_episodes(
    [mock_sample], 
    mock_generations,
    mock_tokenizer,
    generations_per_sample=4
)

print(f"Generated {len(episodes['all_response_token_ids'])} episodes\n")

for i in range(4):
    print(f"Episode {i+1}:")
    print(f"  Response length: {stats['response_lengths'][i]} tokens")
    print(f"  Rewards: format={stats['format_rewards'][i]:.1f}, "
          f"equation={stats['equation_rewards'][i]:.1f}, "
          f"total={stats['rewards'][i]:.1f}")
    print(f"  Advantage: {episodes['all_advantages'][i][0]:.3f}")
    print(f"  All tokens get same advantage: {len(episodes['all_advantages'][i])} × {episodes['all_advantages'][i][0]:.3f}")
    print()

### GRPO vs PPO: Visualizing the Difference

Let's visualize how GRPO's uniform advantage assignment differs from PPO's token-level advantages:

In [None]:
# Visualizing GRPO vs PPO advantage assignment
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Simulated token sequence for a countdown solution
token_labels = ['<think>', 'I', 'need', 'to', 'calculate', '25', '*', '4', '=', '100', '</think>', '\\n', '<answer>', '25', '*', '4', '</answer>']
num_tokens = len(token_labels)

# GRPO: Successful response (all tokens get same positive advantage)
grpo_advantages_success = [0.8] * num_tokens
ax1.bar(range(num_tokens), grpo_advantages_success, color='green', alpha=0.7)
ax1.set_xticks(range(num_tokens))
ax1.set_xticklabels(token_labels, rotation=45, ha='right')
ax1.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax1.set_ylabel('Advantage')
ax1.set_title('GRPO: Successful Response (Correct Answer)\\nAll tokens get the same positive advantage')
ax1.set_ylim(-1.5, 1.5)
ax1.grid(axis='y', alpha=0.3)

# GRPO: Failed response (all tokens get same negative advantage)
token_labels_fail = ['<think>', 'Maybe', 'it\\'s', '25', '+', '4', '?', '</think>', '\\n', '<answer>', '25', '+', '4', '</answer>']
num_tokens_fail = len(token_labels_fail)
grpo_advantages_fail = [-0.8] * num_tokens_fail
ax2.bar(range(num_tokens_fail), grpo_advantages_fail, color='red', alpha=0.7)
ax2.set_xticks(range(num_tokens_fail))
ax2.set_xticklabels(token_labels_fail, rotation=45, ha='right')
ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax2.set_ylabel('Advantage')
ax2.set_xlabel('Token Position')
ax2.set_title('GRPO: Failed Response (Wrong Answer)\\nAll tokens get the same negative advantage')
ax2.set_ylim(-1.5, 1.5)
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\\nKey GRPO Characteristics:")
print("• ✅ Correct responses: ALL tokens are reinforced equally")
print("• ❌ Wrong responses: ALL tokens are discouraged equally")
print("• 📊 The model learns which COMPLETE responses lead to success")
print("• 🎯 This encourages coherent, multi-step reasoning patterns")

### GRPO Policy Gradient

Since GRPO uses fresh samples for each update, the policy gradient simplifies significantly:

In [None]:
def compute_grpo_loss(
    logits: Float[Tensor, "batch seq_len vocab_size"],
    ref_logits: Float[Tensor, "batch seq_len vocab_size"],
    advantages: Float[Tensor, "batch seq_len"],
    labels: Int[Tensor, "batch seq_len"],
    kl_coef: float = 0.001
) -> tuple[Float[Tensor, ""], dict]:
    """
    Compute GRPO loss with KL penalty.
    
    Key insight: In the fully online setting (one update per batch),
    the importance ratio π/π_old = 1, so PPO reduces to vanilla policy gradient!
    
    Loss = -E[log π(a|s) * A] + β * KL(π || π_ref)
    """
    # Get log probabilities for taken actions
    log_probs = logits.log_softmax(dim=-1)
    ref_log_probs = ref_logits.log_softmax(dim=-1)
    
    # Gather log probs of actual tokens
    batch_size, seq_len = labels.shape
    token_log_probs = log_probs.gather(
        dim=-1, 
        index=labels.unsqueeze(-1)
    ).squeeze(-1)
    
    ref_token_log_probs = ref_log_probs.gather(
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)
    
    # Create mask for valid tokens (not padding)
    mask = (labels != -100).float()
    
    # Policy gradient loss (REINFORCE with advantages)
    pg_loss = -(token_log_probs * advantages * mask).sum() / mask.sum()
    
    # KL penalty using the k3 estimator
    # KL = E[π_ref/π - log(π_ref/π) - 1]
    ratio = (ref_token_log_probs - token_log_probs).exp()
    kl_penalty = (ratio - ratio.log() - 1) * mask
    kl_loss = kl_coef * kl_penalty.sum() / mask.sum()
    
    # Total loss
    total_loss = pg_loss + kl_loss
    
    metrics = {
        'pg_loss': pg_loss.item(),
        'kl_penalty': kl_loss.item() / kl_coef,  # Report actual KL
        'kl_loss': kl_loss.item()
    }
    
    return total_loss, metrics

# Demonstrate the loss calculation
print("GRPO Loss Calculation Example:\n")

# Mock data
batch_size, seq_len, vocab_size = 2, 10, 100
logits = t.randn(batch_size, seq_len, vocab_size)
ref_logits = t.randn(batch_size, seq_len, vocab_size)

# Advantages: one response has positive, one has negative
advantages = t.tensor([
    [0.5] * seq_len,  # Good response
    [-0.5] * seq_len  # Bad response
])

# Labels (token IDs, with -100 for padding)
labels = t.randint(0, vocab_size, (batch_size, seq_len))
labels[:, -2:] = -100  # Last 2 positions are padding

# Compute loss
loss, metrics = compute_grpo_loss(logits, ref_logits, advantages, labels)

print(f"Total loss: {loss:.4f}")
print(f"Policy gradient loss: {metrics['pg_loss']:.4f}")
print(f"KL penalty: {metrics['kl_penalty']:.4f}")
print(f"KL loss component: {metrics['kl_loss']:.4f}")

print("\nInterpretation:")
print("• PG loss encourages/discourages actions based on advantages")
print("• KL penalty prevents model from deviating too far from reference")
print("• The balance is controlled by kl_coef")

### Emergent Behaviors from GRPO

One of the most exciting aspects of GRPO (as demonstrated in DeepSeek-R1) is how it leads to emergent reasoning behaviors:

In [None]:
# Demonstrating emergent behaviors in GRPO training
print("Emergent Behaviors in GRPO Training:\n")

# Example responses showing progression over training
print("Early Training (Simple attempts):")
early_responses = [
    "<think>25 * 4 = 100</think>\n<answer>25 * 4</answer>",
    "<think>Try multiplication</think>\n<answer>5 * 20</answer>",
    "<think>100 = 100</think>\n<answer>100</answer>",
    "<think>25 + 75</think>\n<answer>25 + 75</answer>"
]

# Simulate rewards for early training
early_problem = {'nums': [25, 4, 75, 1], 'target': 100}
print(f"\nProblem: Make {early_problem['target']} using {early_problem['nums']}")

for i, resp in enumerate(early_responses):
    f_r = format_reward_func(resp)
    e_r = equation_reward_func(resp, early_problem['nums'], early_problem['target'])
    print(f"  Response {i+1}: format={f_r}, equation={e_r}, total={f_r + e_r}")

print("\n" + "-"*60 + "\n")

print("Later Training (Emergent verification & self-correction):")
later_responses = [
    """<think>
I need to make 100 using [25, 4, 75, 1].
Let me try: 25 * 4 = 100. 
Wait, let me verify: 25 * 4 = 100 ✓
</think>
<answer>25 * 4</answer>""",
    
    """<think>
Target is 100 with [25, 4, 75, 1].
Maybe 75 + 25 = 100? Yes!
Actually, let me double-check: 75 + 25 = 100 ✓
</think>
<answer>75 + 25</answer>""",
    
    """<think>
I have [25, 4, 75, 1] to make 100.
First attempt: 25 * 4 = 100
Hmm, but I also have 75 and 1 unused.
Actually, the problem says use the given numbers, not all.
So 25 * 4 = 100 is correct!
</think>
<answer>25 * 4</answer>""",
    
    """<think>
Need 100 from [25, 4, 75, 1].
Let's see... 75 + 25 = 100
Or alternatively: 25 * 4 = 100
Both work! I'll go with the simpler one.
</think>
<answer>25 * 4</answer>"""
]

for i, resp in enumerate(later_responses):
    f_r = format_reward_func(resp)
    e_r = equation_reward_func(resp, early_problem['nums'], early_problem['target'])
    print(f"  Response {i+1}: format={f_r}, equation={e_r}, total={f_r + e_r}")
    
print("\n" + "-"*60 + "\n")

print("Observed Emergent Behaviors:")
print("1. ✓ **Self-verification**: Model starts checking its own answers")
print("2. 🔄 **Backtracking**: Model corrects itself when noticing errors")  
print("3. 🤔 **Multiple attempts**: Model tries different approaches")
print("4. 📝 **Explanation**: Model explains its reasoning process")
print("5. 🎯 **Meta-reasoning**: Model reasons about the problem constraints")

print("\nKey Insight:")
print("These behaviors were NOT explicitly programmed or instructed!")
print("They emerge naturally from the GRPO training process when")
print("the reward structure incentivizes correct final answers.")

### GRPO Training Loop Structure

Here's the structure of a GRPO training loop, showing how all components fit together:

In [None]:
# Pseudocode for GRPO training loop
print("GRPO Training Loop Structure:\n")

print("""
for iteration in range(num_iterations):
    # 1. Sample batch of prompts from dataset
    prompts = sample_prompts(dataset, batch_size)
    
    # 2. Generate multiple responses per prompt
    all_responses = []
    for prompt in prompts:
        for _ in range(generations_per_sample):
            response = model.generate(prompt, temperature=1.0)
            all_responses.append(response)
    
    # 3. Create training episodes
    episodes, stats = create_grpo_training_episodes(
        prompts, 
        all_responses,
        generations_per_sample
    )
    
    # 4. Compute loss and update model
    # Key: Only ONE gradient update per batch!
    optimizer.zero_grad()
    
    for batch in make_batches(episodes):
        # Get model and reference model outputs
        logits = model(batch.input_ids)
        with torch.no_grad():
            ref_logits = ref_model(batch.input_ids)
        
        # Compute GRPO loss
        loss = compute_grpo_loss(
            logits, 
            ref_logits,
            batch.advantages,
            batch.labels,
            kl_coef
        )
        
        loss.backward()
    
    optimizer.step()
    
    # 5. Log metrics
    log_metrics(stats)
""")

print("\nKey Differences from Standard PPO:")
print("1. **Generation**: G responses per prompt (not just 1)")
print("2. **Advantages**: Computed at response level, not token level")
print("3. **Updates**: Single gradient step per batch (fully online)")
print("4. **Simplicity**: No value function training or bootstrapping")

print("\n" + "="*60 + "\n")

# Show a complete mini example
print("Mini GRPO Training Example:\n")

# Setup
@dataclass
class GRPOConfig:
    model_name: str = "gpt2"
    generations_per_sample: int = 4
    kl_coef: float = 0.001
    learning_rate: float = 1e-5
    batch_size: int = 2

config = GRPOConfig()

# Create countdown prompts
def create_countdown_prompt(nums: list[int], target: int) -> str:
    return f"""Using the numbers {nums}, create an equation that equals {target}.
You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
Show your work in <think> </think> tags. 
Return the final equation in <answer> </answer> tags.

Let me solve this step by step.
<think>"""

# Example training data
training_problems = [
    {'nums': [2, 3, 4, 6], 'target': 24},
    {'nums': [5, 10, 25, 2], 'target': 100}
]

print("Training Problems:")
for i, prob in enumerate(training_problems):
    print(f"{i+1}. Make {prob['target']} using {prob['nums']}")
    
print("\nPrompt format:")
print(create_countdown_prompt([2, 3, 4, 6], 24)[:200] + "...")

print("\n" + "-"*60)
print("\nGRPO Training Process:")
print("1. Generate 4 responses for each problem")
print("2. Compute rewards (format + correctness)")  
print("3. Normalize rewards within each group → advantages")
print("4. Update model to increase probability of good responses")
print("5. KL penalty prevents excessive deviation")

print("\nExpected Outcomes:")
print("• Model learns to format responses correctly")
print("• Model discovers successful solution patterns")
print("• Emergent behaviors like verification appear")
print("• Performance improves on held-out test problems")

## 4.6 Limitations of Task-Specific GRPO: Why Countdown Success Doesn't Generalize

While the GRPO implementation we've explored shows impressive results on countdown tasks and mathematical reasoning, it's crucial to understand why this success doesn't automatically translate to general reasoning capabilities. This section examines the fundamental limitations when optimizing GRPO for specific tasks.

### Task-Specific Reward Functions Create Narrow Optimization

The countdown task uses two specific reward functions that create a very narrow optimization target:

1. **Format Reward**: Enforces strict `<think>...</think><answer>...</answer>` structure
2. **Equation Reward**: Validates mathematical correctness and number usage

```python
# Highly specific reward structure
def compute_countdown_reward(completion, nums, target):
    format_reward = check_format(completion)  # Binary: 0 or 1
    equation_reward = check_equation(completion, nums, target)  # Binary: 0 or 1
    return format_reward + equation_reward  # Total: 0, 1, or 2
```

This creates several issues for generalization:

### Problem 1: Overfitting to Format Rather Than Reasoning

GRPO with uniform advantages teaches the model to maximize reward by any means necessary. In countdown tasks, the model learns that:

- Always outputting the exact format gets partial reward (0.5-1.0)
- Correct equations in the right format get maximum reward (2.0)
- Any deviation from format gets zero reward

This leads to "format hacking" where the model prioritizes structure over reasoning quality:

In [None]:
# Example: Model learns to game the format reward
example_responses = [
    # Response 1: Good reasoning, wrong format → 0 reward
    "Let me think about this. I'll try 25 × 4 = 100, then 100 - 3 = 97. Actually, let me try...",
    
    # Response 2: No reasoning, correct format → 1 reward
    "<think>numbers</think>\n<answer>random equation</answer>",
    
    # Response 3: Minimal reasoning, correct answer → 2 reward
    "<think>try multiply</think>\n<answer>25 × 4 - 3</answer>"
]

# GRPO will prefer Response 2 and 3 over Response 1, even though 1 shows better reasoning process!

### Problem 2: Binary Rewards Discourage Exploration

The countdown task uses binary rewards (0 or 1 for each component), which creates a sparse reward landscape:

In [None]:
# Sparse reward problem in countdown tasks
def visualize_reward_sparsity():
    """
    In countdown, there's no gradient between "almost correct" and "totally wrong"
    """
    attempts = [
        ("25 × 4 - 3", 97, 97),      # Correct → reward = 2
        ("25 × 4 - 2", 98, 97),      # Off by 1 → reward = 1 (format only)
        ("25 × 3 + 6", 81, 97),      # Off by 16 → reward = 1 (format only)
        ("Invalid syntax", None, 97), # Syntax error → reward = 0
    ]
    
    print("Equation        | Result | Target | Reward | Learning Signal")
    print("-" * 60)
    for eq, result, target in attempts:
        if result is None:
            reward = 0
        elif result == target:
            reward = 2
        else:
            reward = 1  # Format is correct
        
        signal = "STRONG" if reward == 2 else ("WEAK" if reward == 1 else "NONE")
        print(f"{eq:15} | {str(result):6} | {target:6} | {reward:6} | {signal}")

visualize_reward_sparsity()

### Problem 3: Uniform Advantages Inhibit General Reasoning

GRPO's uniform advantage assignment means all tokens in a response get the same credit. This is particularly problematic for general reasoning:

In [None]:
# Problem with uniform advantages in general reasoning
general_reasoning_example = """
Question: Why do leaves change color in autumn?

Model Response:
<think>
Leaves change color because of temperature. [INCORRECT]
Actually, it's due to chlorophyll breakdown. [CORRECT]
This reveals other pigments like carotenoids. [CORRECT]
The process is triggered by shorter days. [CORRECT]
Trees also stop producing chlorophyll. [CORRECT]
</think>
<answer>
Leaves change color due to chlorophyll breakdown revealing other pigments.
</answer>
"""

# In GRPO with uniform advantages:
# - ALL tokens get positive advantage if answer is correct
# - INCLUDING the incorrect statement about temperature!
# - Model doesn't learn which parts of reasoning were actually helpful

print("With uniform advantages:")
print("- Correct final answer → All tokens rewarded equally")
print("- Incorrect reasoning steps get same positive advantage")
print("- Model learns: 'Any thinking + correct answer = good'")
print("- No pressure to improve reasoning quality")

### Problem 4: Mathematical vs General Reasoning Requirements

Mathematical reasoning (like countdown) has fundamentally different properties than general reasoning:

In [None]:
# Comparison of task properties
task_comparison = {
    "Property": ["Countdown/Math", "General Reasoning"],
    "Answer Space": ["Single correct answer", "Multiple valid answers"],
    "Verification": ["Algorithmic (eval)", "Requires human judgment"],
    "Reasoning Path": ["Any path to correct result", "Quality of path matters"],
    "Partial Credit": ["Binary (right/wrong)", "Degrees of correctness"],
    "Format Importance": ["Just for parsing", "Part of communication"],
}

# Display comparison
for prop, values in task_comparison.items():
    if prop == "Property":
        print(f"{'Property':<20} | {'Countdown/Math':<25} | {'General Reasoning':<25}")
        print("-" * 75)
    else:
        print(f"{prop:<20} | {values[0]:<25} | {values[1]:<25}")

### The Generalization Gap: Why Countdown Success Doesn't Transfer

When we train GRPO specifically for countdown tasks, the model learns highly specialized behaviors that don't generalize:

In [None]:
# What the model actually learns from countdown-specific GRPO
learned_behaviors = {
    "Positive Behaviors (Countdown-Specific)": [
        "✓ Always use <think></think> tags",
        "✓ Try multiple arithmetic combinations",
        "✓ Verify equations evaluate correctly",
        "✓ Use each number exactly once",
        "✓ Backtrack when result doesn't match",
    ],
    
    "Missing Behaviors (Needed for General Reasoning)": [
        "✗ Explain WHY a solution works",
        "✗ Consider multiple valid perspectives",
        "✗ Build on previous knowledge",
        "✗ Admit uncertainty appropriately",
        "✗ Provide context and nuance",
        "✗ Adapt explanation to audience",
    ],
    
    "Harmful Behaviors (Over-optimization)": [
        "⚠ Rigid adherence to format over clarity",
        "⚠ Preference for arithmetic over logic",
        "⚠ Binary thinking (right/wrong only)",
        "⚠ Ignoring reasoning quality if answer is correct",
    ]
}

for category, behaviors in learned_behaviors.items():
    print(f"\n{category}:")
    for behavior in behaviors:
        print(f"  {behavior}")

### Key Takeaways: Requirements for General Reasoning with GRPO

To make GRPO work for general reasoning (not just mathematical tasks), we would need:

In [None]:
# Requirements for general reasoning with GRPO
requirements = {
    "1. Richer Reward Functions": [
        "- Continuous rewards (not just binary)",
        "- Multiple evaluation criteria",
        "- Partial credit for good reasoning",
        "- Human feedback or strong reward models"
    ],
    
    "2. Better Credit Assignment": [
        "- Token-level advantages (like PPO)",
        "- Process rewards during reasoning",
        "- Identifying which steps helped/hurt",
        "- Not uniform advantages across responses"
    ],
    
    "3. Diverse Training Tasks": [
        "- Not just mathematical problems",
        "- Open-ended questions",
        "- Creative and analytical tasks",
        "- Real-world reasoning scenarios"
    ],
    
    "4. Evaluation Beyond Correctness": [
        "- Reasoning quality metrics",
        "- Explanation clarity",
        "- Appropriate uncertainty",
        "- Adaptability to context"
    ]
}

for requirement, details in requirements.items():
    print(f"\n{requirement}:")
    for detail in details:
        print(f"  {detail}")

print("\n" + "="*80)
print("\nConclusion:")
print("While GRPO excels at specific tasks like countdown through emergent behaviors")
print("(self-verification, backtracking), these behaviors are task-specific optimizations.")
print("General reasoning requires fundamentally different reward structures and training")
print("approaches that value reasoning quality, not just final correctness.")

### The R1-Zero Insight: Task-Specific Excellence vs General Intelligence

The R1-Zero demonstrates an important principle: **task-specific optimization can produce impressive emergent behaviors** (like self-verification and backtracking in mathematical reasoning), but these behaviors are deeply tied to the reward structure and task characteristics.

The success of GRPO on countdown tasks shows:
- ✅ GRPO can train models without human feedback
- ✅ Emergent behaviors arise from simple reward signals
- ✅ Efficient training is possible with proper implementation

But it also reveals fundamental limitations:
- ❌ Task-specific rewards create narrow capabilities
- ❌ Uniform advantages don't teach reasoning quality
- ❌ Binary rewards discourage exploration
- ❌ Mathematical verification doesn't transfer to general reasoning

**The key lesson**: While GRPO represents an important step toward self-improving AI systems, achieving general reasoning capabilities requires moving beyond task-specific optimization to more sophisticated reward modeling and credit assignment mechanisms.

### Summary: GRPO vs PPO

**GRPO (Group Relative Policy Optimization)**:
- ✅ Simpler to implement and understand
- ✅ No value function needed
- ✅ Effective for response-level rewards
- ✅ Encourages emergent reasoning behaviors
- ❌ Less fine-grained credit assignment
- ❌ Requires multiple generations per prompt

**PPO (Proximal Policy Optimization)**:
- ✅ Token-level credit assignment
- ✅ More flexible and general
- ✅ Can handle dense rewards
- ✅ Established track record
- ❌ More complex implementation
- ❌ Requires value function training

### Key Takeaways on GRPO

1. **Simplicity is Powerful**: GRPO shows that simpler approaches can yield impressive results
2. **Emergent Behaviors**: With the right reward structure, models develop sophisticated reasoning patterns
3. **Group Normalization**: Comparing responses within groups provides strong learning signals
4. **R1-Zero Success**: GRPO enabled DeepSeek-R1 to develop reasoning abilities without human feedback
5. **Practical Choice**: Use GRPO when you have clear response-level rewards and want simpler implementation

### The Countdown Task as a Learning Tool

The countdown task perfectly illustrates GRPO's strengths:
- **Clear success criteria**: Solutions either work or don't
- **Multiple valid approaches**: Encourages exploration
- **Reasoning required**: Not just pattern matching
- **Verifiable**: Models can check their own work

This makes it an ideal testbed for understanding how GRPO enables models to develop complex reasoning behaviors from simple reward signals.