In [13]:
# Recreate SFT step (toy) for investigator pθ(x|y)
import os, time, math, random
from dataclasses import dataclass
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
from torch.optim import AdamW

# ---- Config ----
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
PM_ID  = "gpt2-large"   # target LM pm (forward x->y)
INV_ID = "gpt2-large"   # investigator base pθ (we'll fine-tune)

# Tiny prefix distribution P_SFT
PREFIXES = [
    "A short note about machine learning:",
    "Three tips for staying productive:",
    "A gentle introduction to probability:",
    "In a surprising discovery, scientists found",
    "As a software engineer, I often consider",
    "In Japan, the Shinkansen is known for",
    "A concise summary of the book is",
    "The quick brown fox",
    "An explanation for beginners:",
    "A brief overview of databases:"
]

NUM_EXAMPLES   = 10     # size of DSFT
MAX_PREFIX_EXT = 8      # optional: extend x slightly for variety
MAX_SUFFIX_LEN = 15     # greedy suffix length
BATCH_SIZE     = 2
EPOCHS         = 2
LR             = 5e-5
WARMUP_RATIO   = 0.1
GRAD_ACCUM_STEPS = 1

# Repro
random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

print(f"[INFO] Device: {DEVICE}")


[INFO] Device: cuda


In [14]:
t0 = time.time()

print("[STEP] Loading tokenizer (pm)...")
tok_pm = AutoTokenizer.from_pretrained(PM_ID)
if tok_pm.pad_token is None:
    tok_pm.pad_token = tok_pm.eos_token

print("[STEP] Loading target model pm (x->y)... (first run may download)")
pm = AutoModelForCausalLM.from_pretrained(
    PM_ID,
    use_safetensors=True,
    torch_dtype=torch.float16,   # ✅ use half precision
    device_map="auto"            # ✅ automatically put on GPU
)
pm.eval()

print("[STEP] Loading tokenizer (investigator pθ)...")
tok_inv = AutoTokenizer.from_pretrained(INV_ID)
if tok_inv.pad_token is None:
    tok_inv.pad_token = tok_inv.eos_token

print("[STEP] Loading investigator base model pθ (will fine-tune)...")
inv = AutoModelForCausalLM.from_pretrained(
    INV_ID,
    use_safetensors=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

print(f"[TIME] Models loaded in {time.time()-t0:.2f}s")
# After loading your model
print(next(pm.parameters()).device)
print(next(inv.parameters()).device)


[STEP] Loading tokenizer (pm)...
[STEP] Loading target model pm (x->y)... (first run may download)
[STEP] Loading tokenizer (investigator pθ)...
[STEP] Loading investigator base model pθ (will fine-tune)...
[TIME] Models loaded in 32.27s
cuda:0
cuda:0


In [15]:
def gen_from(model, tok, text, max_new_tokens, device=DEVICE):
    """Sample continuation from given text (avoids greedy repetition)."""
    inputs = tok(text, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,           # sampling instead of greedy
            top_p=0.95,               # nucleus sampling
            temperature=0.8,          # soften logits
            pad_token_id=tok.eos_token_id,
        )
    return tok.decode(out[0], skip_special_tokens=True)

print("[STEP] Building DSFT: sample x ~ P_SFT, then y ← pm(x) (sampled)")
t_ds = time.time()
pairs = []
for i in range(NUM_EXAMPLES):
    x = random.choice(PREFIXES)
    full = gen_from(pm, tok_pm, x, max_new_tokens=MAX_SUFFIX_LEN)
    y = full[len(x):].strip() if full.startswith(x) else full
    pairs.append((x, y))

print(f"[TIME] Built DSFT with {len(pairs)} examples in {time.time()-t_ds:.2f}s")
print("\n=== PRINTING DSFT (x, y) ===")
for idx, (x, y) in enumerate(pairs, 1):
    print(f"\n[{idx}] x: {x}")
    print(f"[{idx}] y: {y}")
    

[STEP] Building DSFT: sample x ~ P_SFT, then y ← pm(x) (sampled)
[TIME] Built DSFT with 10 examples in 3.97s

=== PRINTING DSFT (x, y) ===

[1] x: Three tips for staying productive:
[1] y: Learn to read

Read books that have lots of examples and

[2] x: A short note about machine learning:
[2] y: it's not going to fix the world, because there will always be people

[3] x: As a software engineer, I often consider
[3] y: myself a good programmer (and even though I hate to admit it, I

[4] x: In a surprising discovery, scientists found
[4] y: that the bacteria's DNA is actually made up of five chromosomes, each of

[5] x: In a surprising discovery, scientists found
[5] y: that the brain's dopamine system had a higher response to nicotine than to cocaine

[6] x: A gentle introduction to probability:
[6] y: the power of statistics

The power of statistics is that it allows you

[7] x: Three tips for staying productive:
[7] y: 1) Set a timer for a half hour and work for 45

[8] x: An explana

In [45]:
IN_CONTEXT_PREFIX = "Suffix:\n"
MID_PROMPT = "\nPrefix:\n"

class XYDataset(torch.utils.data.Dataset):
    """
    For each (x, y), we feed the model:
      input:  'Suffix:\\n{y}\\nPrefix:\\n' + x
      labels: supervise only on x (mask out the suffix part)
    """
    def __init__(self, pairs, tokenizer, max_len=256):
        self.tok = tokenizer
        self.max_len = max_len
        self.items = []
        for (x, y) in pairs:
            src = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
            tgt = x
            self.items.append((src, tgt))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, i):
        src, tgt = self.items[i]
        enc_all = self.tok(
            src + tgt,
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )
        input_ids = enc_all["input_ids"][0]
        attn = enc_all["attention_mask"][0]

        # mask out source portion
        src_len = len(self.tok(src, truncation=True, max_length=self.max_len)["input_ids"])
        labels = input_ids.clone()
        labels[:src_len] = -100
        return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}

