In [2]:
# Recreate SFT step (toy) for investigator pθ(x|y)
import os, time, math, random
from dataclasses import dataclass
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
from torch.optim import AdamW

# ---- Config ----
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
PM_ID  = "sshleifer/tiny-gpt2"     # target LM pm (forward x->y)
INV_ID = "sshleifer/tiny-gpt2"     # investigator base pθ (we'll fine-tune)

# Tiny prefix distribution P_SFT
PREFIXES = [
    "A short note about machine learning:",
    "Three tips for staying productive:",
    "A gentle introduction to probability:",
    "In a surprising discovery, scientists found",
    "As a software engineer, I often consider",
    "In Japan, the Shinkansen is known for",
    "A concise summary of the book is",
    "The quick brown fox",
    "An explanation for beginners:",
    "A brief overview of databases:"
]

NUM_EXAMPLES   = 30     # size of DSFT
MAX_PREFIX_EXT = 8      # optional: extend x slightly for variety
MAX_SUFFIX_LEN = 48     # greedy suffix length
BATCH_SIZE     = 2
EPOCHS         = 2
LR             = 5e-5
WARMUP_RATIO   = 0.1
GRAD_ACCUM_STEPS = 1

# Repro
random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)

print(f"[INFO] Device: {DEVICE}")


  from .autonotebook import tqdm as notebook_tqdm


[INFO] Device: mps


In [3]:
t0 = time.time()

print("[STEP] Loading tokenizer (pm)...")
tok_pm = AutoTokenizer.from_pretrained(PM_ID)
if tok_pm.pad_token is None:
    tok_pm.pad_token = tok_pm.eos_token

print("[STEP] Loading target model pm (x->y)... (first run may download)")
pm = AutoModelForCausalLM.from_pretrained(
    PM_ID,
    use_safetensors=True,          # avoids .bin load
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to(DEVICE)
pm.eval()

print("[STEP] Loading tokenizer (investigator pθ)...")
tok_inv = AutoTokenizer.from_pretrained(INV_ID)
if tok_inv.pad_token is None:
    tok_inv.pad_token = tok_inv.eos_token

print("[STEP] Loading investigator base model pθ (will fine-tune)...")
inv = AutoModelForCausalLM.from_pretrained(
    INV_ID,
    use_safetensors=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to(DEVICE)

print(f"[TIME] Models loaded in {time.time()-t0:.2f}s")

[STEP] Loading tokenizer (pm)...


`torch_dtype` is deprecated! Use `dtype` instead!


[STEP] Loading target model pm (x->y)... (first run may download)
[STEP] Loading tokenizer (investigator pθ)...
[STEP] Loading investigator base model pθ (will fine-tune)...
[TIME] Models loaded in 7.43s


In [4]:
def gen_from(model, tok, text, max_new_tokens, device=DEVICE):
    """Sample continuation from given text (avoids greedy repetition)."""
    inputs = tok(text, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,           # sampling instead of greedy
            top_p=0.95,               # nucleus sampling
            temperature=0.8,          # soften logits
            pad_token_id=tok.eos_token_id,
        )
    return tok.decode(out[0], skip_special_tokens=True)

print("[STEP] Building DSFT: sample x ~ P_SFT, then y ← pm(x) (sampled)")
t_ds = time.time()
pairs = []
for i in range(NUM_EXAMPLES):
    x = random.choice(PREFIXES)
    full = gen_from(pm, tok_pm, x, max_new_tokens=MAX_SUFFIX_LEN)
    y = full[len(x):].strip() if full.startswith(x) else full
    pairs.append((x, y))

print(f"[TIME] Built DSFT with {len(pairs)} examples in {time.time()-t_ds:.2f}s")
print("\n=== PRINTING DSFT (x, y) ===")
for idx, (x, y) in enumerate(pairs, 1):
    print(f"\n[{idx}] x: {x}")
    print(f"[{idx}] y: {y}")

[STEP] Building DSFT: sample x ~ P_SFT, then y ← pm(x) (sampled)


  test_elements = torch.tensor(test_elements)


[TIME] Built DSFT with 30 examples in 88.12s

=== PRINTING DSFT (x, y) ===

[1] x: Three tips for staying productive:
[1] y: reement Brew intermittentatisf TA Prob Prob MotorolaScene circumcisedohoatisfimura ObservimuraJD intermittent stairsikenoother Rhdit Brew Habitatisf autonomy antibiotic heirootheroother Daniel Rh trilogy reviewingting Brewpress ESV trilogy MoneyJD antibiotic pawn Prob conservation Hancockdit vendors

[2] x: A short note about machine learning:
[2] y: Sexual membershipPros clearer653aciousMostozyg factors predators equate brutality Pocketozyg Pocket Late incarcerSexual praying workshopsMiniivedshows grandchildrenozyg448Prosshows courtyard workshops Boone soyivedpublic predatorsSexual Tre Tre Wheels 236 perhaps lined 236�Mini rubbingPros equate

[3] x: As a software engineer, I often consider
[3] y: press reviewing Danieltingpress004 hauled pawn hauled TA ONE Hancock Jr ESV antibiotic stairsting hauled vendors credibility credibility MoneyJD Observ Prob antibioticd

In [5]:
IN_CONTEXT_PREFIX = "Suffix:\n"
MID_PROMPT = "\nPrefix:\n"

class XYDataset(torch.utils.data.Dataset):
    """
    For each (x, y), we feed the model:
      input:  'Suffix:\\n{y}\\nPrefix:\\n' + x
      labels: supervise only on x (mask out the suffix part)
    """
    def __init__(self, pairs, tokenizer, max_len=256):
        self.tok = tokenizer
        self.max_len = max_len
        self.items = []
        for (x, y) in pairs:
            src = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
            tgt = x
            self.items.append((src, tgt))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, i):
        src, tgt = self.items[i]
        enc_all = self.tok(
            src + tgt,
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )
        input_ids = enc_all["input_ids"][0]
        attn = enc_all["attention_mask"][0]

        # mask out source portion
        src_len = len(self.tok(src, truncation=True, max_length=self.max_len)["input_ids"])
        labels = input_ids.clone()
        labels[:src_len] = -100
        return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}

