# 08 — PRIME: Implicit Process Rewards

> **Purpose:** Implement dense token-level rewards without expensive process annotations. PRIME learns a Q-function from outcome supervision only.

**Key insight:** Train an outcome reward model, then extract token-level rewards as log probability ratios.

$$r_{implicit}(s_t, a_t) = \log \pi_{PRM}(a_t|s_t) - \log \pi_{ref}(a_t|s_t)$$

| Method | Annotations Required | Dense Rewards | Online Updates |
|--------|---------------------|---------------|----------------|
| Explicit PRM | Per-step labels ($500K+) | ✅ | ❌ |
| ORM only | Outcome only | ❌ Sparse | ❌ |
| **PRIME** | Outcome only | ✅ | ✅ |

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional
import copy

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

## 1. The Credit Assignment Problem

**Sparse rewards make learning hard.** When only the final answer is rewarded, which tokens contributed?

In [None]:
def visualize_credit_assignment():
    """
    Demonstrate the credit assignment problem with sparse vs dense rewards.
    """
    # Example reasoning trace
    tokens = [
        "Let", "x", "=", "5",           # Step 1: Define variable
        "Then", "2x", "=", "10",        # Step 2: Correct multiplication
        "So", "2x", "+", "3", "=", "13", # Step 3: Correct addition
        "Answer:", "13"                  # Final answer
    ]
    
    # Sparse reward: only final token gets signal
    sparse_rewards = [0.0] * (len(tokens) - 1) + [1.0]
    
    # Dense (ideal PRM) rewards: each step gets credit
    dense_rewards = [
        0.1, 0.1, 0.1, 0.1,  # Step 1
        0.2, 0.2, 0.2, 0.2,  # Step 2
        0.2, 0.2, 0.2, 0.2, 0.2, 0.2,  # Step 3
        0.3, 1.0             # Final
    ]
    
    print("Credit Assignment Comparison:")
    print(f"{'Token':<10} {'Sparse':<10} {'Dense':<10}")
    print("-" * 30)
    for tok, s, d in zip(tokens, sparse_rewards, dense_rewards):
        print(f"{tok:<10} {s:<10.1f} {d:<10.1f}")
    
    print(f"\nTotal sparse signal: {sum(sparse_rewards):.1f}")
    print(f"Total dense signal: {sum(dense_rewards):.1f}")
    print("Dense rewards give each step credit for correctness!")

visualize_credit_assignment()

## 2. Model Setup: Policy, Reference, and Implicit PRM

In [None]:
class TinyLM(nn.Module):
    """Minimal LM for PRIME experiments."""
    
    def __init__(self, vocab_size=100, hidden_size=64):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.head = nn.Linear(hidden_size, vocab_size)
        self.vocab_size = vocab_size
    
    def forward(self, x):
        h = self.embed(x)
        h, _ = self.rnn(h)
        return self.head(h)
    
    def get_token_log_probs(self, input_ids: torch.Tensor, 
                            labels: torch.Tensor) -> torch.Tensor:
        """
        Get log probability for each token.
        
        Returns:
            log_probs: (batch, seq_len) per-token log probs
        """
        logits = self(input_ids)
        log_probs = F.log_softmax(logits, dim=-1)
        return torch.gather(log_probs, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)


def create_reference_model(model: nn.Module) -> nn.Module:
    """Create frozen reference model."""
    ref = copy.deepcopy(model)
    ref.eval()
    for p in ref.parameters():
        p.requires_grad = False
    return ref

## 3. Implicit Process Reward Computation

The key PRIME formula: token-level reward from log probability ratio.

$$r_{implicit}(t) = \log \pi_{PRM}(a_t|s_t) - \log \pi_{ref}(a_t|s_t)$$