def collate(batch):
    keys = batch[0].keys()
    out = {}
    for k in keys:
        seqs = [b[k] for b in batch]
        pad_val = tok_inv.pad_token_id if k != "labels" else -100
        out[k] = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=pad_val)
    return out

# Split train/val
cut = int(0.8 * len(pairs))
train_pairs = pairs[:cut]
val_pairs   = pairs[cut:]

train_ds = XYDataset(train_pairs, tok_inv)
val_ds   = XYDataset(val_pairs, tok_inv)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate)
val_loader   = torch.utils.data.DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

print(f"[INFO] Train size: {len(train_ds)} | Val size: {len(val_ds)}")

# Peek at tokenized structure
sample_item = train_ds[0]
print("\n[SAMPLE TOKENS]")
print("input_ids len:", sample_item["input_ids"].shape[0])
print("supervised tokens:", (sample_item["labels"] != -100).sum().item())


[INFO] Train size: 8 | Val size: 2

[SAMPLE TOKENS]
input_ids len: 68
supervised tokens: 6


In [46]:
from torch.optim import AdamW  # use PyTorch AdamW

print("[STEP] Training investigator pθ to predict x from y (SFT)...")
t_train = time.time()
inv.train()
opt = AdamW(inv.parameters(), lr=LR)

num_training_steps = EPOCHS * math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
num_warmup = int(WARMUP_RATIO * num_training_steps)
sched = get_cosine_schedule_with_warmup(opt, num_warmup, num_training_steps)

global_step = 0
for epoch in range(1, EPOCHS + 1):
    running = 0.0
    for step, batch in enumerate(train_loader, 1):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        out = inv(**batch)
        loss = out.loss / GRAD_ACCUM_STEPS
        loss.backward()
        if step % GRAD_ACCUM_STEPS == 0:
            opt.step()
            sched.step()
            opt.zero_grad()
            global_step += 1
        running += loss.item() * GRAD_ACCUM_STEPS
        if step % 5 == 0 or step == len(train_loader):
            print(f"  [epoch {epoch} step {step}/{len(train_loader)}] loss={running/step:.4f}")

print(f"[TIME] Training done in {time.time()-t_train:.2f}s")

[STEP] Training investigator pθ to predict x from y (SFT)...
  [epoch 1 step 4/4] loss=10.8248
  [epoch 2 step 4/4] loss=10.8238
[TIME] Training done in 0.53s


In [47]:
inv.eval()

def inv_predict_prefix(y: str, max_new_tokens=32):
    prompt = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
    enc = tok_inv(prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        gen = inv.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tok_inv.eos_token_id,
        )
    text = tok_inv.decode(gen[0], skip_special_tokens=True)
    return text.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in text else text

