In [None]:
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shang-vikas/series1-coding-exercises/blob/main/exercises/blog-09/exercise-03.ipynb)

# REINFORCE with Reward Shaping Demo

This notebook demonstrates:
- A small pretrained causal LM (distilgpt2) as policy
- A sentiment classifier pipeline as reward oracle
- REINFORCE-style policy updates (educational)
- A clear reward-hacking demonstration (unshaped reward)
- Safety reward shaping (repetition / profanity / short-response penalties)
- Side-by-side visualizations and sample galleries

**Runtime note**: This downloads models (~100–300MB). Use a GPU runtime in Colab for speed (optional but recommended). Keep iteration counts small (defaults are conservative).

In [None]:
# run in Colab cell
%pip install -q transformers torch datasets sentencepiece matplotlib

# Python cell
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import numpy as np
import random
import matplotlib.pyplot as plt
from copy import deepcopy
import time

# reproducible-ish
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
policy_name = "distilgpt2"  # small, fast for demo
print("Loading policy:", policy_name)
policy_tok = AutoTokenizer.from_pretrained(policy_name)
# ensure pad token exists
if policy_tok.pad_token is None:
    policy_tok.pad_token = policy_tok.eos_token

policy = AutoModelForCausalLM.from_pretrained(policy_name).to(device)
policy.train()  # we will update it via REINFORCE

# reward oracle: sentiment classifier (SST-like)
print("Loading sentiment pipeline (reward oracle)...")
sentiment = pipeline(
    "sentiment-analysis",
    model="distilbert-base-uncased-finetuned-sst-2-english",
    device=0 if device == "cuda" else -1
)

# optimizer for policy (small LR for safety)
optimizer = torch.optim.AdamW(policy.parameters(), lr=1e-5)
print("Models loaded.")

In [None]:
# ---------- Sampling helpers ----------
def sample_with_logprobs_train(prompt_ids, max_new_tokens=20, temperature=1.0, top_k=None):
    """
    Sample step-by-step while retaining gradient on log-probs (for REINFORCE).
    Returns:
      generated_ids (tensor 1 x L), logprob_sum (tensor scalar)
    """
    generated = prompt_ids.to(device)
    logprob_sum = 0.0  # torch scalar to accumulate
    for _ in range(max_new_tokens):
        outputs = policy(generated)  # returns logits with grad
        logits = outputs.logits  # (1, seq_len, vocab)
        next_logits = logits[:, -1, :] / max(temperature, 1e-9)
        if top_k is not None:
            v, idx = torch.topk(next_logits, top_k)
            mask = torch.full_like(next_logits, -1e9)
            mask[:, idx[0]] = next_logits[:, idx[0]]
            next_logits = mask
        probs = F.softmax(next_logits, dim=-1)  # (1, V)
        next_token = torch.multinomial(probs, num_samples=1)  # (1,1)
        # compute logprob of sampled token (this retains grad)
        logp = torch.log(probs.gather(-1, next_token)).squeeze()
        if isinstance(logprob_sum, float):
            logprob_sum = logp
        else:
            logprob_sum = logprob_sum + logp
        generated = torch.cat([generated, next_token], dim=1)
    return generated, logprob_sum

def sample_for_eval(prompt, max_new_tokens=20, temperature=1.0, top_k=None):
    """Sampling for evaluation (no gradient, faster). Returns decoded string."""
    policy.eval()
    input_ids = policy_tok(prompt, return_tensors="pt").input_ids.to(device)
    generated = input_ids
    with torch.no_grad():
        for _ in range(max_new_tokens):
            logits = policy(generated).logits
            next_logits = logits[:, -1, :] / max(temperature, 1e-9)
            if top_k is not None:
                v, idx = torch.topk(next_logits, top_k)
                mask = torch.full_like(next_logits, -1e9)
                mask[:, idx[0]] = next_logits[:, idx[0]]
                next_logits = mask
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=1)
    return policy_tok.decode(generated[0].cpu().numpy(), skip_special_tokens=True)

