# 08. Reinforcement Learning from Human Feedback

This notebook explores how to fine-tune large language models with reinforcement learning in domains where we can automatically verify model responses. We begin with **Proximal Policy Optimization (PPO)** and then extend the same ingredients to **Group Relative Preference Optimization (GRPO)** while working with the public [Anthropic HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset of human preference pairs.


## Learning goals

* Load and inspect the Anthropic human preference dataset.
* Build a minimal pipeline for PPO that can optimize a policy model against automatically computed rewards.
* Adapt the PPO implementation into GRPO to take advantage of grouped preference data.
* Discuss evaluation considerations for verifiable domains, such as math or factual QA.


> **Note:** This notebook focuses on presenting a *reference implementation*. The code is intended to run on small models such as `distilgpt2` for demonstration, but you should expect to adjust batch sizes and precision if you train longer or on larger checkpoints.


## Environment setup

We rely on the Hugging Face ecosystem for models and datasets, and on PyTorch for automatic differentiation.


In [2]:
import math
import random
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import torch
from datasets import Dataset, load_dataset
from torch import nn
from torch.nn import functional as F
import random
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import lightning as L
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import tiktoken

from src.shraygpt import ShrayGPT

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

tokenizer = tiktoken.get_encoding("r50k_base")

model = ShrayGPT.load_from_checkpoint("checkpoints/shraygpt-instruct.ckpt", map_location="cpu")

model.hparams.learning_rate_adamw = 1e-4
model.hparams.learning_rate_muon = 5e-4
model.hparams.aux_loss_weight = 5e-4


Using device: cpu


: 

### Loading Anthropic preference pairs

The Anthropic HH-RLHF dataset contains paired responses (`chosen` vs. `rejected`) for multi-turn conversations. We'll downsample aggressively to keep the demo light-weight.


In [1]:
dataset = load_dataset('Anthropic/hh-rlhf', split='train')
print(dataset)

def format_pair(example: Dict[str, str]) -> Dict[str, str]:
    prompt = example['prompt'].strip()
    chosen = example['chosen'].strip()
    rejected = example['rejected'].strip()
    return {
        'prompt': prompt,
        'chosen': chosen,
        'rejected': rejected,
    }

dataset = dataset.map(format_pair)
small_dataset = dataset.select(range(256))
small_dataset[0]


NameError: name 'load_dataset' is not defined

### Tokenization utilities

We operate on concatenated prompt+response strings and keep track of the prompt length so we can mask out the log-probabilities that correspond to the prompt tokens during policy optimization.


In [None]:
TOKENIZER_NAME = 'distilgpt2'
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
tokenizer.pad_token = tokenizer.eos_token

@dataclass
class TokenizedBatch:
    input_ids: torch.LongTensor
    attention_mask: torch.LongTensor
    prompt_mask: torch.BoolTensor


def tokenize_batch(prompts: List[str], responses: List[str]) -> TokenizedBatch:
    combined = [prompts[i] + responses[i] for i in range(len(prompts))]
    enc = tokenizer(combined, padding=True, return_tensors='pt')
    prompt_lens = torch.tensor(
        [len(tokenizer(p).input_ids) for p in prompts], dtype=torch.long
    )
    prompt_mask = torch.arange(enc.input_ids.size(1)).unsqueeze(0) < prompt_lens.unsqueeze(1)
    return TokenizedBatch(
        input_ids=enc.input_ids,
        attention_mask=enc.attention_mask,
        prompt_mask=prompt_mask,
    )


## Implementing PPO

PPO optimizes a policy by constraining updates to remain close to the previous policy through a clipped surrogate loss. For language models we operate in log-probability space and mask the prompt tokens because they are provided by the environment.


### Configuration dataclasses


In [None]:
@dataclass
class PPOConfig:
    kl_coef: float = 0.1
    clip_range: float = 0.2
    vf_coef: float = 0.1
    ent_coef: float = 0.01
    target_kl: float = 0.1
    num_epochs: int = 1
    batch_size: int = 4
    mini_batch_size: int = 2
    max_new_tokens: int = 64


### Storage for rollouts


In [None]:
@dataclass
class PPORollout:
    prompt: str
    response: str
    reward: float
    logprobs: torch.Tensor
    ref_logprobs: torch.Tensor
    values: torch.Tensor
    masks: torch.BoolTensor


def stack_rollouts(rollouts: List[PPORollout]) -> Dict[str, torch.Tensor]:
    return {
        'logprobs': torch.stack([r.logprobs for r in rollouts]),
        'ref_logprobs': torch.stack([r.ref_logprobs for r in rollouts]),
        'values': torch.stack([r.values for r in rollouts]),
        'masks': torch.stack([r.masks for r in rollouts]),
        'rewards': torch.tensor([r.reward for r in rollouts], dtype=torch.float),
    }


### Policy and value helper functions


In [None]:
def build_models(model_name: str = TOKENIZER_NAME) -> Tuple[PreTrainedModel, PreTrainedModel]:
    policy = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
    ref_model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)
    return policy, ref_model