def pm_continue(x: str, max_new_tokens=40):
    full = gen_from(pm, tok_pm, x, max_new_tokens=max_new_tokens)
    return full[len(x):].strip() if full.startswith(x) else full

print("[STEP] Sampling a few validations to see behavior...")
for i, (x_true, y_true) in enumerate(val_pairs[:5], 1):
    x_hat = inv_predict_prefix(y_true, max_new_tokens=32)
    y_from_hat = pm_continue(x_hat, max_new_tokens=40)

    print("\n--- Example", i, "---")
    print("y_true (suffix):", (y_true[:100] + "...") if len(y_true) > 100 else y_true)
    print("x_true (gold prefix):", x_true)
    print("x_hat  (pred prefix):", x_hat)
    print("pm continuation from x_hat:", (y_from_hat[:100] + "...") if len(y_from_hat) > 100 else y_from_hat)

print("\n[ALL DONE]")

[STEP] Sampling a few validations to see behavior...

--- Example 1 ---
y_true (suffix): reviewing Participationimura Daniel Hancockatisf confir autonomyohopresshibitoho TA Habit trilogy Pa...
x_true (gold prefix): Three tips for staying productive:
x_hat  (pred prefix): 653 deflect Medic factorspublic Dreams Television Singapore skillet factors predatorspublic Medic linedMini membershipSexualshows Pocket Television448 skillet perhaps praying skillet mutual653 courtyardOutside Television equate soy
pm continuation from x_hat: Wheels boils 236ozygGypublic grandchildren Tre TelevisionSexual BooneMini boils factors courtyardaci...

--- Example 2 ---
y_true (suffix): imura Money antibioticRocketmediately stairs stairs subst Hancock scalp Rh Prob ONE intermittentiken...
x_true (gold prefix): A brief overview of databases:
x_hat  (pred prefix): Late factors clearer workshopsProsPros lined membershipived deflect predators Pocket Boone courtyard Lateivedacious Dreamspublic boils skillet grandc

In [48]:
import random, torch
import torch.nn.functional as F

def inv_sample_prefix(y: str, max_new_tokens=48):
    """
    Sample a candidate prefix x ~ pθ(.|y) from the current investigator `inv`.
    Uses the same input template as SFT: 'Suffix:\\n{y}\\nPrefix:\\n'.
    """
    prompt = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
    dev = next(inv.parameters()).device   # <<< detect model's device
    enc = tok_inv(prompt, return_tensors="pt").to(dev)

    inv.eval()
    with torch.no_grad():
        out = inv.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=True,       # sample to get diverse candidates
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tok_inv.eos_token_id,
        )
    text = tok_inv.decode(out[0], skip_special_tokens=True)
    return text.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in text else text.strip()

def logprob_y_given_x(pm_model, tok, x: str, y: str):
    """
    Compute log p_m(y | x) under the target LM via teacher forcing.
    Returns: (sum_logprob, avg_logprob_per_token, num_y_tokens).
    """
    if not y:
        return float("-inf"), float("-inf"), 0

    dev = next(pm_model.parameters()).device

    # Tokenize full sequence and the split point (where y begins)
    ids_full = tok(x + y, return_tensors="pt").input_ids[0].to(dev)  # [T]
    ids_x    = tok(x,     return_tensors="pt").input_ids[0].to(dev)  # [Tx]
    start = ids_x.shape[0]  # index in ids_full where y starts
    if start >= ids_full.shape[0]:
        return float("-inf"), float("-inf"), 0

    pm_model.eval()
    with torch.no_grad():
        logits = pm_model(ids_full.unsqueeze(0)).logits[0]   # [T, V]
        logp   = F.log_softmax(logits, dim=-1)               # [T, V]

    next_token_logp = logp[:-1, :].gather(1, ids_full[1:].unsqueeze(-1)).squeeze(-1)  # [T-1]
    y_logp = next_token_logp[start-1:]
    sum_lp = float(y_logp.sum().item())
    n_tok  = int(y_logp.shape[0])
    avg_lp = (sum_lp / n_tok) if n_tok > 0 else float("-inf")
    return sum_lp, avg_lp, n_tok

