# RLHF with Toy Example

The RLHF / RLFT procedure can be illustrated in a very compact form:

1. **Policy** (language model as policy):
   
   $$
   \pi_\theta(y \mid x) = \prod_{t=1}^T \pi_\theta(y_t \mid x, y_{<t})
   $$

2. **Reward model** assigns a scalar score to the prompt–completion pair:
   
   $$
   R = r_\phi(x,y) \in \mathbb{R}
   $$

3. **Baseline and advantage**  
   To reduce variance we normalize rewards in the batch, giving a simple advantage signal:
   
   $$
   \bar{R} = \frac{1}{N}\sum_{i=1}^N R_i,
   \qquad
   \hat{A}_t = R - \bar{R} \quad (\text{broadcast to tokens})
   $$

4. **Log-probabilities and importance ratio**  
   We compare the new policy to the old one that generated the samples:
   
   $$
   \ell^{\text{old}}_t = \log \pi_{\text{old}}(a_t \mid s_t),
   \qquad
   \ell_t = \log \pi_\theta(a_t \mid s_t)
   $$
   
   $$
   \rho_t = \exp(\ell_t - \ell^{\text{old}}_t)
   $$

5. **PPO surrogate objective**  
   Encourages improvement while clipping updates to avoid divergence:
   
   $$
   L_{\text{PPO}}(\theta) = \mathbb{E}_t \Big[
   \min\big(\rho_t \hat{A}_t,\;
   \text{clip}(\rho_t, 1-\epsilon, 1+\epsilon)\,\hat{A}_t \big)
   \Big]
   $$

6. **KL-to-reference penalty**  
   Keeps the policy close to the supervised (SFT) reference model:
   
   $$
   \widehat{\mathrm{KL}} = \mathbb{E}_t\big[\ell_t - \ell^{\text{ref}}_t\big],
   \qquad
   \ell^{\text{ref}}_t = \log \pi_{\text{ref}}(a_t \mid s_t)
   $$

7. **Final loss** (to minimize):
   
   $$
   \mathcal{L}(\theta) = - L_{\text{PPO}}(\theta)\;+\;\beta \,\widehat{\mathrm{KL}}
   $$

---

**Summary:**  
- The policy generates completions.  
- The reward model (or preference signal) gives a scalar reward.  
- We normalize rewards to compute a baseline-adjusted advantage.  
- PPO surrogate ensures stable updates.  
- KL regularization prevents the model from drifting too far from the supervised reference.  


In [None]:
# toy_rlft_bigram_ppo.py
import math, random, copy
from dataclasses import dataclass
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0); random.seed(0)

# -----------------------------
# Tiny tokenizer (char-level)
# -----------------------------
PROMPTS = [
    "Q: 2+3=\nA:", "Q: 1+4=\nA:", "Q: 4+4=\nA:", "Q: 3+2=\nA:",
    "Q: 5+0=\nA:", "Q: 6-1=\nA:", "Q: 7-2=\nA:", "Q: 9-4=\nA:"
]
# Gold short answers for an "auto-judge" (to create preferences)
GOLD = {
    "Q: 2+3=\nA:": "5", "Q: 1+4=\nA:": "5", "Q: 4+4=\nA:": "8",
    "Q: 3+2=\nA:": "5", "Q: 5+0=\nA:": "5", "Q: 6-1=\nA:": "5",
    "Q: 7-2=\nA:": "5", "Q: 9-4=\nA:": "5"
}

SPECIALS = ["<pad>", "<bos>", "<eos>"]
vocab = sorted(set("".join(PROMPTS)+ "".join(GOLD.values()) + " 0123456789+-=:Q\nA"))
itos = SPECIALS + vocab
stoi = {ch:i for i,ch in enumerate(itos)}
PAD, BOS, EOS = stoi["<pad>"], stoi["<bos>"], stoi["<eos>"]
V = len(itos)

def encode(s:str):
    return [BOS] + [stoi[c] for c in s] + [EOS]