# ---------- Reward oracle ----------
def sentiment_reward(text):
    """Return sentiment reward in [-1, +1] (positive -> positive reward)."""
    try:
        out = sentiment(text[:512])  # limit input size for the classifier
    except Exception as e:
        # if pipeline hiccups, return neutral
        return 0.0
    lab = out[0]["label"].upper()
    sc = float(out[0]["score"])
    return sc if lab.startswith("POSITIVE") else -sc

# ---------- Simple penalties for shaping ----------
BAD_WORDS = {"idiot", "stupid", "damn"}  # toy example blacklist

def repetition_penalty(text, ngram=3):
    toks = text.split()
    if len(toks) < 2:
        return 0.0
    # longest identical-token run
    max_run = 1
    run = 1
    for i in range(1, len(toks)):
        if toks[i] == toks[i-1]:
            run += 1
            max_run = max(max_run, run)
        else:
            run = 1
    # repeated ngram count
    ngram_counts = {}
    for i in range(len(toks)-ngram+1):
        ng = tuple(toks[i:i+ngram])
        ngram_counts[ng] = ngram_counts.get(ng, 0) + 1
    rep_count = sum(1 for v in ngram_counts.values() if v > 1)
    return 0.5 * max_run + 1.0 * rep_count

def profanity_penalty(text):
    toks = set(w.lower().strip(".,!?;:") for w in text.split())
    hits = toks & BAD_WORDS
    return float(len(hits))

def length_penalty(text, min_len=5):
    toks = text.split()
    if len(toks) >= min_len:
        return 0.0
    return (min_len - len(toks)) * 0.5

def shaped_reward(text, lambda_rep=0.7, lambda_prof=1.0, lambda_short=0.6):
    base = sentiment_reward(text)
    rep = repetition_penalty(text)
    prof = profanity_penalty(text)
    short = length_penalty(text)
    return base - lambda_rep * rep - lambda_prof * prof - lambda_short * short

In [None]:
# Running baseline for variance reduction
running_baseline = 0.0
alpha_baseline = 0.02  # small smoothing factor