# ----- Config for DPO data construction (smaller for demo) -----
K_CANDIDATES = 2        # 2 prefixes per suffix
NUM_DPO_SUFFIXES = 5    # only use 5 suffixes total
MAX_INV_GEN = 48
SEED = 123
random.seed(SEED); torch.manual_seed(SEED)

# take only a small pool of suffixes
suffix_pool = [y for (x, y) in pairs if isinstance(y, str) and len(y) > 0]
random.shuffle(suffix_pool)
suffix_pool = suffix_pool[:NUM_DPO_SUFFIXES]

dpo_triples = []
print(f"[STEP] Constructing DPO preference data for {len(suffix_pool)} suffixes...")

for y in suffix_pool:
    cands = []
    for _ in range(K_CANDIDATES):
        x = inv_sample_prefix(y)
        if not x or len(x.strip()) < 3:
            continue
        sum_lp, avg_lp, n_tok = logprob_y_given_x(pm, tok_pm, x, y)
        cands.append((x, sum_lp, avg_lp, n_tok))
    if len(cands) < 2:
        continue

    # pick one winner and one loser
    cands.sort(key=lambda t: t[1], reverse=True)
    winner = cands[0][0]
    loser  = cands[-1][0]

    dpo_triples.append({
        "prompt": f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}",
        "chosen": winner,
        "rejected": loser,
    })

print(f"[INFO] Built {len(dpo_triples)} DPO pairs (max {NUM_DPO_SUFFIXES}).")
if dpo_triples:
    ex = dpo_triples[0]
    print("[EXAMPLE DPO ITEM]")
    print("prompt (truncated):", (ex['prompt'][:120] + "...") if len(ex['prompt'])>120 else ex['prompt'])
    print("chosen:", ex["chosen"])
    print("rejected:", ex["rejected"])

[STEP] Constructing DPO preference data for 5 suffixes...
[INFO] Built 5 DPO pairs (max 5).
[EXAMPLE DPO ITEM]
prompt (truncated): Suffix:
reviewing Participationimura Daniel Hancockatisf confir autonomyohopresshibitoho TA Habit trilogy Participationa...
chosen: skillet praying 236 boils representationsobl equate membershipshows WheelsPros grandchildrenProsacious Medic Redux predatorsacious boilsMost membership perhapsPros448 ReduxPros equateMiniPros Tre653 lined workshopsshowsacious Medic653Mini equate factors boilsived deflect boilsozygived Late equate
rejected: TelevisionOutside Wheels perhaps soypublicpublicSexualacious grandchildren perhaps BendSexualOutside mutual skilletived� predators clearer representations soy Bendived Redux ReduxMost Late factors skillet grandchildren Pocket bravery equate Late ReduxMostoblOutside braveryMini Singapore Tre mutual equateSexual Medic lined


In [49]:
# === Cell 8 (patched): DPO on CPU; provide processing_class + padding_value ===
from trl import DPOTrainer, DPOConfig
from datasets import Dataset
import copy, torch, random

if len(dpo_triples) == 0:
    raise RuntimeError("DPO dataset is empty. Re-run Cell 7 to build preference pairs.")

dpo_ds = Dataset.from_list(dpo_triples)
print(f"[INFO] DPO dataset size: {len(dpo_ds)} pairs")

_ORIG_DEVICE = DEVICE

# ----- Ensure tokenizer/model have padding set -----
if tok_inv.pad_token is None:
    tok_inv.pad_token = tok_inv.eos_token
if getattr(inv.config, "pad_token_id", None) is None:
    inv.config.pad_token_id = tok_inv.pad_token_id

# ----- Build models on CPU -----
print("[STEP] Cloning frozen reference policy (CPU)...")
ref_model = copy.deepcopy(inv).to("cpu")
for p in ref_model.parameters():
    p.requires_grad_(False)
ref_model.eval()

inv_cpu = inv.to("cpu")

# ----- DPO config (note padding_value) -----
dpo_args = DPOConfig(
    beta=0.1,
    learning_rate=1e-5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    max_prompt_length=200,
    max_length=256,               # some TRL versions also support max_completion_length
    remove_unused_columns=False,
    logging_steps=5,
    use_mps_device=False,
    use_cpu=True,
    padding_value=tok_inv.pad_token_id,   # <<< important
    label_pad_token_id=-100,
)