def compute_logprobs(model: PreTrainedModel, batch: TokenizedBatch) -> torch.Tensor:
    with torch.no_grad():
        outputs = model(
            input_ids=batch.input_ids.to(DEVICE),
            attention_mask=batch.attention_mask.to(DEVICE)
        )
    logits = outputs.logits[:, :-1]
    labels = batch.input_ids[:, 1:].to(DEVICE)
    log_probs = F.log_softmax(logits, dim=-1)
    chosen = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    return chosen.cpu()


def compute_values(model: PreTrainedModel, batch: TokenizedBatch) -> torch.Tensor:
    logprobs = compute_logprobs(model, batch)
    token_counts = batch.attention_mask[:, 1:].sum(dim=-1, keepdim=True)
    return logprobs.sum(dim=-1, keepdim=True) / token_counts


### Reward shaping for verifiable domains

To keep the demo self-contained, we simulate a verifiable task: the response must include an explicit `'Answer:'` tag followed by a number that matches a deterministic scoring function. Real projects would plug in a programmatic verifier such as a unit test harness for code or a math proof checker.


In [None]:
def synthetic_verifier(prompt: str, response: str) -> float:
    target = 0
    for token in prompt.split():
        if token.isdigit():
            target += int(token)
    reward = -1.0
    if 'Answer:' in response:
        try:
            prediction = int(response.split('Answer:')[1].split()[0])
            reward = 1.0 if prediction == target else -0.2
        except (ValueError, IndexError):
            reward = -0.5
    return reward


### Collecting PPO rollouts


In [None]:
def generate_response(model: PreTrainedModel, prompt: str, config: PPOConfig) -> str:
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE)
    output = model.generate(
        input_ids,
        max_new_tokens=config.max_new_tokens,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id,
    )
    generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
    return generated.strip()


def collect_rollouts(policy: PreTrainedModel, ref_model: PreTrainedModel, dataset: Dataset, config: PPOConfig, num_rollouts: int) -> List[PPORollout]:
    rollouts: List[PPORollout] = []
    for example in dataset.shuffle(seed=42).select(range(num_rollouts)):
        prompt = example['prompt']
        response = generate_response(policy, prompt, config)
        tokenized = tokenize_batch([prompt], [response])
        logprobs = compute_logprobs(policy, tokenized)[0]
        ref_logprobs = compute_logprobs(ref_model, tokenized)[0]
        values = compute_values(policy, tokenized)[0]
        reward = synthetic_verifier(prompt, response)
        rollout = PPORollout(
            prompt=prompt,
            response=response,
            reward=reward,
            logprobs=logprobs,
            ref_logprobs=ref_logprobs,
            values=values,
            masks=~tokenized.prompt_mask[:, 1:][0],
        )
        rollouts.append(rollout)
    return rollouts


### PPO loss computation