def decode(ids:List[int]):
    s = "".join(itos[i] for i in ids if i>=0)
    return s.replace("<bos>","").replace("<eos>","")

# -----------------------------
# Bigram LM (Karpathy-style)
# -----------------------------
class BigramLM(nn.Module):
    def __init__(self, vocab_size:int, d_model:int=64):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
    def forward(self, idx):   # idx: (B,T)
        h = self.tok(idx)    # (B,T,d)
        logits = self.head(h)# (B,T,V)
        return logits
    def logprob_of(self, idx_in, idx_out):  # teacher-forced logprob of next-token targets
        # idx_in, idx_out: (B,T)
        logits = self.forward(idx_in)
        logp = F.log_softmax(logits, dim=-1)
        # gather logp of chosen next tokens
        lp = logp.gather(-1, idx_out.unsqueeze(-1)).squeeze(-1)  # (B,T)
        return lp

# -----------------------------
# Tiny Reward Model r_phi(x,y)
# Encoder = mean of token embeddings -> linear scalar
# -----------------------------
class RewardModel(nn.Module):
    def __init__(self, vocab_size:int, d:int=64):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d)
        self.head = nn.Linear(d, 1)
    def forward(self, xy):   # xy: (B,T)
        e = self.emb(xy)    # (B,T,d)
        m = e.mean(dim=1)   # (B,d)
        return self.head(m).squeeze(-1)  # (B,)

# -----------------------------
# Utilities
# -----------------------------
def pad_to_batch(seqs:List[List[int]], pad_id:int=PAD):
    T = max(len(s) for s in seqs)
    out = []
    for s in seqs:
        out.append(s + [pad_id]*(T-len(s)))
    return torch.tensor(out, dtype=torch.long)

@torch.no_grad()
def sample_completion(model:BigramLM, prompt_ids:List[int], max_new:int=10, temperature:float=1.0):
    model.eval()
    ids = prompt_ids[:]  # include BOS ... tokens
    for _ in range(max_new):
        x = torch.tensor(ids, dtype=torch.long).unsqueeze(0)  # (1,t)
        logits = model(x)[:, -1, :] / max(1e-6, temperature)
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1).item()
        ids.append(next_id)
        if next_id == EOS: break
    return ids

# -----------------------------
# SFT stage (optional but recommended)
# Train policy to imitate a few gold answers
# -----------------------------
def sft_train(policy:BigramLM, pairs:List[Tuple[str,str]], iters=200, lr=1e-2):
    opt = torch.optim.AdamW(policy.parameters(), lr=lr)
    policy.train()
    for it in range(iters):
        X, Y = [], []
        for x,y in pairs:
            sx = encode(x)
            sy = encode(y)
            # teacher force next-token: input is prefix, target is next tokens in y
            # here, just concatenate x and y so LM learns to answer after prompt
            xy = sx + sy[1:]  # share one BOS
            x_in = xy[:-1]; x_out = xy[1:]
            X.append(x_in); Y.append(x_out)
        xb = pad_to_batch(X); yb = pad_to_batch(Y)
        logp = policy.logprob_of(xb, yb)
        loss = -logp.masked_fill(yb==PAD, 0).mean()
        opt.zero_grad(); loss.backward(); opt.step()
        if (it+1)%100==0:
            print(f"[SFT] iter {it+1} loss={loss.item():.3f}")

