# Tiny Word‑Level Transformer: **Before vs After Fine‑Tuning** (CPU‑Fast Demo)  

## 1. Environment & Reproducibility  

We will set PyTorch and Python seeds for reproducible results, and keep everything CPU-only and fast.

In [None]:
import re, random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerDecoderLayer, TransformerDecoder

# Reproducibility and speed (CPU only)
torch.manual_seed(7)
random.seed(7)
device = "cpu"
torch.set_num_threads(1)

print(f"Using device: {device}")

## 2. Minimal **World-Level** Tokenizer  

- Splits on words (keeps hyphenated terms like `Code-You`) and punctuation tokens.  
- Provides: `tokenize`, `detokenize`, vocab builders, and encode/decode helpers.  
- Includes special tokens: `<bos>`, `<eos>`, `<unk>`.

In [None]:
def tokenize(text: str):
    # words + hyphenated + punctuation (includes ':' for QA format)
    return re.findall(r"[A-Za-z]+(?:-[A-Za-z]+)?|[:?.!,']", text)

def detokenize(toks):
    s = " ".join(toks)
    return re.sub(r"\s+([?:.!,'])", r"\1", s)

def build_vocab_texts(texts, specials=("", "", "")):
    vocab = set()
    for t in texts:
        vocab.update(tokenize(t))
    vocab = list(specials) + sorted(vocab)
    stoi = {w: i for i, w in enumerate(vocab)}
    itos = {i: w for w, i in stoi.items()}
    return stoi, itos

def enc_line(line, stoi):
    toks = [""] + tokenize(line) + [""]
    unk = stoi[""]
    return [stoi.get(t, unk) for t in toks]

def enc_prefix(prefix, stoi):
    # used for prompts (no  appended so the model can continue)
    toks = [""] + tokenize(prefix)
    unk = stoi[""]
    return [stoi.get(t, unk) for t in toks]

def dec_ids(ids, itos, keep_eos=False):
    toks = [itos[i] for i in ids]
    toks = [t for t in toks if t != "" and (keep_eos or t != "")]
    return detokenize(toks)

print("Tokenizer ready.")

## 3. Dataset for Next-Token Prediction  

Sequences are prepared as `(x, y)` where `y` is `x` shifted by 1 (classic next-token LM training).

In [None]:
class SeqDS(Dataset):
    def __init__(self, ids, seq_len=24):
        self.ids = ids
        self.seq_len = seq_len
    def __len__(self):
        return max(0, len(self.ids) - self.seq_len - 1)
    def __getitem__(self, i):
        x = self.ids[i:i+self.seq_len]
        y = self.ids[i+1:i+self.seq_len+1]
        return torch.tensor(x), torch.tensor(y)

print("Dataset class ready.")

## 4. Tiny **Decoder-Only** Transformer  

We use `TransformerDecoder` with a **causal mask** so tokens cannot attend to the future.  

In [None]:
class PosEnc(nn.Module):
    def __init__(self, d, max_len=512):
        super().__init__()
        self.pos = nn.Embedding(max_len, d)
    def forward(self, x):  # x: [B,T]
        T = x.size(1)
        return self.pos(torch.arange(T, device=x.device)[None, :])

class TinyLM(nn.Module):
    def __init__(self, V, d=96, heads=3, layers=2, max_len=512):
        super().__init__()
        self.emb = nn.Embedding(V, d)
        self.pos = PosEnc(d, max_len)
        block = TransformerDecoderLayer(d_model=d, nhead=heads, batch_first=False)
        self.dec = TransformerDecoder(block, num_layers=layers)
        self.head = nn.Linear(d, V)

    def forward(self, ids):  # ids: [B,T]
        x = self.emb(ids) + self.pos(ids)           # [B,T,E]
        h = x.transpose(0,1)                        # [T,B,E]
        T = h.size(0)
        mask = torch.triu(torch.ones(T, T, device=h.device), diagonal=1).bool()
        out = self.dec(h, h, tgt_mask=mask).transpose(0,1)  # [B,T,E]
        return self.head(out)                       # [B,T,V]

    @torch.no_grad()
    def generate_until_eos(self, start_ids, eos_id, max_new=20):
        self.eval()
        ids = start_ids
        for _ in range(max_new):
            logits = self.forward(ids)[:, -1, :]      # [B,V]
            nxt = torch.argmax(logits, dim=-1, keepdim=True)  # greedy
            ids = torch.cat([ids, nxt], dim=1)
            if (nxt == eos_id).all():
                break
        return ids

print("Model ready.")

## 5. Corpora  

- **Base pretraining**: facts + distractors (no QA lines) → answers may be imperfect.
- **Fine‑tuning pairs**: focused QA you want to be perfect.