In [None]:
def compute_implicit_process_rewards(
    prm_log_probs: torch.Tensor,
    ref_log_probs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Compute token-level implicit process rewards.
    
    This is the core PRIME insight: the log probability ratio
    between a trained model and reference acts as a Q-function,
    providing dense rewards without explicit process labels.
    
    Args:
        prm_log_probs: (batch, seq_len) log probs from implicit PRM
        ref_log_probs: (batch, seq_len) log probs from frozen reference
        mask: Optional (batch, seq_len) mask for valid tokens
    
    Returns:
        implicit_rewards: (batch, seq_len) per-token rewards
    """
    # Implicit reward = log ratio (acts as Q-value difference)
    implicit_rewards = prm_log_probs - ref_log_probs
    
    if mask is not None:
        implicit_rewards = implicit_rewards * mask
    
    return implicit_rewards

In [None]:
# TEST: Implicit rewards

policy = TinyLM().to(device)
ref_model = create_reference_model(policy).to(device)

# Modify policy slightly to see difference
with torch.no_grad():
    policy.head.weight.add_(torch.randn_like(policy.head.weight) * 0.1)

# Sample input
batch_size, seq_len = 2, 10
input_ids = torch.randint(0, 100, (batch_size, seq_len), device=device)
labels = torch.randint(0, 100, (batch_size, seq_len), device=device)

# Compute log probs
prm_logps = policy.get_token_log_probs(input_ids, labels)
ref_logps = ref_model.get_token_log_probs(input_ids, labels)

# Compute implicit rewards
implicit_rewards = compute_implicit_process_rewards(prm_logps, ref_logps)

print("Implicit Process Rewards (per token):")
print(f"  Shape: {implicit_rewards.shape}")
print(f"  Mean: {implicit_rewards.mean().item():.4f}")
print(f"  Std: {implicit_rewards.std().item():.4f}")
print(f"  Sample: {implicit_rewards[0, :5].tolist()}")

## 4. Monte Carlo Advantage Estimation

PRIME combines implicit process rewards with sparse outcome rewards.

In [None]:
def compute_prime_advantages(
    implicit_rewards: torch.Tensor,
    outcome_rewards: torch.Tensor,
    response_lengths: torch.Tensor,
    gamma: float = 1.0,
    lambda_implicit: float = 0.5,
) -> torch.Tensor:
    """
    Compute advantages combining implicit and outcome rewards.
    
    Uses Monte Carlo estimation: future return from each token.
    
    Args:
        implicit_rewards: (batch, seq_len) per-token implicit rewards
        outcome_rewards: (batch,) sparse outcome reward per sequence
        response_lengths: (batch,) length of each response
        gamma: Discount factor (1.0 for no discounting)
        lambda_implicit: Weight for implicit vs outcome rewards
    
    Returns:
        advantages: (batch, seq_len) per-token advantages
    """
    batch_size, seq_len = implicit_rewards.shape
    device = implicit_rewards.device
    
    # Combine rewards: spread outcome reward + implicit rewards
    combined_rewards = implicit_rewards.clone() * lambda_implicit
    
    # Add outcome reward at the last token of each sequence
    for i in range(batch_size):
        end_idx = min(int(response_lengths[i].item()) - 1, seq_len - 1)
        combined_rewards[i, end_idx] += outcome_rewards[i] * (1 - lambda_implicit)
    
    # Monte Carlo return: sum of future rewards from each position
    # (Simplified: no discounting for gamma=1.0)
    advantages = torch.zeros_like(combined_rewards)
    
    for i in range(batch_size):
        end_idx = int(response_lengths[i].item())
        cumsum = 0.0
        for t in range(end_idx - 1, -1, -1):
            cumsum = combined_rewards[i, t] + gamma * cumsum
            advantages[i, t] = cumsum
    
    # Normalize advantages
    mask = torch.arange(seq_len, device=device).unsqueeze(0) < response_lengths.unsqueeze(1)
    valid_advantages = advantages[mask]
    if len(valid_advantages) > 1:
        mean = valid_advantages.mean()
        std = valid_advantages.std() + 1e-8
        advantages = (advantages - mean) / std
    
    return advantages * mask.float()

In [None]:
# TEST: PRIME advantages

implicit_rewards = torch.randn(4, 20, device=device) * 0.1  # Small per-token
outcome_rewards = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)  # Win/loss
response_lengths = torch.tensor([15, 18, 12, 20], device=device)

advantages = compute_prime_advantages(
    implicit_rewards, outcome_rewards, response_lengths,
    lambda_implicit=0.5
)

print("PRIME Advantages:")
print(f"  Shape: {advantages.shape}")
for i in range(4):
    print(f"  Seq {i} (outcome={outcome_rewards[i]:.0f}): "
          f"first 5 = {advantages[i, :5].tolist()}")

## 5. PRIME Loss Function

PPO-style policy loss with implicit process rewards.

In [None]:
def compute_prime_policy_loss(
    policy_log_probs: torch.Tensor,
    old_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    mask: torch.Tensor,
    epsilon: float = 0.2,
) -> Dict[str, torch.Tensor]:
    """
    Compute PRIME policy loss (PPO-style with dense advantages).
    
    Args:
        policy_log_probs: (batch, seq_len) current policy log probs
        old_log_probs: (batch, seq_len) old policy log probs
        advantages: (batch, seq_len) PRIME advantages
        mask: (batch, seq_len) valid token mask
        epsilon: PPO clip parameter
    
    Returns:
        Dict with 'loss', 'clip_fraction'
    """
    # Importance sampling ratio
    log_ratio = policy_log_probs - old_log_probs
    ratio = torch.exp(log_ratio)
    
    # Clipped ratio
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
    
    # PPO objective
    obj1 = ratio * advantages
    obj2 = clipped_ratio * advantages
    policy_loss = -torch.min(obj1, obj2)
    
    # Apply mask and average
    masked_loss = (policy_loss * mask).sum() / mask.sum()
    
    # Tracking metrics
    clipped = (ratio < 1 - epsilon) | (ratio > 1 + epsilon)
    clip_fraction = (clipped * mask).sum() / mask.sum()
    
    return {
        'loss': masked_loss,
        'clip_fraction': clip_fraction,
        'mean_ratio': (ratio * mask).sum() / mask.sum(),
    }

## 6. Online PRM Update

PRIME's key innovation: update the implicit PRM alongside the policy.

In [None]:
def compute_prm_update_loss(
    prm_log_probs_correct: torch.Tensor,
    prm_log_probs_incorrect: torch.Tensor,
    correct_lengths: torch.Tensor,
    incorrect_lengths: torch.Tensor,
) -> torch.Tensor:
    """
    Compute loss for online PRM update.
    
    Train implicit PRM to distinguish correct vs incorrect outcomes.
    
    Args:
        prm_log_probs_*: (batch, seq_len) log probs for correct/incorrect
        *_lengths: (batch,) sequence lengths
    
    Returns:
        PRM update loss (preference-style)
    """
    batch_size = prm_log_probs_correct.shape[0]
    
    # Sum log probs for each sequence
    correct_sum = torch.zeros(batch_size, device=prm_log_probs_correct.device)
    incorrect_sum = torch.zeros(batch_size, device=prm_log_probs_incorrect.device)
    
    for i in range(batch_size):
        correct_sum[i] = prm_log_probs_correct[i, :int(correct_lengths[i])].sum()
        incorrect_sum[i] = prm_log_probs_incorrect[i, :int(incorrect_lengths[i])].sum()
    
    # Preference loss: PRM should prefer correct over incorrect
    logits = correct_sum - incorrect_sum
    loss = -F.logsigmoid(logits).mean()
    
    return loss

## 7. Complete PRIME Trainer

In [None]:
class PRIMETrainer:
    """
    PRIME: Process Reinforcement through Implicit Rewards.
    
    Key features:
    - Implicit PRM initialized from SFT, updated online
    - Dense token-level rewards from log probability ratio
    - Monte Carlo advantage estimation
    - PPO policy updates
    """
    
    def __init__(self,
                 policy: nn.Module,
                 ref_model: nn.Module,
                 implicit_prm: nn.Module,
                 epsilon: float = 0.2,
                 lambda_implicit: float = 0.5,
                 policy_lr: float = 1e-5,
                 prm_lr: float = 1e-5,
                 update_prm: bool = True):
        
        self.policy = policy
        self.ref_model = ref_model
        self.implicit_prm = implicit_prm
        self.epsilon = epsilon
        self.lambda_implicit = lambda_implicit
        self.update_prm = update_prm
        
        self.policy_optimizer = torch.optim.AdamW(policy.parameters(), lr=policy_lr)
        self.prm_optimizer = torch.optim.AdamW(implicit_prm.parameters(), lr=prm_lr)
    
    def compute_implicit_rewards(self, input_ids: torch.Tensor, 
                                  labels: torch.Tensor) -> torch.Tensor:
        """Get implicit process rewards for a batch."""
        with torch.no_grad():
            prm_logps = self.implicit_prm.get_token_log_probs(input_ids, labels)
            ref_logps = self.ref_model.get_token_log_probs(input_ids, labels)
        
        return compute_implicit_process_rewards(prm_logps, ref_logps)
    
    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """
        Single PRIME training step.
        
        Args:
            batch: Dict with:
                - 'input_ids': (batch, seq_len)
                - 'labels': (batch, seq_len)
                - 'outcome_rewards': (batch,) sparse outcome
                - 'response_lengths': (batch,)
                - 'old_log_probs': (batch, seq_len) for importance sampling
        
        Returns:
            Training metrics
        """
        self.policy.train()
        
        # Step 1: Compute implicit process rewards
        implicit_rewards = self.compute_implicit_rewards(
            batch['input_ids'], batch['labels']
        )
        
        # Step 2: Compute PRIME advantages
        advantages = compute_prime_advantages(
            implicit_rewards,
            batch['outcome_rewards'],
            batch['response_lengths'],
            lambda_implicit=self.lambda_implicit,
        )
        
        # Step 3: Compute current policy log probs
        policy_logps = self.policy.get_token_log_probs(
            batch['input_ids'], batch['labels']
        )
        
        # Create mask
        seq_len = batch['input_ids'].shape[1]
        mask = torch.arange(seq_len, device=batch['input_ids'].device).unsqueeze(0) \
               < batch['response_lengths'].unsqueeze(1)
        mask = mask.float()
        
        # Step 4: Compute policy loss
        loss_dict = compute_prime_policy_loss(
            policy_logps,
            batch['old_log_probs'],
            advantages,
            mask,
            epsilon=self.epsilon,
        )
        
        # Step 5: Update policy
        self.policy_optimizer.zero_grad()
        loss_dict['loss'].backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1.0)
        self.policy_optimizer.step()
        
        metrics = {k: v.item() for k, v in loss_dict.items()}
        metrics['implicit_reward_mean'] = implicit_rewards.mean().item()
        
        return metrics

In [None]:
# TEST: PRIME Trainer

policy = TinyLM().to(device)
ref_model = create_reference_model(policy).to(device)
implicit_prm = TinyLM().to(device)  # Separate model for PRM

trainer = PRIMETrainer(
    policy=policy,
    ref_model=ref_model,
    implicit_prm=implicit_prm,
    lambda_implicit=0.5,
    policy_lr=1e-4,
)

# Create synthetic batch
batch_size, seq_len = 4, 30
batch = {
    'input_ids': torch.randint(0, 100, (batch_size, seq_len), device=device),
    'labels': torch.randint(0, 100, (batch_size, seq_len), device=device),
    'outcome_rewards': torch.tensor([1.0, 0.0, 1.0, 0.0], device=device),
    'response_lengths': torch.tensor([25, 30, 20, 28], device=device),
    'old_log_probs': torch.randn(batch_size, seq_len, device=device) - 5,
}

print("PRIME Training:")
print(f"{'Step':>5} {'Loss':>10} {'Clip%':>10} {'ImplRwd':>12}")
print("-" * 40)

for step in range(5):
    metrics = trainer.train_step(batch)
    print(f"{step:>5} {metrics['loss']:>10.4f} "
          f"{metrics['clip_fraction']*100:>9.1f}% "
          f"{metrics['implicit_reward_mean']:>12.4f}")

## 8. Key Takeaways

| Aspect | PRIME Approach |
|--------|---------------|
| **Dense Rewards** | Log ratio: `log π_prm - log π_ref` |
| **No Process Labels** | Train PRM on outcomes, extract token rewards |
| **Online Updates** | PRM evolves with policy, prevents reward hacking |
| **Sample Efficiency** | 2.5x vs RLOO baseline |

### Production Configuration

```python
config = {
    'lambda_implicit': 0.5,    # Balance implicit vs outcome
    'epsilon': 0.2,            # PPO clip
    'policy_lr': 1e-5,         # Policy learning rate
    'prm_lr': 1e-5,            # PRM learning rate
    'gamma': 1.0,              # No discounting for reasoning
    'update_prm_every': 1,     # Online PRM updates
}
```

---
**Tier 3 Started!** Next: Conceptual deep-dives on Kimi K2, DeepSeek V3.2, etc.