In [None]:
def ppo_loss(policy: PreTrainedModel, rollouts: List[PPORollout], config: PPOConfig) -> Tuple[torch.Tensor, Dict[str, float]]:
    stacked = stack_rollouts(rollouts)
    advantages = stacked['rewards'].unsqueeze(-1) - stacked['values'].squeeze(-1)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    batch = tokenize_batch(
        [r.prompt for r in rollouts],
        [r.response for r in rollouts],
    )
    outputs = policy(
        input_ids=batch.input_ids.to(DEVICE),
        attention_mask=batch.attention_mask.to(DEVICE)
    )
    logits = outputs.logits[:, :-1]
    labels = batch.input_ids[:, 1:].to(DEVICE)
    log_probs = F.log_softmax(logits, dim=-1)
    chosen = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    old_logprobs = stacked['logprobs'].to(DEVICE)
    logratio = chosen - old_logprobs
    ratio = logratio.exp()
    unclipped = ratio * advantages.to(DEVICE)
    clipped = torch.clamp(ratio, 1 - config.clip_range, 1 + config.clip_range) * advantages.to(DEVICE)
    mask = stacked['masks'].to(DEVICE)
    policy_loss = -(torch.min(unclipped, clipped) * mask).sum() / mask.sum()

    with torch.no_grad():
        ref_logprobs = stacked['ref_logprobs'].to(DEVICE)
    kl = ((old_logprobs - ref_logprobs) * mask).sum() / mask.sum()

    value_loss = F.mse_loss(stacked['values'].to(DEVICE).squeeze(-1), stacked['rewards'].to(DEVICE))
    entropy = -(ratio * logratio * mask).sum() / mask.sum()
    total_loss = policy_loss + config.vf_coef * value_loss - config.ent_coef * entropy + config.kl_coef * kl

    metrics = {
        'policy_loss': policy_loss.item(),
        'value_loss': value_loss.item(),
        'kl': kl.item(),
        'entropy': entropy.item(),
    }
    return total_loss, metrics


### PPO training loop


In [None]:
def train_ppo(dataset: Dataset, config: PPOConfig) -> Tuple[PreTrainedModel, List[Dict[str, float]]]:
    policy, ref_model = build_models()
    optimizer = torch.optim.AdamW(policy.parameters(), lr=5e-6)
    all_metrics: List[Dict[str, float]] = []

    for epoch in range(config.num_epochs):
        rollouts = collect_rollouts(policy, ref_model, dataset, config, config.batch_size)
        loss, metrics = ppo_loss(policy, rollouts, config)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
        optimizer.step()
        metrics['epoch'] = epoch
        all_metrics.append(metrics)
        print(f'Epoch {epoch}: {metrics}')
    return policy, all_metrics


Even a single PPO epoch is computationally demanding because it needs to generate fresh samples. For larger experiments you would:

* Use a separate **reward model** instead of the synthetic verifier.
* Accumulate rollouts across multiple gradient steps.
* Monitor the KL divergence to adjust `kl_coef` adaptively.

Next we upgrade the algorithm to GRPO.


## Implementing GRPO

Group Relative Preference Optimization (GRPO) modifies PPO by optimizing the policy against *grouped* preferences. For each prompt we draw several candidate completions, score them with the verifier, and compute advantages relative to the group statistics.


### GRPO utilities


In [None]:
@dataclass
class GRPOConfig(PPOConfig):
    group_size: int = 4


def collect_group_rollouts(policy: PreTrainedModel, ref_model: PreTrainedModel, dataset: Dataset, config: GRPOConfig, num_groups: int) -> List[List[PPORollout]]:
    groups: List[List[PPORollout]] = []
    for example in dataset.shuffle(seed=1337).select(range(num_groups)):
        prompt = example['prompt']
        group: List[PPORollout] = []
        for _ in range(config.group_size):
            response = generate_response(policy, prompt, config)
            tokenized = tokenize_batch([prompt], [response])
            logprobs = compute_logprobs(policy, tokenized)[0]
            ref_logprobs = compute_logprobs(ref_model, tokenized)[0]
            values = compute_values(policy, tokenized)[0]
            reward = synthetic_verifier(prompt, response)
            rollout = PPORollout(
                prompt=prompt,
                response=response,
                reward=reward,
                logprobs=logprobs,
                ref_logprobs=ref_logprobs,
                values=values,
                masks=~tokenized.prompt_mask[:, 1:][0],
            )
            group.append(rollout)
        groups.append(group)
    return groups


def grpo_group_advantages(group: List[PPORollout]) -> torch.Tensor:
    rewards = torch.tensor([r.reward for r in group], dtype=torch.float)
    baseline = rewards.mean()
    advantages = rewards - baseline
    return (advantages - advantages.mean()) / (advantages.std() + 1e-8)


