# 07 — DPO from Scratch: The Elegant Shortcut

In notebook 06, we built the full RLHF pipeline: train a reward model, then use PPO to optimize the policy. It works, but it's complex — 4 models, unstable training, reward hacking risks.

**DPO (Direct Preference Optimization)** achieves the same goal with a single elegant loss function. The key insight: if you work out the math, the optimal RLHF policy has a **closed-form solution**. Instead of learning a reward model and then doing RL, you can directly optimize the language model on preference pairs.

The DPO loss:

$$\mathcal{L}_{\text{DPO}} = -\log \sigma\left(\beta \left[\log \frac{\pi_\theta(y_w | x)}{\pi_{\text{ref}}(y_w | x)} - \log \frac{\pi_\theta(y_l | x)}{\pi_{\text{ref}}(y_l | x)}\right]\right)$$

Where:
- $\pi_\theta$ = current model (being trained)
- $\pi_{\text{ref}}$ = frozen reference model (the SFT model)
- $y_w$ = chosen ("winning") response
- $y_l$ = rejected ("losing") response
- $\beta$ = temperature (controls how far from reference model)

In plain English: **increase the probability of chosen responses relative to rejected ones, but don't stray too far from the reference model.**

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')
torch.manual_seed(1337)

## Part 0: Rebuild the SFT Model

Same as notebook 06 — we need a trained SFT model as our starting point.

In [None]:
# Load data and build vocabulary
with open('../data/shakespeare/input.txt', 'r') as f:
    text = f.read()

chars = sorted(list(set(text)))
base_vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

# Special tokens
SPECIAL_TOKENS = {
    '<|user|>': base_vocab_size,
    '<|assistant|>': base_vocab_size + 1,
    '<|end|>': base_vocab_size + 2,
}
ID_TO_SPECIAL = {v: k for k, v in SPECIAL_TOKENS.items()}
sft_vocab_size = base_vocab_size + len(SPECIAL_TOKENS)

def encode_sft(text):
    tokens = []
    i = 0
    while i < len(text):
        matched = False
        for token_str, token_id in SPECIAL_TOKENS.items():
            if text[i:].startswith(token_str):
                tokens.append(token_id)
                i += len(token_str)
                matched = True
                break
        if not matched:
            tokens.append(stoi[text[i]])
            i += 1
    return tokens

def decode_sft(tokens):
    result = []
    for t in tokens:
        if t in ID_TO_SPECIAL:
            result.append(ID_TO_SPECIAL[t])
        else:
            result.append(itos[t])
    return ''.join(result)

print(f'Vocab: {base_vocab_size} base + 3 special = {sft_vocab_size}')