# -----------------------------
# Build preference data to train Reward Model
# Use an AUTOMATIC judge: +1 if final char matches GOLD, else 0.
# You can flip to RANDOM to illustrate noisy preferences.
# -----------------------------
def build_pref_dataset(policy:BigramLM, prompts:List[str], per_prompt:int=4,
                       judge="AUTO") -> List[Tuple[List[int], List[int], int]]:
    data = []
    for x in prompts:
        px = encode(x)
        cands = []
        for _ in range(per_prompt):
            y_ids = sample_completion(policy, px, max_new=3, temperature=1.2)
            cands.append(y_ids)
        # create pairwise preferences
        for i in range(len(cands)):
            for j in range(i+1, len(cands)):
                yi, yj = cands[i], cands[j]
                if judge == "RANDOM":
                    pref = random.choice([0,1])   # 1 if yi preferred, else 0
                else:
                    # AUTO: read last visible char vs gold
                    out_i = decode(yi)[len(x):].strip()
                    out_j = decode(yj)[len(x):].strip()
                    gi = int(GOLD[x]); si = 1 if out_i[:1]==str(gi) else 0
                    sj = 1 if out_j[:1]==str(gi) else 0
                    if si==sj:   # tie-break
                        pref = 1 if len(out_i)>=len(out_j) else 0
                    else:
                        pref = 1 if si>sj else 0
                data.append((px, yi, pref))  # store one-sided pairs by swapping as needed below
                data.append((px, yj, 1-pref))
    return data

def train_reward_model(rm:RewardModel, pref_data, iters=400, lr=1e-3):
    # Convert pairwise prefs into Bradley–Terry mini-batches by sampling pairs
    opt = torch.optim.AdamW(rm.parameters(), lr=lr)
    for it in range(iters):
        batch = random.sample(pref_data, k=min(16, len(pref_data)))
        # Build pairs y_w, y_l w.r.t. pref flag
        XYw, XYl = [], []
        for px, yseq, pref in batch:
            # draw a second sample from same prompt to oppose
            px2, yseq2, pref2 = random.choice([z for z in pref_data if z[0]==px])
            # decide winner/loser using stored pref flags
            if pref==1: (yw,yl)=(yseq, yseq2)
            else:       (yw,yl)=(yseq2, yseq)
            XYw.append(pad_to_batch([px + yw[ len(px): ]]))  # prompt+completion
            XYl.append(pad_to_batch([px + yl[ len(px): ]]))
        XYw = torch.cat(XYw, dim=0); XYl = torch.cat(XYl, dim=0)
        sw = rm(XYw); sl = rm(XYl)
        loss = F.binary_cross_entropy_with_logits(sw - sl, torch.ones_like(sw))
        opt.zero_grad(); loss.backward(); opt.step()
        if (it+1)%100==0:
            print(f"[RM ] iter {it+1} loss={loss.item():.3f}")

# -----------------------------
# PPO fine-tuning against r_phi
# critic-free PPO with batch baseline + KL to reference
# -----------------------------
@dataclass
class PPOCfg:
    max_new:int = 3
    batch_size:int = 16
    epochs:int = 50
    ppo_steps:int = 4
    lr:float = 5e-3
    eps:float = 0.2        # PPO clip
    beta:float = 0.05      # KL-to-reference strength
    temperature:float = 1.0

def rollout(policy:BigramLM, prompts:List[str], max_new:int, temperature:float):
    # returns lists of (prompt_ids, full_ids, actions_ids, old_logprobs per token)
    policy.eval()
    records = []
    with torch.no_grad():
        for x in random.sample(prompts, len(prompts)):
            px = encode(x)
            ids = px[:]
            old_lps = []
            actions = []
            for _ in range(max_new):
                x_in = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
                logits = policy(x_in)[:, -1, :] / max(1e-6, temperature)
                probs = F.softmax(logits, dim=-1)
                a = torch.multinomial(probs, num_samples=1).item()
                lp = torch.log(probs[0, a]+1e-9).item()
                actions.append(a); old_lps.append(lp)
                ids.append(a)
                if a==EOS: break
            records.append((px, ids, actions, old_lps))
    return records