def collate(batch):
    keys = batch[0].keys()
    out = {}
    for k in keys:
        seqs = [b[k] for b in batch]
        pad_val = tok_inv.pad_token_id if k != "labels" else -100
        out[k] = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=pad_val)
    return out

# Split train/val
cut = int(0.8 * len(pairs))
train_pairs = pairs[:cut]
val_pairs   = pairs[cut:]

train_ds = XYDataset(train_pairs, tok_inv)
val_ds   = XYDataset(val_pairs, tok_inv)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate)
val_loader   = torch.utils.data.DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

print(f"[INFO] Train size: {len(train_ds)} | Val size: {len(val_ds)}")

# Peek at tokenized structure
sample_item = train_ds[0]
print("\n[SAMPLE TOKENS]")
print("input_ids len:", sample_item["input_ids"].shape[0])
print("supervised tokens:", (sample_item["labels"] != -100).sum().item())


[INFO] Train size: 24 | Val size: 6

[SAMPLE TOKENS]
input_ids len: 68
supervised tokens: 6


In [6]:
from torch.optim import AdamW  # use PyTorch AdamW

print("[STEP] Training investigator pθ to predict x from y (SFT)...")
t_train = time.time()
inv.train()
opt = AdamW(inv.parameters(), lr=LR)

num_training_steps = EPOCHS * math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
num_warmup = int(WARMUP_RATIO * num_training_steps)
sched = get_cosine_schedule_with_warmup(opt, num_warmup, num_training_steps)

global_step = 0
for epoch in range(1, EPOCHS + 1):
    running = 0.0
    for step, batch in enumerate(train_loader, 1):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        out = inv(**batch)
        loss = out.loss / GRAD_ACCUM_STEPS
        loss.backward()
        if step % GRAD_ACCUM_STEPS == 0:
            opt.step()
            sched.step()
            opt.zero_grad()
            global_step += 1
        running += loss.item() * GRAD_ACCUM_STEPS
        if step % 5 == 0 or step == len(train_loader):
            print(f"  [epoch {epoch} step {step}/{len(train_loader)}] loss={running/step:.4f}")