print("[STEP] Starting DPO training on CPU...")
trainer = DPOTrainer(
    model=inv_cpu,
    ref_model=ref_model,
    args=dpo_args,
    train_dataset=dpo_ds,
    processing_class=tok_inv,     # <<< pass tokenizer here for your TRL version
)
train_result = trainer.train()
print("[INFO] DPO training complete.")
print(train_result)

# ----- Move back to original device -----
inv = inv_cpu.to(_ORIG_DEVICE)

# ----- Sanity check -----
y_test = random.choice([
    ex["prompt"].split("Suffix:\n",1)[1].split("\nPrefix:\n",1)[0] 
    for ex in dpo_triples
])

prompt = f"{IN_CONTEXT_PREFIX}{y_test}{MID_PROMPT}"
enc = tok_inv(prompt, return_tensors="pt").to(_ORIG_DEVICE)
inv.eval()
with torch.no_grad():
    gen = inv.generate(
        **enc,
        max_new_tokens=48,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        pad_token_id=tok_inv.pad_token_id,
    )
full = tok_inv.decode(gen[0], skip_special_tokens=True)
x_hat = full.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in full else full.strip()

sum_lp, avg_lp, n_tok = logprob_y_given_x(pm.to(_ORIG_DEVICE), tok_pm, x_hat, y_test)
print("\n[SANITY CHECK AFTER DPO]")
print("suffix (y):", (y_test[:160] + "...") if len(y_test)>160 else y_test)
print("proposed prefix (x̂):", (x_hat[:160] + "...") if len(x_hat)>160 else x_hat)
print(f"log pm(y|x̂): sum={sum_lp:.3f}, avg/token={avg_lp:.4f}, n_tok={n_tok}")


[INFO] DPO dataset size: 5 pairs
[STEP] Cloning frozen reference policy (CPU)...
[STEP] Starting DPO training on CPU...


Extracting prompt in train dataset: 100%|██████████| 5/5 [00:00<00:00, 1124.12 examples/s]
Applying chat template to train dataset: 100%|██████████| 5/5 [00:00<00:00, 930.95 examples/s]
Tokenizing train dataset: 100%|██████████| 5/5 [00:00<00:00, 488.99 examples/s]
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 50256}.


Step,Training Loss


[INFO] DPO training complete.
TrainOutput(global_step=3, training_loss=0.6931034723917643, metrics={'train_runtime': 1.6538, 'train_samples_per_second': 3.023, 'train_steps_per_second': 1.814, 'total_flos': 0.0, 'train_loss': 0.6931034723917643, 'epoch': 1.0})


  test_elements = torch.tensor(test_elements)



[SANITY CHECK AFTER DPO]
suffix (y): reviewing Participationimura Daniel Hancockatisf confir autonomyohopresshibitoho TA Habit trilogy Participationatisf scalpohopress ONE subst stairsmediately dir...
proposed prefix (x̂): bravery clearer Singapore Tre Redux boils653 bravery prayingpublicacious grandchildrenaciouspublicozyg predators Television448448 deflect653448 praying courtyar...
log pm(y|x̂): sum=-580.997, avg/token=-10.7592, n_tok=54


In [50]:
# === Cell 9: FW helpers ===
import torch, torch.nn.functional as F, random, copy

# Small, laptop-friendly defaults
FW_ITERS = 2          # number of FW rounds (toy)
FW_NUM_SUFFIX = 2     # how many suffixes per FW round
FW_K_CANDS = 3        # candidates per suffix
FW_MAX_NEW = 48       # max tokens investigator generates for a prefix
FW_LAMBDA = 0.5       # λ (penalty strength on previous investigator)
SEED = 123
random.seed(SEED); torch.manual_seed(SEED)