def reinforce_update(prompt, shaped=False, max_new_tokens=12, temperature=0.8, top_k=None, scale=1.0):
    """
    Single REINFORCE update step:
     - sample tokens retaining logprobs
     - compute scalar reward (shaped or base)
     - compute loss = - (reward - baseline) * logprob_sum
     - backprop & optimizer.step()
    Returns generated_text, reward
    """
    global running_baseline
    policy.train()
    inp = policy_tok(prompt, return_tensors="pt").input_ids.to(device)
    generated_ids, logprob_sum = sample_with_logprobs_train(inp, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    generated_text = policy_tok.decode(generated_ids[0].cpu().numpy(), skip_special_tokens=True)
    # compute reward (outside gradient)
    if shaped:
        r = shaped_reward(generated_text)
    else:
        r = sentiment_reward(generated_text)
    r = float(r) * scale
    # update running baseline
    running_baseline = (1 - alpha_baseline) * running_baseline + alpha_baseline * r
    advantage = r - running_baseline
    # loss is negative advantage times logprob sum (note logprob_sum is a tensor)
    loss = - (advantage) * logprob_sum
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
    optimizer.step()
    return generated_text, r, advantage

# visualization helpers
def plot_rewards(history_unshaped, history_shaped):
    plt.figure(figsize=(10,4))
    plt.plot(history_unshaped, label="unshaped (sentiment only)")
    plt.plot(history_shaped, label="shaped (penalties applied)")
    plt.xlabel("step")
    plt.ylabel("reward")
    plt.legend()
    plt.title("Reward over steps")
    plt.show()

def show_samples(samples, title):
    print("----", title, "----")
    for i, s in enumerate(samples):
        print(f"{i+1}.", s.replace("\n", " ")[:300])
    print()

In [None]:
seed_prompts = [
    "The movie was",
    "I felt that the film",
    "The lead actor performance",
]

print("=== Baseline samples (no policy updates yet) ===")
for p in seed_prompts:
    out = sample_for_eval(p, max_new_tokens=12, temperature=0.8)
    print("PROMPT:", p)
    print("GENERATED:", out)
    print("REWARD (sentiment):", sentiment_reward(out))
    print("-"*60)

In [None]:
# Save initial weights to compare later
initial_state = deepcopy(policy.state_dict())

print("Starting reward-hacking phase (unshaped sentiment reward).")
steps = 40  # small, quick demo — increase to observe stronger hacking
history_unshaped = []
samples_unshaped = []

start_time = time.time()
for step in range(steps):
    prompt = random.choice(seed_prompts)
    txt, r, adv = reinforce_update(prompt, shaped=False, max_new_tokens=10, temperature=0.8)
    history_unshaped.append(r)
    if step % 5 == 0:
        samples_unshaped.append(txt)
    if step % 10 == 0:
        print(f"[{step:02d}] reward {r:.3f} adv {adv:.3f} sample: {txt[:120]}")
print("Hacking phase done in {:.1f}s".format(time.time()-start_time))

# show some samples
show_samples(samples_unshaped, "Samples from hacked policy (periodic)")

In [None]:
# repetition scores for collected periodic samples
rep_scores = [repetition_penalty(s) for s in samples_unshaped]
plt.figure(figsize=(10,3))
plt.plot(history_unshaped, label="unshaped reward")
plt.title("Unshaped Reward over steps")
plt.xlabel("step")
plt.ylabel("sentiment reward")
plt.show()

plt.figure(figsize=(6,3))
plt.bar(range(len(rep_scores)), rep_scores)
plt.title("Repetition penalty for periodic hacked samples")
plt.xlabel("sample idx")
plt.ylabel("rep penalty")
plt.show()


# Teaching note (brief): at this stage you will often see the policy produce short, 
# highly positive repeated phrases (e.g., "great great great"). That's exactly 
# reward-hacking: the policy finds easy high-reward patterns the oracle values, 
# even if they're undesirable.

In [None]:
# restore initial weights for fair comparison
policy.load_state_dict(initial_state)
optimizer = torch.optim.AdamW(policy.parameters(), lr=1e-5)  # re-init optimizer

print("Running shaped REINFORCE (penalties applied).")
history_shaped = []
samples_shaped = []

start_time = time.time()
for step in range(40):  # same budget
    prompt = random.choice(seed_prompts)
    txt, r, adv = reinforce_update(prompt, shaped=True, max_new_tokens=10, temperature=0.8)
    history_shaped.append(r)
    if step % 5 == 0:
        samples_shaped.append(txt)
    if step % 10 == 0:
        print(f"[{step:02d}] shaped_reward {r:.3f} adv {adv:.3f} sample: {txt[:120]}")
print("Shaped phase done in {:.1f}s".format(time.time()-start_time))

show_samples(samples_shaped, "Samples from shaped policy (periodic)")

In [None]:
# plotting rewards side-by-side
plot_rewards(history_unshaped, history_shaped)

# histograms of repetition penalty across samples (use more samples via evaluation)
def collect_samples(policy_mode, n=50):
    out_samples = []
    for _ in range(n):
        p = random.choice(seed_prompts)
        s = sample_for_eval(p, max_new_tokens=12, temperature=0.8)
        out_samples.append(s)
    return out_samples

samples_before = collect_samples("unshaped", n=50)  # sampled from current shaped policy? we restored earlier - fine
rep_before = [repetition_penalty(s) for s in samples_before]
samples_after = collect_samples("shaped", n=50)
rep_after = [repetition_penalty(s) for s in samples_after]

plt.figure(figsize=(10,4))
plt.hist(rep_before, bins=10, alpha=0.6, label="unshaped (eval)")
plt.hist(rep_after, bins=10, alpha=0.6, label="shaped (eval)")
plt.legend()
plt.title("Repetition penalty distribution (eval samples)")
plt.show()

print("Avg rep (unshaped eval):", np.mean(rep_before))
print("Avg rep (shaped eval):", np.mean(rep_after))

In [None]:
# For fair side-by-side, restore initial state, then sample:
policy.load_state_dict(initial_state)  # fresh start
# generate unshaped samples by applying a few REINFORCE unshaped steps to a copy
policy_unshaped = deepcopy(policy)
optimizer_un = torch.optim.AdamW(policy_unshaped.parameters(), lr=1e-5)
# perform a few unshaped updates on copy to simulate hacking policy
# temporarily swap global policy to use the copy
orig_policy_state = deepcopy(policy.state_dict())
policy.load_state_dict(policy_unshaped.state_dict())
for _ in range(25):
    prompt = random.choice(seed_prompts)
    # sample with logprobs using policy_unshaped (via global policy swap)
    generated_ids, logprob_sum = sample_with_logprobs_train(policy_tok(prompt, return_tensors="pt").input_ids.to(device), max_new_tokens=10)
    gen_text = policy_tok.decode(generated_ids[0].cpu().numpy(), skip_special_tokens=True)
    r = sentiment_reward(gen_text)
    # simple REINFORCE update on copy
    loss = - (r - 0.0) * logprob_sum
    optimizer_un.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_unshaped.parameters(), 1.0)
    optimizer_un.step()
    # sync global policy with copy
    policy.load_state_dict(policy_unshaped.state_dict())