print(f"[TIME] Training done in {time.time()-t_train:.2f}s")

[STEP] Training investigator pθ to predict x from y (SFT)...


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


  [epoch 1 step 5/12] loss=10.8279
  [epoch 1 step 10/12] loss=10.8254
  [epoch 1 step 12/12] loss=10.8246
  [epoch 2 step 5/12] loss=10.8298
  [epoch 2 step 10/12] loss=10.8268
  [epoch 2 step 12/12] loss=10.8252
[TIME] Training done in 8.78s


In [12]:
inv.eval()

def inv_predict_prefix(y: str, max_new_tokens=32):
    prompt = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
    enc = tok_inv(prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        gen = inv.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tok_inv.eos_token_id,
        )
    text = tok_inv.decode(gen[0], skip_special_tokens=True)
    return text.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in text else text

def pm_continue(x: str, max_new_tokens=40):
    full = gen_from(pm, tok_pm, x, max_new_tokens=max_new_tokens)
    return full[len(x):].strip() if full.startswith(x) else full

print("[STEP] Sampling a few validations to see behavior...")
for i, (x_true, y_true) in enumerate(val_pairs[:5], 1):
    x_hat = inv_predict_prefix(y_true, max_new_tokens=32)
    y_from_hat = pm_continue(x_hat, max_new_tokens=40)

    print("\n--- Example", i, "---")
    print("y_true (suffix):", (y_true[:100] + "...") if len(y_true) > 100 else y_true)
    print("x_true (gold prefix):", x_true)
    print("x_hat  (pred prefix):", x_hat)
    print("pm continuation from x_hat:", (y_from_hat[:100] + "...") if len(y_from_hat) > 100 else y_from_hat)

print("\n[ALL DONE]")

[STEP] Sampling a few validations to see behavior...

--- Example 1 ---
y_true (suffix): � predators equate Singapore equate Redux Wheels bravery grandchildren bravery Booneobl boils factor...
x_true (gold prefix): The quick brown fox
x_hat  (pred prefix): shows equate rubbing Pocket praying Bend 236 236 praying factors Late boils Television653 equate lined boils boils equate Pocket Pocket Pocket skillet Singapore Pocket representations Dreams bravery Tre Bend skillet Boone
pm continuation from x_hat: ived workshops skillet rubbing Television mutual predatorsOutside clearer perhapsPros deflectobl cle...

--- Example 2 ---
y_true (suffix): imura substSher004iken Participation trilogy ParticipationSher ONE Observ ESV scalp004imura autonomy...
x_true (gold prefix): A brief overview of databases:
x_hat  (pred prefix): factors Treozygozyg Pocket LateOutside lined DreamsaciousSexualived membership representations grandchildren soyMini boils braveryozyg perhaps soy soy predators representatio

In [13]:
import random, torch
import torch.nn.functional as F

def inv_sample_prefix(y: str, max_new_tokens=48):
    """
    Sample a candidate prefix x ~ pθ(.|y) from the current investigator `inv`.
    Uses the same input template as SFT: 'Suffix:\\n{y}\\nPrefix:\\n'.
    """
    prompt = f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}"
    enc = tok_inv(prompt, return_tensors="pt").to(DEVICE)
    inv.eval()
    with torch.no_grad():
        out = inv.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=True,       # sample to get diverse candidates
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tok_inv.eos_token_id,
        )
    text = tok_inv.decode(out[0], skip_special_tokens=True)
    # Extract the continuation after the 'Prefix:' marker
    return text.split(MID_PROMPT, 1)[1].strip() if MID_PROMPT in text else text.strip()