In [None]:
# Hyperparameters and model architecture
batch_size = 32
block_size = 64
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.2

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        return wei @ self.value(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd // n_head)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        if targets is None:
            return logits, None
        B, T, C = logits.shape
        loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

print(f'Model architecture defined.')

In [None]:
# Pre-train base model and run SFT (same as notebooks 05-06)
def get_batch(split):
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

# Pre-train
base_model = GPT(base_vocab_size).to(device)
optimizer = torch.optim.AdamW(base_model.parameters(), lr=3e-4)
print("Pre-training base model...")
for iter in range(3000):
    if iter % 1000 == 0:
        base_model.eval()
        losses = torch.zeros(200)
        for k in range(200):
            X, Y = get_batch('val')
            _, loss = base_model(X, Y)
            losses[k] = loss.item()
        print(f"  step {iter}: val loss {losses.mean():.4f}")
        base_model.train()
    xb, yb = get_batch('train')
    _, loss = base_model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print("Pre-training done.")

# SFT
conversations = [
    ("Write a greeting", "Good morrow to thee, noble friend!"),
    ("Say hello", "Hail and well met, good sir!"),
    ("Greet a friend", "Welcome, dear companion of mine heart!"),
    ("Say goodbye", "Farewell, and may fortune smile upon thee."),
    ("Write a farewell", "Good night, sweet friend, till we meet again."),
    ("Who are you", "A humble player upon this stage of words."),
    ("Speak of love", "Love is a smoke raised with the fume of sighs."),
    ("Tell me of love", "It is the star to every wandering bark."),
    ("What is love", "A madness most discreet, a bitter sweet."),
    ("What is honor", "Honor is a mere word, yet it moves the heart."),
    ("Speak of duty", "Duty binds us all, from king to common man."),
    ("Describe the night", "The moon doth hang like silver in the sky."),
    ("Write of morning", "Dawn breaks golden upon the sleeping earth."),
    ("Give me counsel", "Be wise, be patient, and trust thy heart."),
    ("Give advice", "Think well before thou speak, and speak the truth."),
    ("What should I do", "Follow the path that honor bids thee walk."),
    ("Tell me a story", "Once there lived a king both wise and bold."),
    ("I am sad", "Take heart, for sorrow fades as morning comes."),
    ("I am afraid", "Fear not, for courage lives within thy soul."),
    ("Write a line of verse", "Shall I compare thee to a summer day?"),
    ("Give a toast", "To health and joy, and friends both old and new!"),
    ("Say something kind", "Thou art more lovely than the fairest dawn."),
    ("Cheer me up", "Smile, for the world is bright and full of wonder."),
]

def prepare_sft_example(user_msg, asst_msg, max_len=block_size):
    text = f"<|user|>{user_msg}<|end|><|assistant|>{asst_msg}<|end|>"
    tokens = encode_sft(text)
    if len(tokens) > max_len:
        tokens = tokens[:max_len]
    pad_len = max_len - len(tokens)
    tokens = tokens + [0] * pad_len
    input_ids = tokens[:-1]
    labels = tokens[1:]
    asst_token_id = SPECIAL_TOKENS['<|assistant|>']
    asst_pos = None
    for i, t in enumerate(input_ids):
        if t == asst_token_id:
            asst_pos = i
            break
    masked_labels = []
    for i, t in enumerate(labels):
        if asst_pos is not None and i <= asst_pos:
            masked_labels.append(-100)
        elif i >= (max_len - 1 - pad_len):
            masked_labels.append(-100)
        else:
            masked_labels.append(t)
    return input_ids, masked_labels

sft_model = GPT(sft_vocab_size).to(device)
pretrained_state = base_model.state_dict()
sft_state = sft_model.state_dict()
for key in pretrained_state:
    if key in sft_state:
        if pretrained_state[key].shape == sft_state[key].shape:
            sft_state[key] = pretrained_state[key]
        elif 'token_embedding' in key:
            sft_state[key][:base_vocab_size] = pretrained_state[key]
        elif 'lm_head.weight' in key:
            sft_state[key][:base_vocab_size] = pretrained_state[key]
        elif 'lm_head.bias' in key:
            sft_state[key][:base_vocab_size] = pretrained_state[key]
sft_model.load_state_dict(sft_state)

all_input_ids, all_labels = [], []
for user_msg, asst_msg in conversations:
    ids, labs = prepare_sft_example(user_msg, asst_msg)
    all_input_ids.append(ids)
    all_labels.append(labs)
sft_input_ids = torch.tensor(all_input_ids, dtype=torch.long, device=device)
sft_labels = torch.tensor(all_labels, dtype=torch.long, device=device)

optimizer = torch.optim.AdamW(sft_model.parameters(), lr=1e-4)
print("SFT training...")
sft_model.train()
for iter in range(1000):
    batch_idx = torch.randint(len(sft_input_ids), (8,))
    xb = sft_input_ids[batch_idx]
    yb = sft_labels[batch_idx]
    logits, _ = sft_model(xb)
    B, T, C = logits.shape
    loss = F.cross_entropy(logits.view(B*T, C), yb.view(B*T), ignore_index=-100)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if iter % 200 == 0:
        print(f"  step {iter}: SFT loss {loss.item():.4f}")
print("SFT done.")

In [None]:
def chat(model, user_msg, max_tokens=50):
    prompt = f"<|user|>{user_msg}<|end|><|assistant|>"
    tokens = encode_sft(prompt)
    idx = torch.tensor([tokens], dtype=torch.long, device=device)
    model.eval()
    with torch.no_grad():
        output = model.generate(idx, max_new_tokens=max_tokens)[0].tolist()
    model.train()
    full_text = decode_sft(output)
    if '<|assistant|>' in full_text:
        response = full_text.split('<|assistant|>')[-1]
        if '<|end|>' in response:
            response = response.split('<|end|>')[0]
        return response.strip()
    return full_text

# Verify SFT model works
print("SFT model check:")
for p in ["Write a greeting", "Speak of love"]:
    print(f"  {p} → {chat(sft_model, p)}")

## Part 1: The DPO Loss — Surprisingly Simple

The entire DPO algorithm comes down to one loss function. Let's build it step by step.

### The Math

For a preference pair (prompt $x$, chosen $y_w$, rejected $y_l$):

$$\mathcal{L}_{\text{DPO}} = -\log \sigma\left(\beta \left[\log \frac{\pi_\theta(y_w | x)}{\pi_{\text{ref}}(y_w | x)} - \log \frac{\pi_\theta(y_l | x)}{\pi_{\text{ref}}(y_l | x)}\right]\right)$$

### In plain English

1. Compute how much the **current model** likes the chosen response vs. the reference model → this is the "chosen log-ratio"
2. Compute how much the **current model** likes the rejected response vs. the reference model → this is the "rejected log-ratio"
3. Push the chosen log-ratio **up** and the rejected log-ratio **down**
4. $\beta$ controls how aggressively we push

### Why this works

The key insight from the [DPO paper](https://arxiv.org/abs/2305.18290): if you solve the RLHF optimization problem (maximize reward - KL penalty) analytically, the optimal policy satisfies:

$$r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + C$$

The reward is **implicit** in the log-ratio between the policy and reference model. So instead of learning a separate reward model, we can directly optimize the policy using preference pairs. The language model **is** the reward model.

In [None]:
# Preference dataset (same as notebook 06)
preference_data = [
    ("Write a greeting",
     "Good morrow to thee, noble friend!",
     "What dost thou want of me now?"),
    ("Say hello",
     "Hail and well met, good sir!",
     "Go away, I have no time."),
    ("Speak of love",
     "Love is a smoke raised with the fume of sighs.",
     "I know not of love nor care."),
    ("What is love",
     "A madness most discreet, a bitter sweet.",
     "It matters not. Ask another."),
    ("Give me counsel",
     "Be wise, be patient, and trust thy heart.",
     "Do what thou wilt, it matters not."),
    ("Give advice",
     "Think well before thou speak, and speak the truth.",
     "I care not for thy troubles."),
    ("I am sad",
     "Take heart, for sorrow fades as morning comes.",
     "Then be sad. What can I do?"),
    ("I am afraid",
     "Fear not, for courage lives within thy soul.",
     "Thou shouldst be afraid, fool."),
    ("Tell me a story",
     "Once there lived a king both wise and bold.",
     "No. I shall not. Leave me be."),
    ("Say goodbye",
     "Farewell, and may fortune smile upon thee.",
     "Good riddance to thee then."),
    ("Write a farewell",
     "Good night, sweet friend, till we meet again.",
     "Be gone from my sight at once."),
    ("Who are you",
     "A humble player upon this stage of words.",
     "None of thy concern, stranger."),
    ("Describe the night",
     "The moon doth hang like silver in the sky.",
     "It is dark. What more to say?"),
    ("Write of morning",
     "Dawn breaks golden upon the sleeping earth.",
     "Morning comes as always it does."),
    ("What should I do",
     "Follow the path that honor bids thee walk.",
     "How should I know? Decide thyself."),
    ("Give a toast",
     "To health and joy, and friends both old and new!",
     "Drink and be done with it."),
    ("Say something kind",
     "Thou art more lovely than the fairest dawn.",
     "Thou art adequate, I suppose."),
    ("Cheer me up",
     "Smile, for the world is bright and full of wonder.",
     "Why should I cheer thee? Cheer thyself."),
]

print(f"Preference dataset: {len(preference_data)} examples")

In [None]:
# Prepare preference data for DPO
#
# For each preference pair, we need the FULL tokenized sequences
# (prompt + chosen) and (prompt + rejected), plus masks to identify
# which tokens are the response (where we compute log probs).

def prepare_dpo_example(prompt, response, max_len=block_size):
    """Tokenize a (prompt, response) pair and create a response mask."""
    text = f"<|user|>{prompt}<|end|><|assistant|>{response}<|end|>"
    tokens = encode_sft(text)
    if len(tokens) > max_len:
        tokens = tokens[:max_len]

    # Find where the response starts (after <|assistant|>)
    asst_token_id = SPECIAL_TOKENS['<|assistant|>']
    asst_pos = None
    for i, t in enumerate(tokens):
        if t == asst_token_id:
            asst_pos = i
            break

    # Response mask: 1 for response tokens, 0 for prompt tokens
    response_mask = [0] * len(tokens)
    if asst_pos is not None:
        for i in range(asst_pos + 1, len(tokens)):
            response_mask[i] = 1

    # Pad
    pad_len = max_len - len(tokens)
    tokens = tokens + [0] * pad_len
    response_mask = response_mask + [0] * pad_len

    return tokens, response_mask

# Build tensors
chosen_ids_list, chosen_masks_list = [], []
rejected_ids_list, rejected_masks_list = [], []

for prompt, chosen, rejected in preference_data:
    c_ids, c_mask = prepare_dpo_example(prompt, chosen)
    r_ids, r_mask = prepare_dpo_example(prompt, rejected)
    chosen_ids_list.append(c_ids)
    chosen_masks_list.append(c_mask)
    rejected_ids_list.append(r_ids)
    rejected_masks_list.append(r_mask)

chosen_ids = torch.tensor(chosen_ids_list, dtype=torch.long, device=device)
chosen_masks = torch.tensor(chosen_masks_list, dtype=torch.float, device=device)
rejected_ids = torch.tensor(rejected_ids_list, dtype=torch.long, device=device)
rejected_masks = torch.tensor(rejected_masks_list, dtype=torch.float, device=device)

print(f"Chosen:   {chosen_ids.shape}")
print(f"Rejected: {rejected_ids.shape}")

# Visualize one example
print(f"\nExample: '{preference_data[0][0]}'")
print(f"  Chosen tokens:  {decode_sft(chosen_ids_list[0][:40])}...")
print(f"  Chosen mask:    {chosen_masks_list[0][:40]}")
print(f"  (1 = response token, 0 = prompt/padding)")

In [None]:
# The core of DPO: compute log probabilities of response tokens

def get_response_log_probs(model, token_ids, response_mask):
    """
    Compute the sum of log probabilities for response tokens only.

    This is log pi(response | prompt) — the model's log-probability
    of generating the response given the prompt.

    Args:
        model: the language model
        token_ids: (B, T) full sequence tokens
        response_mask: (B, T) binary mask, 1 for response tokens

    Returns:
        (B,) sum of log probs over response tokens for each example
    """
    logits, _ = model(token_ids)  # (B, T, vocab)

    # Shift: logits[t] predicts token[t+1]
    shift_logits = logits[:, :-1, :]         # (B, T-1, vocab)
    shift_labels = token_ids[:, 1:]           # (B, T-1)
    shift_mask = response_mask[:, 1:]         # (B, T-1)

    # Log probabilities
    log_probs = F.log_softmax(shift_logits, dim=-1)  # (B, T-1, vocab)

    # Gather log probs of actual tokens
    token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)  # (B, T-1)

    # Mask to only response tokens and sum
    masked_log_probs = token_log_probs * shift_mask  # (B, T-1)
    return masked_log_probs.sum(dim=-1)  # (B,)

print("get_response_log_probs defined.")
print("This is the building block of DPO — everything else is just algebra.")

In [None]:
# The DPO loss function — this is the entire algorithm

def dpo_loss(policy_model, ref_model, chosen_ids, chosen_masks, rejected_ids, rejected_masks, beta=0.1):
    """
    Compute the DPO loss.

    L = -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))

    where log_ratio = log(pi_theta(y|x)) - log(pi_ref(y|x))
    """
    # Log probs under the current policy
    policy_chosen_logps = get_response_log_probs(policy_model, chosen_ids, chosen_masks)
    policy_rejected_logps = get_response_log_probs(policy_model, rejected_ids, rejected_masks)

    # Log probs under the reference model (frozen, no grad)
    with torch.no_grad():
        ref_chosen_logps = get_response_log_probs(ref_model, chosen_ids, chosen_masks)
        ref_rejected_logps = get_response_log_probs(ref_model, rejected_ids, rejected_masks)

    # Log ratios: how much more does the policy like this vs the reference?
    chosen_log_ratio = policy_chosen_logps - ref_chosen_logps
    rejected_log_ratio = policy_rejected_logps - ref_rejected_logps

    # DPO loss: push chosen_log_ratio up and rejected_log_ratio down
    logits = beta * (chosen_log_ratio - rejected_log_ratio)
    loss = -F.logsigmoid(logits).mean()

    # Useful metrics
    with torch.no_grad():
        chosen_reward = beta * chosen_log_ratio.mean()
        rejected_reward = beta * rejected_log_ratio.mean()
        accuracy = (chosen_log_ratio > rejected_log_ratio).float().mean()

    return loss, chosen_reward.item(), rejected_reward.item(), accuracy.item()

print("dpo_loss defined.")
print("\nThat's it. The ENTIRE DPO algorithm is ~20 lines of code.")
print("Compare with RLHF: reward model + PPO + KL tuning + 4 models.")

## Part 2: DPO Training

Now we train. The setup is beautifully simple:
- **Policy model**: starts as a copy of SFT model (gets updated)
- **Reference model**: frozen copy of SFT model (never changes)
- **No reward model needed**
- **No RL (PPO) needed**

Just a standard training loop with the DPO loss.

In [None]:
# Setup: policy model (trainable) + reference model (frozen)
dpo_model = GPT(sft_vocab_size).to(device)
dpo_model.load_state_dict(sft_model.state_dict())

ref_model = GPT(sft_vocab_size).to(device)
ref_model.load_state_dict(sft_model.state_dict())
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

print("DPO setup:")
print(f"  Policy model: {sum(p.numel() for p in dpo_model.parameters()):,} params (trainable)")
print(f"  Ref model:    frozen copy of SFT model")
print(f"  Total models: 2 (vs 4 for RLHF)")

In [None]:
# DPO Training Loop
dpo_lr = 1e-5
dpo_beta = 0.1  # Temperature — higher = more conservative changes
dpo_iters = 500
dpo_batch_size = 8

dpo_optimizer = torch.optim.AdamW(dpo_model.parameters(), lr=dpo_lr)

# Track metrics
dpo_losses = []
dpo_chosen_rewards = []
dpo_rejected_rewards = []
dpo_accuracies = []

print(f"DPO training (beta={dpo_beta}, lr={dpo_lr})...\n")
dpo_model.train()

for iter in range(dpo_iters):
    # Sample a batch
    batch_idx = torch.randint(len(chosen_ids), (dpo_batch_size,))

    loss, c_reward, r_reward, acc = dpo_loss(
        dpo_model, ref_model,
        chosen_ids[batch_idx], chosen_masks[batch_idx],
        rejected_ids[batch_idx], rejected_masks[batch_idx],
        beta=dpo_beta
    )

    dpo_optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(dpo_model.parameters(), 1.0)
    dpo_optimizer.step()

    dpo_losses.append(loss.item())
    dpo_chosen_rewards.append(c_reward)
    dpo_rejected_rewards.append(r_reward)
    dpo_accuracies.append(acc)

    if iter % 100 == 0:
        print(f"  step {iter}: loss {loss.item():.4f}, acc {acc:.2f}, "
              f"chosen_r {c_reward:+.3f}, rejected_r {r_reward:+.3f}")

print(f"\nDPO training done.")
print(f"Final: loss {dpo_losses[-1]:.4f}, accuracy {dpo_accuracies[-1]:.2f}")

In [None]:
# Plot DPO training metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
window = 30

def smooth(data, w):
    return [sum(data[max(0,i-w):i+1])/min(i+1,w) for i in range(len(data))]

# Loss
axes[0].plot(dpo_losses, alpha=0.3, color='blue')
axes[0].plot(smooth(dpo_losses, window), color='blue', linewidth=2)
axes[0].set_title('DPO Loss')
axes[0].set_xlabel('Step')
axes[0].grid(True, alpha=0.3)

# Implicit rewards (chosen vs rejected)
axes[1].plot(smooth(dpo_chosen_rewards, window), color='green', linewidth=2, label='chosen')
axes[1].plot(smooth(dpo_rejected_rewards, window), color='red', linewidth=2, label='rejected')
axes[1].set_title('Implicit Rewards')
axes[1].set_xlabel('Step')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Accuracy
axes[2].plot(dpo_accuracies, alpha=0.3, color='green')
axes[2].plot(smooth(dpo_accuracies, window), color='green', linewidth=2)
axes[2].axhline(y=0.5, color='red', linestyle='--', label='random')
axes[2].set_title('Preference Accuracy')
axes[2].set_xlabel('Step')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key observation: the implicit reward for chosen responses goes UP")
print("while rejected responses go DOWN — the model learns the preference.")

## Part 3: Results — SFT vs. DPO

Let's see if DPO improved the model's responses.

In [None]:
# Compare SFT vs DPO
test_prompts = [
    "Write a greeting",
    "Speak of love",
    "I am sad",
    "Give advice",
    "Tell me a story",
    "Describe the moon",  # unseen
    "Write a farewell",
]

print("=" * 70)
print("SFT MODEL  vs  DPO MODEL")
print("=" * 70)

for prompt in test_prompts:
    sft_response = chat(sft_model, prompt, max_tokens=35)
    dpo_response = chat(dpo_model, prompt, max_tokens=35)
    print(f"\nUser: {prompt}")
    print(f"  SFT: {sft_response}")
    print(f"  DPO: {dpo_response}")
    print("-" * 70)

In [None]:
# Quantitative: compute implicit rewards for the DPO model
# Remember: in DPO, the reward is implicitly r(x,y) = beta * log(pi/pi_ref)

dpo_model.eval()
ref_model.eval()

print("Implicit reward scores (higher = model prefers this response more):")
print("(DPO model should assign higher implicit reward to chosen vs rejected)\n")

correct = 0
total = 0
with torch.no_grad():
    for prompt, chosen, rejected in preference_data[:8]:
        c_ids, c_mask = prepare_dpo_example(prompt, chosen)
        r_ids, r_mask = prepare_dpo_example(prompt, rejected)

        c_ids_t = torch.tensor([c_ids], dtype=torch.long, device=device)
        c_mask_t = torch.tensor([c_mask], dtype=torch.float, device=device)
        r_ids_t = torch.tensor([r_ids], dtype=torch.long, device=device)
        r_mask_t = torch.tensor([r_mask], dtype=torch.float, device=device)

        # Implicit reward = beta * (log_pi - log_pi_ref)
        policy_c = get_response_log_probs(dpo_model, c_ids_t, c_mask_t)
        ref_c = get_response_log_probs(ref_model, c_ids_t, c_mask_t)
        policy_r = get_response_log_probs(dpo_model, r_ids_t, r_mask_t)
        ref_r = get_response_log_probs(ref_model, r_ids_t, r_mask_t)

        chosen_reward = dpo_beta * (policy_c - ref_c).item()
        rejected_reward = dpo_beta * (policy_r - ref_r).item()
        is_correct = chosen_reward > rejected_reward
        correct += is_correct
        total += 1

        print(f"  '{prompt}'")
        print(f"    Chosen:   {chosen_reward:+.3f}  '{chosen[:35]}...'")
        print(f"    Rejected: {rejected_reward:+.3f}  '{rejected[:35]}...'")
        print(f"    {'CORRECT' if is_correct else 'WRONG'}")
        print()

print(f"Preference accuracy: {correct}/{total} = {correct/total:.1%}")

## Part 4: DPO vs RLHF — Comparison

| | RLHF | DPO |
|---|---|---|
| **Models needed** | 4 (policy, ref, reward, value) | 2 (policy, ref) |
| **Training steps** | Reward model + PPO (two phases) | Single phase |
| **Stability** | PPO is notoriously finicky | Standard supervised loss |
| **Hyperparameters** | KL beta, clip eps, LR, GAE lambda... | Just beta and LR |
| **Reward hacking** | Major risk | Less risk (no explicit reward model to hack) |
| **Code complexity** | ~200+ lines for PPO | ~20 lines for DPO loss |
| **Memory** | 4 models in GPU memory | 2 models |

### When to use which?

**DPO is preferred when:**
- You have preference data available
- You want simplicity and stability
- You're doing offline optimization (not interactive)

**RLHF is preferred when:**
- You need online learning (model generates, humans rate in real-time)
- You want to reuse the reward model for evaluation
- You need to explore diverse responses (PPO enables this)

In practice, most teams now use **DPO or its variants** (IPO, KTO, ORPO) because of the simplicity.

In [None]:
# Visual: the effect of beta on DPO
#
# beta controls how conservative the changes are:
# - Small beta → aggressive changes, risk of overfitting to preferences
# - Large beta → conservative, stays close to SFT
#
# This is analogous to the KL penalty in RLHF, but it's baked into the loss.

print("The role of beta in DPO:")
print()
print("  beta = 0.01 (very low)")
print("    → Aggressive preference optimization")
print("    → Model may overfit to preference patterns")
print("    → Can lose general language quality")
print()
print("  beta = 0.1 (moderate — what we used)")
print("    → Balanced preference learning")
print("    → Good trade-off between preferences and general quality")
print()
print("  beta = 1.0 (high)")
print("    → Very conservative changes")
print("    → Stays very close to the SFT model")
print("    → Preferences are weakly enforced")
print()
print("In the DPO paper: beta=0.1 for summarization, beta=0.5 for dialogue.")
print("The right value depends on your data and how much you trust the preferences.")

## Key Takeaways

**What DPO does:**
- Directly optimizes the language model on preference pairs
- No reward model, no RL — just a clever loss function
- The language model implicitly learns its own reward function

**The math intuition:**
- RLHF objective: maximize reward - KL penalty
- Solve analytically → the optimal policy has a closed form
- Substitute back → DPO loss, which only needs preference pairs
- The reward is implicit: `r(x,y) = beta * log(pi/pi_ref)`

**The full post-training stack we built:**

```
Pre-training (notebook 04)     → Language capability (autocomplete)
  ↓
SFT (notebook 05)              → Instruction-following format
  ↓
RLHF (notebook 06) or          → Response quality (preference alignment)
DPO  (notebook 07)
```

You now understand the entire modern LLM training pipeline from scratch.

### References
- [DPO paper](https://arxiv.org/abs/2305.18290) — the original paper
- [Cameron Wolfe's DPO deep dive](https://cameronrwolfe.substack.com/p/direct-preference-optimization) — best explanation
- [Raschka's DPO notebook](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb)
- [HuggingFace: From RLHF to DPO](https://huggingface.co/blog/ariG23498/rlhf-to-dpo)