def logprob_inv_x_given_y(inv_model, tok, y: str, x: str) -> tuple[float, float, int]:
    """
    Compute log p_inv(x | y) for the investigator under the SFT/DPO training template:
        prompt_src = "Suffix:\\n{y}\\nPrefix:\\n"
        continuation = x
    We teacher-force the model on (src + x) and sum logprobs over x tokens only.
    Returns: (sum_logprob, avg_logprob_per_token, num_tokens)
    """
    if not x:
        return float("-inf"), float("-inf"), 0

    src = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
    dev = next(inv_model.parameters()).device

    ids_full = tok(src + x, return_tensors="pt").input_ids[0].to(dev)
    ids_src  = tok(src,     return_tensors="pt").input_ids[0].to(dev)
    start = ids_src.shape[0]
    if start >= ids_full.shape[0]:
        return float("-inf"), float("-inf"), 0

    inv_model.eval()
    with torch.no_grad():
        logits = inv_model(ids_full.unsqueeze(0)).logits[0]
        logp   = F.log_softmax(logits, dim=-1)

    next_token_logp = logp[:-1, :].gather(1, ids_full[1:].unsqueeze(-1)).squeeze(-1)
    x_logp = next_token_logp[start-1:]
    sum_lp = float(x_logp.sum().item())
    n_tok  = int(x_logp.shape[0])
    avg_lp = (sum_lp / n_tok) if n_tok > 0 else float("-inf")
    return sum_lp, avg_lp, n_tok

def inv_sample_prefix_from(model, tok, y: str, max_new_tokens=FW_MAX_NEW) -> str:
    """
    Sample a prefix x ~ model(.|y) using the same input template.
    """
    src = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
    dev = next(model.parameters()).device
    enc = tok(src, return_tensors="pt").to(dev)

    model.eval()
    with torch.no_grad():
        out = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tok.eos_token_id,
        )
    text = tok.decode(out[0], skip_special_tokens=True)
    return text.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in text else text.strip()

def build_fw_pairs(prev_inv_model, pm_model, tok_pm_local, suffix_source_pairs,
                   k_cands=FW_K_CANDS, lambda_pen=FW_LAMBDA):
    """
    Build DPO triples for one FW iteration using penalized score:
        score(x,y) = log pm(y|x) - λ * log p_prev(x|y)
    Returns a list of dicts {prompt, chosen, rejected}.
    """
    triples = []
    suffixes = [y for (_, y) in suffix_source_pairs if isinstance(y, str) and len(y) > 0]
    random.shuffle(suffixes)
    suffixes = suffixes[:FW_NUM_SUFFIX]

    for y in suffixes:
        cands = []
        for _ in range(k_cands):
            x = inv_sample_prefix_from(inv, tok_inv, y)
            if not x or len(x.strip()) < 3:
                continue
            # compute both terms explicitly on CPU model
            sum_pm, _, _   = logprob_y_given_x(pm_model, tok_pm_local, x, y)
            sum_prev, _, _ = logprob_inv_x_given_y(prev_inv_model, tok_inv, y, x)
            score = sum_pm - lambda_pen * sum_prev
            cands.append((x, score, sum_pm, sum_prev))

        if len(cands) < 2:
            continue

        cands.sort(key=lambda t: t[1], reverse=True)
        winner = cands[0][0]
        loser  = cands[-1][0]

        triples.append({
            "prompt": f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}",
            "chosen": winner,
            "rejected": loser,
        })

    return triples


In [51]:
# === Cell 10: FW loop (CPU-only) ===
from trl import DPOTrainer, DPOConfig
from datasets import Dataset
import copy

CPU = torch.device("cpu")

# Ensure tokenizer/model have padding set
if tok_inv.pad_token is None:
    tok_inv.pad_token = tok_inv.eos_token
if getattr(inv.config, "pad_token_id", None) is None:
    inv.config.pad_token_id = tok_inv.pad_token_id

# Move models to CPU
pm_cpu  = pm.to(CPU)
inv_cpu = inv.to(CPU)

# Previous iterate on CPU
prev_inv = copy.deepcopy(inv_cpu).to(CPU)
for p in prev_inv.parameters():
    p.requires_grad_(False)
prev_inv.eval()