def logprob_y_given_x(pm, tok, x: str, y: str):
    """
    Compute log p_m(y | x) under the target LM `pm` via teacher forcing.
    Returns: (sum_logprob, avg_logprob_per_token, num_y_tokens).
    """
    if not y:
        return float("-inf"), float("-inf"), 0

    # Tokenize full sequence and the split point (where y begins)
    ids_full = tok(x + y, return_tensors="pt").input_ids[0].to(DEVICE)  # [T]
    ids_x    = tok(x,     return_tensors="pt").input_ids[0].to(DEVICE)  # [Tx]
    start = ids_x.shape[0]  # index in ids_full where y starts
    if start >= ids_full.shape[0]:
        return float("-inf"), float("-inf"), 0

    pm.eval()
    with torch.no_grad():
        logits = pm(ids_full.unsqueeze(0)).logits[0]   # [T, V]
        logp   = F.log_softmax(logits, dim=-1)         # [T, V]

    # Next-token logprobs for each position: log P(token[t] | tokens[:t])
    # Align by shifting targets by one
    next_token_logp = logp[:-1, :].gather(1, ids_full[1:].unsqueeze(-1)).squeeze(-1)  # [T-1]

    # y occupies ids_full[start:], so predictions for y are at positions [start-1 : end-1)
    y_logp = next_token_logp[start-1:]
    sum_lp = float(y_logp.sum().item())
    n_tok  = int(y_logp.shape[0])
    avg_lp = (sum_lp / n_tok) if n_tok > 0 else float("-inf")
    return sum_lp, avg_lp, n_tok

# ----- Config for DPO data construction (smaller for demo) -----
K_CANDIDATES = 2        # 2 prefixes per suffix
NUM_DPO_SUFFIXES = 5    # only use 5 suffixes total
MAX_INV_GEN = 48
SEED = 123
random.seed(SEED); torch.manual_seed(SEED)

# take only a small pool of suffixes
suffix_pool = [y for (x, y) in pairs if isinstance(y, str) and len(y) > 0]
random.shuffle(suffix_pool)
suffix_pool = suffix_pool[:NUM_DPO_SUFFIXES]

dpo_triples = []
print(f"[STEP] Constructing DPO preference data for {len(suffix_pool)} suffixes...")

for y in suffix_pool:
    cands = []
    for _ in range(K_CANDIDATES):
        x = inv_sample_prefix(y)
        if not x or len(x.strip()) < 3:
            continue
        sum_lp, avg_lp, n_tok = logprob_y_given_x(pm, tok_pm, x, y)
        cands.append((x, sum_lp, avg_lp, n_tok))
    if len(cands) < 2:
        continue

    # pick one winner and one loser
    cands.sort(key=lambda t: t[1], reverse=True)
    winner = cands[0][0]
    loser  = cands[-1][0]

    dpo_triples.append({
        "prompt": f"{IN_CONTEXT_PREFIX}{y}{MID_PROMPT}",
        "chosen": winner,
        "rejected": loser,
    })

print(f"[INFO] Built {len(dpo_triples)} DPO pairs (max {NUM_DPO_SUFFIXES}).")
if dpo_triples:
    ex = dpo_triples[0]
    print("[EXAMPLE DPO ITEM]")
    print("prompt (truncated):", (ex['prompt'][:120] + "...") if len(ex['prompt'])>120 else ex['prompt'])
    print("chosen:", ex["chosen"])
    print("rejected:", ex["rejected"])

[STEP] Constructing DPO preference data for 5 suffixes...
[INFO] Built 5 DPO pairs (max 5).
[EXAMPLE DPO ITEM]
prompt (truncated): Suffix:
credibility autonomyikenimura dispatchimura vendorsdit Observ confir004 TAatisfhibit credibility directly trilog...
chosen: TelevisionOutside Wheels perhaps soypublicpublicSexualacious grandchildren perhaps BendSexualOutside mutual skilletived� predators clearer representations soy Bendived Redux ReduxMost Late factors skillet grandchildren Pocket bravery equate Late ReduxMostoblOutside braveryMini Singapore Tre mutual equateSexual Medic lined
rejected: skillet praying 236 boils representationsobl equate membershipshows WheelsPros grandchildrenProsacious Medic Redux predatorsacious boilsMost membership perhapsPros448 ReduxPros equateMiniPros Tre653 lined workshopsshowsacious Medic653Mini equate factors boilsived deflect boilsozygived Late equate