def ppo_train(policy:BigramLM, ref:BigramLM, rm:RewardModel, cfg:PPOCfg):
    opt = torch.optim.AdamW(policy.parameters(), lr=cfg.lr)
    for epoch in range(cfg.epochs):
        # freeze a snapshot as old policy for this epoch
        old = copy.deepcopy(policy).eval()
        # collect rollouts
        recs = rollout(old, PROMPTS, cfg.max_new, cfg.temperature)
        # score rewards with reward model
        with torch.no_grad():
            Rs = []
            XY = []
            for px, ids, actions, _ in recs:
                xy = pad_to_batch([ids])
                XY.append(xy)
            XY = torch.cat(XY, dim=0)
            R = rm(XY)  # (N,)
            Rs = R.tolist()
        # normalize rewards (baseline)
        meanR = sum(Rs)/len(Rs); stdR = max(1e-6, (sum((r-meanR)**2 for r in Rs)/len(Rs))**0.5)

        # multiple PPO epochs over the same batch (toy)
        for _ in range(cfg.ppo_steps):
            L_clip, KL = 0.0, 0.0
            n_tok = 0
            for (px, ids, actions, old_lps), R in zip(recs, Rs):
                # advantages: broadcast sequence-level baseline
                A = (R - meanR)/stdR
                # get per-token new/old/ref logprobs for taken actions
                # build per-time-step contexts to compute logprobs
                new_lps = []; ref_lps = []
                cur = px[:]
                for a in actions:
                    x_in = torch.tensor(cur, dtype=torch.long).unsqueeze(0)
                    lp_new = F.log_softmax(policy(x_in)[:, -1, :], dim=-1)[0, a]
                    lp_ref = F.log_softmax(ref(x_in)[:, -1, :], dim=-1)[0, a]
                    new_lps.append(lp_new)
                    ref_lps.append(lp_ref)
                    cur = cur + [a]
                new_lps = torch.stack(new_lps); ref_lps = torch.stack(ref_lps)
                old_lps_t = torch.tensor(old_lps)

                ratio = torch.exp(new_lps - old_lps_t)
                unclipped = ratio * A
                clipped   = torch.clamp(ratio, 1-cfg.eps, 1+cfg.eps) * A
                L_clip += torch.minimum(unclipped, clipped).sum()

                KL += (new_lps - ref_lps).sum()  # MC estimate of KL on sampled actions
                n_tok += len(actions)

            # normalize by token count
            L_clip = L_clip / max(1, n_tok)
            KL = KL / max(1, n_tok)
            loss = -L_clip + cfg.beta * KL
            opt.zero_grad(); loss.backward(); opt.step()

        print(f"[PPO] epoch {epoch+1:03d} loss={loss.item():.3f} L_clip={L_clip.item():.3f} KL={KL.item():.3f} meanR={meanR:.3f}")

# -----------------------------
# Run the pipeline
# -----------------------------
if __name__ == "__main__":
    # 1) init policy, reference, reward model
    policy = BigramLM(V, d_model=64)
    ref    = copy.deepcopy(policy).eval()  # reference anchor (after SFT we reset this)
    rm     = RewardModel(V, d=64)

    # 2) (optional) SFT on a few gold exemplars
    sft_pairs = [(x, GOLD[x]) for x in PROMPTS]
    sft_train(policy, sft_pairs, iters=200, lr=1e-2)
    ref = copy.deepcopy(policy).eval()  # freeze SFT as reference

    # 3) Build preference data using AUTO judge (or "RANDOM")
    pref_data = build_pref_dataset(policy, PROMPTS, per_prompt=4, judge="AUTO")

    # 4) Train reward model from preferences (Bradley–Terry)
    train_reward_model(rm, pref_data, iters=300, lr=1e-3)

    # 5) PPO fine-tune policy against r_phi with KL-to-ref
    cfg = PPOCfg(epochs=30, ppo_steps=3, beta=0.05, eps=0.2, lr=5e-3, max_new=3, batch_size=16, temperature=1.2)
    ppo_train(policy, ref, rm, cfg)

    # 6) Inspect generations
    for x in PROMPTS[:4]:
        ids = sample_completion(policy, encode(x), max_new=3, temperature=0.8)
        print(x, decode(ids)[len(x):].strip())