### GRPO loss


In [None]:
def grpo_loss(policy: PreTrainedModel, groups: List[List[PPORollout]], config: GRPOConfig) -> Tuple[torch.Tensor, Dict[str, float]]:
    total_loss = torch.tensor(0.0, device=DEVICE)
    metrics = {'policy_loss': 0.0, 'kl': 0.0, 'entropy': 0.0}
    total_tokens = 0

    for group in groups:
        advantages = grpo_group_advantages(group).to(DEVICE)
        batch = tokenize_batch([r.prompt for r in group], [r.response for r in group])
        outputs = policy(
            input_ids=batch.input_ids.to(DEVICE),
            attention_mask=batch.attention_mask.to(DEVICE)
        )
        logits = outputs.logits[:, :-1]
        labels = batch.input_ids[:, 1:].to(DEVICE)
        log_probs = F.log_softmax(logits, dim=-1)
        chosen = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)

        old_logprobs = torch.stack([r.logprobs for r in group]).to(DEVICE)
        logratio = chosen - old_logprobs
        ratio = logratio.exp()
        mask = torch.stack([r.masks for r in group]).to(DEVICE)

        policy_loss = -((ratio * advantages.unsqueeze(-1)) * mask).sum() / mask.sum()
        entropy = -(ratio * logratio * mask).sum() / mask.sum()

        ref_logprobs = torch.stack([r.ref_logprobs for r in group]).to(DEVICE)
        kl = ((old_logprobs - ref_logprobs) * mask).sum() / mask.sum()

        group_loss = policy_loss - config.ent_coef * entropy + config.kl_coef * kl
        total_loss = total_loss + group_loss
        total_tokens += mask.sum().item()
        metrics['policy_loss'] += policy_loss.item()
        metrics['kl'] += kl.item()
        metrics['entropy'] += entropy.item()

    total_loss = total_loss / len(groups)
    for key in metrics:
        metrics[key] /= len(groups)
    metrics['tokens_per_step'] = total_tokens / len(groups)
    return total_loss, metrics


### GRPO training loop


In [None]:
def train_grpo(dataset: Dataset, config: GRPOConfig) -> Tuple[PreTrainedModel, List[Dict[str, float]]]:
    policy, ref_model = build_models()
    optimizer = torch.optim.AdamW(policy.parameters(), lr=5e-6)
    history: List[Dict[str, float]] = []

    for epoch in range(config.num_epochs):
        groups = collect_group_rollouts(policy, ref_model, dataset, config, num_groups=config.batch_size)
        loss, metrics = grpo_loss(policy, groups, config)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
        optimizer.step()
        metrics['epoch'] = epoch
        history.append(metrics)
        print(f'Epoch {epoch}: {metrics}')
    return policy, history


## Evaluation and verification

In verifiable domains you can often rely on automated scoring, which greatly simplifies evaluation compared with subjective tasks. Consider:

* **Hold-out prompts** with deterministic solutions, allowing exact accuracy measurement.
* **Programmatic validators** that can execute generated code, check unit tests, or validate mathematical proofs.
* **Safety filters** that refuse to answer when verification fails or the prompt is unsafe.

Below is a simple evaluation helper that reuses the synthetic verifier but can be swapped out for richer test harnesses.


In [None]:
def evaluate_model(model: PreTrainedModel, prompts: Iterable[str], verifier: Callable[[str, str], float], config: PPOConfig) -> Dict[str, float]:
    rewards = []
    for prompt in prompts:
        response = generate_response(model, prompt, config)
        rewards.append(verifier(prompt, response))
    return {
        'avg_reward': sum(rewards) / len(rewards),
        'success_rate': sum(r > 0 for r in rewards) / len(rewards),
    }


## Next steps

* Replace the synthetic verifier with a domain-specific scoring function.
* Fine-tune a reward model on the `chosen` vs. `rejected` responses to approximate human preferences when automation is impossible.
* Scale up batch sizes, gradient accumulation, and model sizes using distributed training utilities.
* Track metrics like KL divergence, entropy, and accuracy to maintain alignment with the reference policy.