# now shaped policy by running shaped updates on another copy
policy.load_state_dict(initial_state)  # reset to initial
policy_shaped = deepcopy(policy)
optimizer_sh = torch.optim.AdamW(policy_shaped.parameters(), lr=1e-5)
policy.load_state_dict(policy_shaped.state_dict())
for _ in range(25):
    prompt = random.choice(seed_prompts)
    generated_ids, logprob_sum = sample_with_logprobs_train(policy_tok(prompt, return_tensors="pt").input_ids.to(device), max_new_tokens=10)
    gen_text = policy_tok.decode(generated_ids[0].cpu().numpy(), skip_special_tokens=True)
    r = shaped_reward(gen_text)
    loss = - (r - 0.0) * logprob_sum
    optimizer_sh.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(policy_shaped.parameters(), 1.0)
    optimizer_sh.step()
    # sync global policy with copy
    policy.load_state_dict(policy_shaped.state_dict())

# now sample from both
print("=== Samples from hacked (unshaped-trained) policy copy ===")
policy.load_state_dict(policy_unshaped.state_dict())
for p in seed_prompts:
    txt = sample_for_eval(p, max_new_tokens=12, temperature=0.8)
    print("-", txt)

print("\n=== Samples from shaped-trained policy copy ===")
policy.load_state_dict(policy_shaped.state_dict())
for p in seed_prompts:
    txt = sample_for_eval(p, max_new_tokens=12, temperature=0.8)
    print("-", txt)


# Note: the gallery above uses small update budgets (25 steps) to keep the demo quick. 
# In practice, reward-hacking becomes more extreme with more updates; shaping penalties 
# scale with usage.

### What you just saw (short)
- The policy optimizes whatever scalar reward you give it. A naive sentiment oracle can be easily exploited (repetition, short slogans).  
- Reward shaping (penalties for repetition, profanity, short answers) changes the optimization objective so that the policy prefers more natural outputs.  
- This demo is a minimal educational REINFORCE loop — real RLHF uses learned reward models, large human-labeled preference datasets, advantage normalization, and PPO-like updates with KL-constraints to limit policy drift.  
- Important engineering lessons: always inspect outputs, track simple metrics (repetition score, profanity hits, length), and combine multiple orthogonal reward/signals rather than a single scalar.

### Next steps you can drop into the notebook
- Replace the toy profanity list with a real classifier.
- Train a small reward model from human preference pairs instead of using sentiment.
- Swap REINFORCE for a tiny PPO implementation (requires value baseline).
- Add KL-penalty to keep the tuned policy close to the original LM (prevents collapse).