# 05 â€” GRPO Implementation: Group Relative Policy Optimization

> **Purpose:** Build a complete GRPO trainer from scratch. GRPO is DeepSeek's critic-free algorithm that powers R1 and DeepSeekMath.

**Key innovation:** Eliminate the value network by computing advantages relative to a group of sampled responses.

$$J_{GRPO}(\theta) = \mathbb{E}\left[\min(r_t A_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A_t)\right] - \beta D_{KL}(\pi_\theta \| \pi_{ref})$$

---

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from typing import List, Dict, Tuple, Optional
import copy
import numpy as np

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cpu


## 1. TinyLM + Reference Model Wrapper

GRPO requires a frozen reference model to compute KL divergence.

In [2]:
class TinyLM(nn.Module):
    """Minimal language model for RL experiments."""
    def __init__(self, vocab_size=100, hidden_size=64, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)
        self.head = nn.Linear(hidden_size, vocab_size)
        self.vocab_size = vocab_size
    
    def forward(self, x):
        h = self.embed(x)
        h, _ = self.rnn(h)
        return self.head(h)
    
    def get_log_probs(self, input_ids: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Compute log probabilities for each token in labels.
        
        Args:
            input_ids: (batch, seq_len) input token IDs
            labels: (batch, seq_len) target token IDs
        
        Returns:
            log_probs: (batch, seq_len) log probability of each label token
        """
        logits = self(input_ids)  # (batch, seq_len, vocab)
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Gather log probs for the actual tokens
        token_log_probs = torch.gather(
            log_probs, dim=-1, index=labels.unsqueeze(-1)
        ).squeeze(-1)
        
        return token_log_probs


class ReferenceModel:
    """
    Wrapper for frozen reference model.
    Used to compute KL divergence penalty.
    """
    
    def __init__(self, model: nn.Module):
        self.model = copy.deepcopy(model)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
    
    def get_log_probs(self, input_ids: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """Get log probs from frozen reference model."""
        with torch.no_grad():
            return self.model.get_log_probs(input_ids, labels)
    
    def to(self, device):
        self.model = self.model.to(device)
        return self

In [3]:
# TEST: Model and reference model

policy = TinyLM().to(device)
ref_model = ReferenceModel(policy)

# Sample batch
batch_size, seq_len = 4, 20
input_ids = torch.randint(0, 100, (batch_size, seq_len), device=device)
labels = torch.randint(0, 100, (batch_size, seq_len), device=device)

policy_log_probs = policy.get_log_probs(input_ids, labels)
ref_log_probs = ref_model.get_log_probs(input_ids, labels)

print(f"Policy log probs shape: {policy_log_probs.shape}")
print(f"Reference log probs shape: {ref_log_probs.shape}")
print(f"Initially equal (same model): {torch.allclose(policy_log_probs, ref_log_probs)}")

Policy log probs shape: torch.Size([4, 20])
Reference log probs shape: torch.Size([4, 20])
Initially equal (same model): True


## 2. Group-Relative Advantage Computation

The core innovation: compute advantages by comparing responses within a group.

$$A_i = \frac{r_i - \bar{r}}{\sigma_r + \delta}$$

In [4]:
def compute_grpo_advantages(rewards: torch.Tensor, 
                            group_size: int,
                            normalize: str = 'local') -> torch.Tensor:
    """
    Compute group-relative advantages for GRPO.
    
    Args:
        rewards: (batch_size,) reward for each response
        group_size: Number of responses per prompt (K)
        normalize: 'local' (per-group) or 'global' (batch-level std)
    
    Returns:
        advantages: (batch_size,) normalized advantages
    """
    batch_size = rewards.shape[0]
    num_groups = batch_size // group_size
    eps = 1e-8
    
    # Reshape to (num_groups, group_size)
    grouped = rewards.view(num_groups, group_size)
    
    # Compute group mean (baseline)
    group_mean = grouped.mean(dim=1, keepdim=True)
    
    if normalize == 'local':
        # GRPO original: per-group std
        group_std = grouped.std(dim=1, keepdim=True)
        advantages = (grouped - group_mean) / (group_std + eps)
    elif normalize == 'global':
        # Hybrid: use batch-level std for stability
        batch_std = rewards.std()
        advantages = (grouped - group_mean) / (batch_std + eps)
    else:
        raise ValueError(f"Unknown normalize: {normalize}")
    
    return advantages.view(batch_size)

In [5]:
# TEST: Group-relative advantages

# 4 prompts, 4 responses each = 16 total
rewards = torch.tensor([
    1.0, 0.8, 0.6, 0.4,   # Prompt 1: varied
    1.0, 1.0, 1.0, 1.0,   # Prompt 2: all correct (easy)
    0.0, 0.0, 0.0, 0.0,   # Prompt 3: all wrong (hard)
    1.0, 0.0, 0.5, 0.5,   # Prompt 4: mixed
])

adv_local = compute_grpo_advantages(rewards, group_size=4, normalize='local')
adv_global = compute_grpo_advantages(rewards, group_size=4, normalize='global')

print("Rewards per prompt:")
for i in range(4):
    print(f"  Prompt {i+1}: {rewards[i*4:(i+1)*4].numpy()}")

print(f"\nLocal normalization (max |A|): {adv_local.abs().max():.2f}")
print(f"Global normalization (max |A|): {adv_global.abs().max():.2f}")
print("\nLocal explodes on easy/hard prompts (std=0), global stays stable")

Rewards per prompt:
  Prompt 1: [1.  0.8 0.6 0.4]
  Prompt 2: [1. 1. 1. 1.]
  Prompt 3: [0. 0. 0. 0.]
  Prompt 4: [1.  0.  0.5 0.5]

Local normalization (max |A|): 1.22
Global normalization (max |A|): 1.15

Local explodes on easy/hard prompts (std=0), global stays stable


## 3. KL Divergence Computation

In [6]:
def compute_kl_divergence(policy_log_probs: torch.Tensor,
                          ref_log_probs: torch.Tensor,
                          mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Compute KL divergence between policy and reference.
    
    KL(policy || ref) = E_policy[log(policy) - log(ref)]
    
    Args:
        policy_log_probs: (batch, seq_len) log probs from policy
        ref_log_probs: (batch, seq_len) log probs from reference
        mask: Optional (batch, seq_len) mask for valid tokens
    
    Returns:
        kl: (batch,) per-sequence KL divergence
    """
    kl_per_token = policy_log_probs - ref_log_probs  # (batch, seq_len)
    
    if mask is not None:
        kl_per_token = kl_per_token * mask
        kl_per_sequence = kl_per_token.sum(dim=1) / mask.sum(dim=1).clamp(min=1)
    else:
        kl_per_sequence = kl_per_token.mean(dim=1)
    
    return kl_per_sequence

## 4. GRPO Loss Function

Clipped objective with KL penalty and token-level or sequence-level aggregation.

In [7]:
def compute_grpo_loss(policy_log_probs: torch.Tensor,
                      old_log_probs: torch.Tensor,
                      ref_log_probs: torch.Tensor,
                      advantages: torch.Tensor,
                      response_lengths: torch.Tensor,
                      epsilon: float = 0.2,
                      beta: float = 0.1,
                      loss_aggregation: str = 'token') -> Dict[str, torch.Tensor]:
    """
    Compute complete GRPO loss with clipping and KL penalty.
    
    Args:
        policy_log_probs: (batch, seq_len) current policy log probs
        old_log_probs: (batch, seq_len) log probs when responses were generated
        ref_log_probs: (batch, seq_len) reference model log probs
        advantages: (batch,) per-response advantages
        response_lengths: (batch,) length of each response
        epsilon: Clipping parameter
        beta: KL penalty coefficient
        loss_aggregation: 'token' or 'sequence'
    
    Returns:
        Dict with 'loss', 'policy_loss', 'kl_loss', 'clip_fraction'
    """
    batch_size, seq_len = policy_log_probs.shape
    
    # Create mask for valid tokens
    mask = torch.arange(seq_len, device=policy_log_probs.device).unsqueeze(0) < response_lengths.unsqueeze(1)
    mask = mask.float()
    
    # Expand advantages to token level
    advantages_expanded = advantages.unsqueeze(1)  # (batch, 1)
    
    # Compute probability ratio
    log_ratio = policy_log_probs - old_log_probs
    ratio = torch.exp(log_ratio)  # (batch, seq_len)
    
    # Clipped ratio
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    
    # PPO-style objective: min(ratio * A, clipped_ratio * A)
    obj1 = ratio * advantages_expanded
    obj2 = clipped_ratio * advantages_expanded
    ppo_obj = torch.min(obj1, obj2)
    
    # Track clip fraction (for monitoring)
    clipped = (ratio < 1 - epsilon) | (ratio > 1 + epsilon)
    clip_fraction = (clipped * mask).sum() / mask.sum()
    
    # Aggregate policy loss
    if loss_aggregation == 'token':
        # Each token contributes equally
        policy_loss = -(ppo_obj * mask).sum() / mask.sum()
    elif loss_aggregation == 'sequence':
        # Each sequence contributes equally
        per_seq_loss = -(ppo_obj * mask).sum(dim=1) / response_lengths.float()
        policy_loss = per_seq_loss.mean()
    else:
        raise ValueError(f"Unknown loss_aggregation: {loss_aggregation}")
    
    # KL divergence penalty
    kl = compute_kl_divergence(policy_log_probs, ref_log_probs, mask)
    kl_loss = beta * kl.mean()
    
    # Total loss
    total_loss = policy_loss + kl_loss
    
    return {
        'loss': total_loss,
        'policy_loss': policy_loss,
        'kl_loss': kl_loss,
        'kl': kl.mean(),
        'clip_fraction': clip_fraction,
    }

## 5. Experience Buffer for Multiple PPO Epochs

Store generated experiences to train on them multiple times.

In [8]:
class ExperienceBuffer:
    """
    Store rollout data for multiple PPO epochs.
    
    Each entry contains:
    - input_ids: Prompt + response tokens
    - labels: Response tokens (for log prob computation)
    - old_log_probs: Log probs when response was generated
    - rewards: Reward for each response
    - response_lengths: Length of each response
    """
    
    def __init__(self):
        self.data = []
    
    def add(self, input_ids: torch.Tensor, labels: torch.Tensor,
            old_log_probs: torch.Tensor, rewards: torch.Tensor,
            response_lengths: torch.Tensor):
        """Add a batch of experiences."""
        self.data.append({
            'input_ids': input_ids.detach(),
            'labels': labels.detach(),
            'old_log_probs': old_log_probs.detach(),
            'rewards': rewards.detach(),
            'response_lengths': response_lengths.detach(),
        })
    
    def get_all(self) -> Dict[str, torch.Tensor]:
        """Concatenate all stored experiences."""
        if not self.data:
            return None
        
        return {
            key: torch.cat([d[key] for d in self.data], dim=0)
            for key in self.data[0].keys()
        }
    
    def clear(self):
        """Clear all stored experiences."""
        self.data = []
    
    def __len__(self):
        return sum(d['input_ids'].shape[0] for d in self.data)

## 6. Complete GRPO Trainer

In [9]:
class GRPOTrainer:
    """
    Full GRPO trainer implementation.
    
    Features:
    - Group sampling with configurable group size
    - Group-relative advantage computation
    - Reference model KL penalty
    - Multiple PPO epochs per rollout
    - Configurable loss aggregation
    """
    
    def __init__(self,
                 policy: nn.Module,
                 ref_model: ReferenceModel,
                 reward_fn,
                 group_size: int = 4,
                 epsilon: float = 0.2,
                 beta: float = 0.1,
                 ppo_epochs: int = 4,
                 mini_batch_size: int = 8,
                 loss_aggregation: str = 'token',
                 normalize_advantages: str = 'global',
                 lr: float = 1e-5,
                 max_grad_norm: float = 1.0):
        
        self.policy = policy
        self.ref_model = ref_model
        self.reward_fn = reward_fn
        self.group_size = group_size
        self.epsilon = epsilon
        self.beta = beta
        self.ppo_epochs = ppo_epochs
        self.mini_batch_size = mini_batch_size
        self.loss_aggregation = loss_aggregation
        self.normalize_advantages = normalize_advantages
        self.max_grad_norm = max_grad_norm
        
        self.optimizer = torch.optim.AdamW(policy.parameters(), lr=lr)
        self.experience_buffer = ExperienceBuffer()
    
    def collect_rollouts(self, prompts: torch.Tensor, max_length: int = 50) -> Dict:
        """
        Generate responses for prompts and collect experience.
        
        Args:
            prompts: (num_prompts, prompt_len) prompt token IDs
            max_length: Maximum response length
        
        Returns:
            Stats dict with generation metrics
        """
        num_prompts = prompts.shape[0]
        device = prompts.device
        
        # Repeat each prompt group_size times
        expanded_prompts = prompts.repeat_interleave(self.group_size, dim=0)
        
        # Generate responses
        self.policy.eval()
        all_tokens = []
        all_log_probs = []
        
        with torch.no_grad():
            tokens = expanded_prompts.clone()
            for _ in range(max_length):
                logits = self.policy(tokens)[:, -1, :]
                dist = Categorical(logits=logits)
                next_token = dist.sample()
                log_prob = dist.log_prob(next_token)
                
                all_log_probs.append(log_prob)
                tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
        
        # Stack log probs
        response_log_probs = torch.stack(all_log_probs, dim=1)  # (batch, max_length)
        
        # For simplicity, use full length as response_lengths
        response_lengths = torch.full((tokens.shape[0],), max_length, device=device)
        
        # Get rewards
        rewards = self.reward_fn(tokens)
        
        # Store in buffer
        self.experience_buffer.add(
            input_ids=tokens,
            labels=tokens,  # Simplified: same as input
            old_log_probs=response_log_probs,
            rewards=rewards,
            response_lengths=response_lengths,
        )
        
        return {
            'mean_reward': rewards.mean().item(),
            'num_responses': tokens.shape[0],
        }
    
    def train_step(self) -> Dict:
        """
        Run multiple PPO epochs on collected experience.
        
        Returns:
            Training stats
        """
        experience = self.experience_buffer.get_all()
        if experience is None:
            return {}
        
        self.policy.train()
        
        # Compute advantages once (fixed for all epochs)
        advantages = compute_grpo_advantages(
            experience['rewards'],
            self.group_size,
            normalize=self.normalize_advantages
        )
        
        # Get reference log probs
        ref_log_probs = self.ref_model.get_log_probs(
            experience['input_ids'][:, :-1],
            experience['labels'][:, 1:]
        )
        
        # Truncate old_log_probs to match
        max_len = min(experience['old_log_probs'].shape[1], ref_log_probs.shape[1])
        old_log_probs = experience['old_log_probs'][:, :max_len]
        ref_log_probs = ref_log_probs[:, :max_len]
        
        stats = {'loss': 0, 'kl': 0, 'clip_fraction': 0}
        num_updates = 0
        
        # Multiple PPO epochs
        batch_size = experience['input_ids'].shape[0]
        indices = np.arange(batch_size)
        
        for epoch in range(self.ppo_epochs):
            np.random.shuffle(indices)
            
            # Mini-batch updates
            for start in range(0, batch_size, self.mini_batch_size):
                end = min(start + self.mini_batch_size, batch_size)
                mb_indices = indices[start:end]
                
                # Get mini-batch data
                mb_input_ids = experience['input_ids'][mb_indices]
                mb_labels = experience['labels'][mb_indices]
                mb_old_log_probs = old_log_probs[mb_indices]
                mb_advantages = advantages[mb_indices]
                mb_response_lengths = torch.full((len(mb_indices),), max_len, device=mb_input_ids.device)
                mb_ref_log_probs = ref_log_probs[mb_indices]
                
                # Forward pass
                policy_log_probs = self.policy.get_log_probs(
                    mb_input_ids[:, :-1], mb_labels[:, 1:]
                )[:, :max_len]
                
                # Compute loss
                loss_dict = compute_grpo_loss(
                    policy_log_probs=policy_log_probs,
                    old_log_probs=mb_old_log_probs,
                    ref_log_probs=mb_ref_log_probs,
                    advantages=mb_advantages,
                    response_lengths=mb_response_lengths,
                    epsilon=self.epsilon,
                    beta=self.beta,
                    loss_aggregation=self.loss_aggregation,
                )
                
                # Backward pass
                self.optimizer.zero_grad()
                loss_dict['loss'].backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                
                self.optimizer.step()
                
                # Accumulate stats
                stats['loss'] += loss_dict['loss'].item()
                stats['kl'] += loss_dict['kl'].item()
                stats['clip_fraction'] += loss_dict['clip_fraction'].item()
                num_updates += 1
        
        # Average stats
        for key in stats:
            stats[key] /= max(num_updates, 1)
        
        # Clear buffer after training
        self.experience_buffer.clear()
        
        return stats

In [10]:
# TEST: GRPO Trainer on toy problem

# Simple reward: higher if response contains more 1s
def simple_reward(tokens):
    return (tokens == 1).float().sum(dim=1) / tokens.shape[1]

# Initialize
policy = TinyLM(vocab_size=10).to(device)
ref_model = ReferenceModel(policy).to(device)

trainer = GRPOTrainer(
    policy=policy,
    ref_model=ref_model,
    reward_fn=simple_reward,
    group_size=4,
    ppo_epochs=2,
    mini_batch_size=4,
    lr=1e-4,
    beta=0.1,
)

# Training loop
print("Training GRPO...")
print(f"{'Step':>5} {'Reward':>10} {'Loss':>10} {'KL':>10} {'Clip%':>10}")
print("-" * 50)

for step in range(5):
    # Generate prompts (just fixed tokens)
    prompts = torch.zeros((4, 5), dtype=torch.long, device=device)
    
    # Collect rollouts
    rollout_stats = trainer.collect_rollouts(prompts, max_length=20)
    
    # Train
    train_stats = trainer.train_step()
    
    print(f"{step:>5} {rollout_stats['mean_reward']:>10.4f} "
          f"{train_stats.get('loss', 0):>10.4f} "
          f"{train_stats.get('kl', 0):>10.4f} "
          f"{train_stats.get('clip_fraction', 0)*100:>9.1f}%")

Training GRPO...
 Step     Reward       Loss         KL      Clip%
--------------------------------------------------
    0     0.0600    -0.0033    -0.0006      16.1%
    1     0.0975    -0.0011     0.0009      14.5%
    2     0.0625     0.0010     0.0020      10.0%
    3     0.0725     0.0031     0.0024      13.0%
    4     0.0825     0.0030     0.0046      12.3%


## 7. Production Hyperparameters Reference

Based on DeepSeek R1 and production systems:

```python
config = {
    # Group sampling
    'group_size': 8,              # Responses per prompt (K)
    
    # PPO parameters
    'epsilon': 0.2,               # Clip range
    'ppo_epochs': 4,              # Epochs per rollout
    'mini_batch_size': 64,        # Per mini-batch
    
    # KL penalty
    'beta': 0.1,                  # KL coefficient
    'target_kl': 6.0,             # For adaptive KL
    
    # Optimization
    'learning_rate': 1e-5,        # Policy LR
    'max_grad_norm': 1.0,         # Gradient clipping
    
    # Aggregation (model-dependent)
    'loss_aggregation': 'token',  # For base models
    'normalize_advantages': 'global',  # Stable
    
    # Generation
    'max_response_length': 8192,
    'temperature': 1.0,
}
```

---
**Next:** `06_infrastructure_frameworks.ipynb` (OpenRLHF, Ray, vLLM)