In [None]:
def base_corpus_with_distractors():
    # True facts (no QA format)
    core = [
        "I am John Doe.",
        "I am instructor at Code-You.",
        "My name is John Doe.",
        "My profession is instructor at Code-You."
    ]
    # Distractors to confuse pretrain
    others = [
        "I am Jane Roe.",
        "I am engineer at Code-You.",
        "My name is Alan Smithee.",
        "My profession is designer at Code-You.",
        "I am John Doe and I am engineer.",
        "My profession is instructor at Code-Me."
    ]
    corpus = (core*5 + others*7)
    random.shuffle(corpus)
    return corpus

def finetune_pairs():
    return [
        "Q: Who I am? A: John Doe.",
        "Q: What is my profession? A: instructor at Code-You."
    ]

print("Corpora generators ready.")

## 6. Training & Evaluation Helpers  

- `build_ids`: flatten lines into one long sequence of token IDs.
- `trainer`: one optimization step wrapper.
- `ask`: formats a QA prompt and generates until `<eos>`.
- `exact_match`: simple EM metric (case-insensitive string match).

In [None]:
def build_ids(lines, stoi):
    ids = []
    for ln in lines:
        ids.extend(enc_line(ln, stoi))
    return ids

def trainer(model, V, lr=2e-3):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    def step(batch):
        x, y = (b.to(device) for b in batch)
        logits = model(x)
        loss = loss_fn(logits.reshape(-1, V), y.reshape(-1))
        opt.zero_grad(); loss.backward(); opt.step()
        return float(loss.item())
    return step

def ask(model, q, stoi, itos, max_new=10):
    prompt = f"Q: {q.strip()}? A:"
    start = torch.tensor([enc_prefix(prompt, stoi)], dtype=torch.long, device=device)
    eos_id = stoi[""]
    out = model.generate_until_eos(start, eos_id, max_new=max_new)[0].tolist()
    txt = dec_ids(out, itos)
    # Extract after "A:"
    if "A:" in txt:
        ans = txt.split("A:", 1)[1].strip()
        m = re.search(r"^(.+?[.?!])", ans)
        if m: ans = m.group(1)
        return ans.strip()
    return txt.strip()

def exact_match(pred, gold):
    return int(pred.strip().lower() == gold.strip().lower())

print("Helpers ready.")

## 7. Build Vocab and **Short Pretraining** (Imperfect on Purpose)  

- Vocab is built from all potential tokens (base + fine‑tune pairs + prompt forms).
- Pretraining is **run only on base corpus** (with distractors) and **without QA lines** to keep answers imperfect.

In [None]:
prompts = [
    ("Who I am", "John Doe."),
    ("What is my profession", "instructor at Code-You.")
]

# Build vocab from all tokens we might see
vocab_source = base_corpus_with_distractors() + finetune_pairs() + [f"Q: {q}? A:" for q,_ in prompts]
stoi, itos = build_vocab_texts(vocab_source)

# Pretrain (short) on distractor-rich base ONLY (no QA lines)
base_ids = build_ids(base_corpus_with_distractors(), stoi)
ds = SeqDS(base_ids, seq_len=24)
dl = DataLoader(ds, batch_size=16, shuffle=True, drop_last=True, num_workers=0)

V = len(stoi)
model = TinyLM(V).to(device)
step_fn = trainer(model, V, lr=2e-3)

PRE_STEPS = 60  # keep short so model is *not* perfect yet
i = 0
for x, y in dl:
    loss = step_fn((x, y))
    i += 1
    if i % 20 == 0:
        print(f"[PRE] step {i:03d}  loss {loss:.4f}")
    if i >= PRE_STEPS:
        break

In [None]:
print("\n--- BEFORE fine-tune ---")
em_before = []
for q, gold in prompts:
    pred = ask(model, q, stoi, itos, max_new=10)
    em = exact_match(pred, gold); em_before.append(em)
    print(f"{q}: {pred}    (EM={em})")
print(f"EM before: {sum(em_before)}/{len(em_before)}")

In [None]:
def fine_tune(model, stoi, steps=180, seq_len=24, batch_size=16, lr=1e-3, repeat=120, log_every=30):
    lines = finetune_pairs() * repeat
    ft_ids = build_ids(lines, stoi)
    ft_ds = SeqDS(ft_ids, seq_len=seq_len)
    ft_dl = DataLoader(ft_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

    V = len(stoi)
    step = trainer(model, V, lr=lr)

    j = 0
    for x, y in ft_dl:
        loss = step((x, y))
        j += 1
        if j % log_every == 0:
            print(f"[FT]  step {j:03d}  loss {loss:.4f}")
        if j >= steps:
            break

print("Fine-tune function ready.")

In [None]:
print("\n--- FINE-TUNING ---")
fine_tune(model, stoi, steps=180, lr=1e-3, repeat=120)

In [None]:
print("\n--- AFTER fine-tune ---")
em_after = []
for q, gold in prompts:
    pred = ask(model, q, stoi, itos, max_new=10)
    em = exact_match(pred, gold); em_after.append(em)
    print(f"{q}: {pred}    (EM={em})")
print(f"EM after: {sum(em_after)}/{len(em_after)}")