for it in range(1, FW_ITERS + 1):
    print(f"\n[FW/CPU] Iteration {it}/{FW_ITERS} — building penalized preference pairs...")
    fw_triples = build_fw_pairs(prev_inv, pm_cpu, tok_pm, pairs,
                                k_cands=FW_K_CANDS, lambda_pen=FW_LAMBDA)
    print(f"[FW/CPU] Built {len(fw_triples)} pairs.")

    if len(fw_triples) == 0:
        print("[FW/CPU] No pairs built.")
        break

    fw_ds = Dataset.from_list(fw_triples)

    dpo_args = DPOConfig(
        beta=0.1,
        learning_rate=1e-5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        num_train_epochs=1,
        max_prompt_length=200,
        max_length=256,
        remove_unused_columns=False,
        logging_steps=5,
        use_cpu=True,
        use_mps_device=False,
        padding_value=tok_inv.pad_token_id,
        label_pad_token_id=-100,
    )

    print("[FW/CPU] DPO step...")
    trainer = DPOTrainer(
        model=inv_cpu,
        ref_model=prev_inv,
        args=dpo_args,
        train_dataset=fw_ds,
        processing_class=tok_inv,
    )
    trainer.train()
    print("[FW/CPU] DPO step complete.")

    # Update prev_inv for next iteration
    prev_inv = copy.deepcopy(inv_cpu).to(CPU)
    for p in prev_inv.parameters():
        p.requires_grad_(False)
    prev_inv.eval()

    # --- Sanity check like your DPO cell ---
    y_test = fw_triples[0]["prompt"].split("Suffix:\n",1)[1].split("\nPrefix:\n",1)[0]
    prompt = f"{IN_CONTEXT_PREFIX}{y_test}{MID_PROMPT}"
    enc = tok_inv(prompt, return_tensors="pt").to(CPU)
    inv_cpu.eval()
    with torch.no_grad():
        gen = inv_cpu.generate(
            **enc,
            max_new_tokens=48,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tok_inv.pad_token_id,
        )
    full = tok_inv.decode(gen[0], skip_special_tokens=True)
    x_hat = full.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in full else full.strip()

    sum_lp, avg_lp, n_tok = logprob_y_given_x(pm_cpu, tok_pm, x_hat, y_test)
    print("[FW/CPU] Sanity check")
    print("suffix (y):", (y_test[:160] + "...") if len(y_test)>160 else y_test)
    print("proposed prefix (x̂):", (x_hat[:160] + "...") if len(x_hat)>160 else x_hat)
    print(f"log pm(y|x̂): sum={sum_lp:.3f}, avg/token={avg_lp:.4f}, n_tok={n_tok}")

print("\n[FW/CPU] Finished. Investigator remains on CPU.")



[FW/CPU] Iteration 1/2 — building penalized preference pairs...
[FW/CPU] Built 2 pairs.
[FW/CPU] DPO step...


Extracting prompt in train dataset: 100%|██████████| 2/2 [00:00<00:00, 497.40 examples/s]
Applying chat template to train dataset: 100%|██████████| 2/2 [00:00<00:00, 511.50 examples/s]
Tokenizing train dataset: 100%|██████████| 2/2 [00:00<00:00, 351.37 examples/s]


Step,Training Loss


[FW/CPU] DPO step complete.
[FW/CPU] Sanity check
suffix (y): reviewing Participationimura Daniel Hancockatisf confir autonomyohopresshibitoho TA Habit trilogy Participationatisf scalpohopress ONE subst stairsmediately dir...
proposed prefix (x̂): factors Wheels Wheels perhaps predatorsobl448ozyg representations grandchildrenpublic Boone deflect predators grandchildren skillet equate clearer lined factors...
log pm(y|x̂): sum=-580.999, avg/token=-10.7592, n_tok=54

[FW/CPU] Iteration 2/2 — building penalized preference pairs...
[FW/CPU] Built 2 pairs.
[FW/CPU] DPO step...


Extracting prompt in train dataset: 100%|██████████| 2/2 [00:00<00:00, 546.31 examples/s]
Applying chat template to train dataset: 100%|██████████| 2/2 [00:00<00:00, 501.77 examples/s]
Tokenizing train dataset: 100%|██████████| 2/2 [00:00<00:00, 348.61 examples/s]


Step,Training Loss


[FW/CPU] DPO step complete.
[FW/CPU] Sanity check
suffix (y): oho BrewRocket Daniel Rh hauled Participation circumcised conservation circumcisediken pawn ESV circumcised credibilityScenemediatelyimura Daniel ONEScene antib...
proposed prefix (x̂): factors Wheels Wheels perhaps predatorsobl448ozyg representations grandchildrenpublic Boone deflect predators grandchildren skillet equate clearer lined factors...
log pm(y|x̂): sum=-580.952, avg/token=-10.7584, n_tok=54

[FW/CPU] Finished. Investigator remains on CPU.
