### Step 1: Install necesscary packages

In [1]:
!pip install matplotlib
!pip install torch numpy transformers datasets tiktoken wandb tqdm

Collecting torch
  Downloading torch-2.9.0-cp312-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting transformers
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting datasets
  Downloading datasets-4.2.0-py3-none-any.whl.metadata (18 kB)
Collecting tiktoken
  Downloading tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.7 kB)
Collecting wandb
  Downloading wandb-0.22.2-py3-none-macosx_12_0_arm64.whl.metadata (10 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.35.3-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.6.2-cp38-abi3-macosx_11_0_arm64.whl.metadata (4.1 kB)
Collecting pyarrow>=21.0.0 (from da

In [3]:
# Bring protobuf and rich back into Streamlit’s requested ranges
!pip install "protobuf<6" "rich<14"


Collecting protobuf<6
  Downloading protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl.metadata (592 bytes)
Collecting rich<14
  Using cached rich-13.9.4-py3-none-any.whl.metadata (18 kB)
Downloading protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl (418 kB)
Using cached rich-13.9.4-py3-none-any.whl (242 kB)
Installing collected packages: protobuf, rich
  Attempting uninstall: protobuf
    Found existing installation: protobuf 6.32.1
    Uninstalling protobuf-6.32.1:
      Successfully uninstalled protobuf-6.32.1
  Attempting uninstall: rich
    Found existing installation: rich 14.1.0
    Uninstalling rich-14.1.0:
      Successfully uninstalled rich-14.1.0
Successfully installed protobuf-5.29.5 rich-13.9.4


### Step 2: Package imports and configuration

In [4]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length =64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200
# tokenizer
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

### Step 3: Define helper functions

In [5]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [7]:
ckpt = torch.load("../sft/gpt.pt", map_location=device)
gptconf = GPTConfig(**ckpt['model_args'])
gpt = GPT(gptconf)
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data (**students are required to complete this part!**)

In [8]:
import os, json, random

def build_dataset(n=100000, out_path="./data/pos_neg_pairs.json"):
    pairs = []
    for _ in range(n):
        a, b = random.randint(1,100), random.randint(1,100)
        op = random.choice(["+", "-", "*"])
        if op == "+":
            ans = a+b; reason = f"{a}+{b} equals {ans}"
        elif op == "-":
            ans = a-b; reason = f"{a}-{b} equals {ans}"
        else:
            ans = a*b; reason = f"{a}*{b} equals {ans}"
        q = f"{a}{op}{b}, x=?"
        pos = f"{q} The answer is {ans} because {reason}."
        neg = f"{q} Sorry, I do not know!"
        pairs.append({"negative": neg, "positive": pos})

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(pairs, f, indent=2)
    print(f"Saved {len(pairs)} pairs to {out_path}")

### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [9]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

#optimizer (AdamW)
optimizer = AdamW(gpt.parameters(), lr=1e-5, weight_decay=0.01)

#scheduler
num_training_steps = 1000  # for example
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps
)

### Step 7: Begin training (**students are required to complete this part!**)

In [28]:
CKPT_PATH = "./dpo.pt"   # <-- use this everywhere


In [13]:
import json, itertools, torch

# ---------- 1) Load the raw JSON ----------
with open("./pos_neg_pairs.json", "r") as f:
    raw = json.load(f)

# Some files wrap the list, e.g. {"pairs":[...]} / {"data":[...]}
if isinstance(raw, dict):
    for k in ("pairs","data","examples","items"):
        if k in raw and isinstance(raw[k], list):
            pairs = raw[k]
            break
    else:
        raise ValueError("Couldn't find a list of pairs inside the JSON dict.")
elif isinstance(raw, list):
    pairs = raw
else:
    raise ValueError("JSON must be a list or a dict containing a list.")

print(f"Loaded {len(pairs)} preference pairs (raw).")

# ---------- 2) Figure out the schema ----------
def extract_pair(p):
    """
    Return (neg_text, pos_text) as strings from one item p, trying several common schemas.
    If a 'prompt' exists, it will be prefixed to both completions.
    """
    # Helper to optionally prepend prompt/context
    def join_prompt(prompt, ans):
        if prompt:
            # keep a simple, consistent format for math/QA
            return f"{prompt.strip()}\n{ans.strip()}"
        return ans.strip()

    # Case A: dict-like with various key names
    if isinstance(p, dict):
        d = {k.lower(): v for k, v in p.items()}

        # Candidate key pairs in order of likelihood
        key_pairs = [
            ("neg","pos"),
            ("negative","positive"),
            ("rejected","chosen"),
            ("dispreferred","preferred"),
            ("bad","good"),
            ("worse","better"),
            ("lose","win"),
            ("neg_resp","pos_resp"),
            ("neg_text","pos_text"),
            ("completion_b","completion_a"),  # some datasets name A=preferred
        ]

        prompt = d.get("prompt") or d.get("question") or d.get("input") or d.get("context") or ""

        for neg_k, pos_k in key_pairs:
            if neg_k in d and pos_k in d:
                neg_text = join_prompt(prompt, str(d[neg_k]))
                pos_text = join_prompt(prompt, str(d[pos_k]))
                return neg_text, pos_text

        # Another common form: have "chosen"/"rejected"
        if "chosen" in d and "rejected" in d:
            neg_text = join_prompt(prompt, str(d["rejected"]))
            pos_text = join_prompt(prompt, str(d["chosen"]))
            return neg_text, pos_text

        # Another: have "pos"/"neg" but plus a separate "answer" keys
        if "pos" in d and "neg" in d:
            neg_text = join_prompt(prompt, str(d["neg"]))
            pos_text = join_prompt(prompt, str(d["pos"]))
            return neg_text, pos_text

        # If there are only two text fields but weird names, try to guess:
        # look for the two longest string fields as candidates
        text_fields = [(k,v) for k,v in d.items() if isinstance(v, str)]
        if len(text_fields) >= 2:
            # heuristic: prefer keys hinting pos/neg if present
            name = " ".join(d.keys()).lower()
            # fallback: pick two and assume the first is negative if its key looks like neg/bad/reject
            # else we’ll try length heuristic
            neg_like = [k for k,_ in text_fields if any(tag in k for tag in ["neg","bad","reject","worse","lose","dispref"])]
            pos_like = [k for k,_ in text_fields if any(tag in k for tag in ["pos","good","chosen","better","win","pref"])]
            if neg_like and pos_like:
                neg_text = join_prompt(prompt, str(d[neg_like[0]]))
                pos_text = join_prompt(prompt, str(d[pos_like[0]]))
                return neg_text, pos_text
            # length heuristic
            text_fields_sorted = sorted(text_fields, key=lambda kv: -len(kv[1]))
            pos_text = join_prompt(prompt, text_fields_sorted[0][1])
            neg_text = join_prompt(prompt, text_fields_sorted[1][1])
            return neg_text, pos_text

        raise KeyError("Unrecognized dict schema for a pair item.")

    # Case B: list/tuple of two strings
    if isinstance(p, (list, tuple)) and len(p) >= 2 and all(isinstance(x, str) for x in p[:2]):
        # Heuristic: assume index 0 = positive if marker present, else index 1 = positive.
        # To be safe for DPO, we’ll assume [pos, neg]; flip if you know your file is [neg, pos].
        pos_text, neg_text = p[0].strip(), p[1].strip()
        return neg_text, pos_text

    raise KeyError("Unrecognized pair item type.")

# ---------- 3) Build encoded dataset ----------
# Require the same tokenizer mapping (stoi/itos) you loaded earlier from meta.pkl
def encode_text(txt):
    return torch


Loaded 100000 preference pairs (raw).


In [25]:
import torch
import torch.nn.functional as F

PAD_ID = 0  # you padded with 0 above; keep consistent

def make_shifted_targets(x, lengths):
    """
    x: (B, T) input ids (already padded with PAD_ID)
    lengths: list/1D tensor of true lengths for each row (without right pad)
    Returns y with shape (B, T) where:
      y[:, :-1] = x[:, 1:], y[:, -1] = PAD_ID, and
      positions >= (length-1) are set to -100 (ignore_index) so CE ignores them.
    """
    B, T = x.shape
    y = x.clone()
    y[:, :-1] = x[:, 1:]
    y[:, -1] = PAD_ID
    y = y.long()
    # mask everything at/after each (length-1)
    y_mask = torch.zeros_like(y).bool()
    for i, L in enumerate(lengths):
        cut = max(int(L) - 1, 0)
        if cut < T:
            y_mask[i, cut:] = True
    y[y_mask] = -100  # ignore_index
    return y

def sequence_mean_logprob(model, x, lengths):
    """
    Compute mean log-prob per example, ignoring padding (via -100 targets).
    Returns tensor of shape (B,) with average logprob per token for each sequence.
    """
    y = make_shifted_targets(x, lengths)             # (B, T) with -100 masked
    logits, _ = model(x, y=None)                     # logits: (B, T, V)
    logprobs = F.log_softmax(logits, dim=-1)         # (B, T, V)
    # gather at target ids (valid positions have y>=0; ignored are -100)
    B, T = y.shape
    gather_y = y.clone()
    gather_y[gather_y < 0] = 0                       # placeholder index for ignored
    token_lp = logprobs.gather(2, gather_y.unsqueeze(-1)).squeeze(-1)  # (B, T)
    # zero out ignored positions
    mask = (y != -100).float()
    # sum per sequence and divide by number of valid tokens
    seq_sum = (token_lp * mask).sum(dim=1)           # (B,)
    seq_cnt = mask.sum(dim=1).clamp_min(1.0)         # avoid /0
    return seq_sum / seq_cnt                         # (B,)


In [None]:
import torch.nn.functional as F
from tqdm import tqdm

beta = 0.1   # DPO temperature (adjustable)
total_steps = len(lines) // batch_size

for epoch in range(epochs):
    epoch_loss = 0.0
    num_batches = 0   # <-- add this line

    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor, pos_tensor) in enumerate(pbar):
        # ---------------------------------------------------------
        # 1. Move batches to device
        # ---------------------------------------------------------
        neg_tensor = neg_tensor.to(device)
        pos_tensor = pos_tensor.to(device)

        # ---------------------------------------------------------
        # 2. Compute log-probs (mean log-likelihoods) for both
        # ---------------------------------------------------------
        with torch.autocast(device_type=device.type,
                            dtype=torch.float16 if device.type != "cpu" else torch.float32):
            pos_logits, pos_loss = gpt(pos_tensor)
            neg_logits, neg_loss = gpt(neg_tensor)

            pos_logprob = -pos_loss
            neg_logprob = -neg_loss

            # -----------------------------------------------------
            # 3. DPO objective
            # -----------------------------------------------------
            loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean()

        # ---------------------------------------------------------
        # 4. Backprop + step
        # ---------------------------------------------------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1   # <-- increment count
        pbar.set_description(
            f"Epoch {epoch+1} | Step {step+1}/{total_steps} | Loss {loss.item():.4f}"
        )

    # -------------------------------------------------------------
    # 5. End-of-epoch reporting & checkpoint
    # -------------------------------------------------------------
    mean_loss = epoch_loss / num_batches if num_batches > 0 else float('nan')
    print(f"Epoch {epoch+1} mean loss: {mean_loss:.4f}")

    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"✅ Saved checkpoint to {ckpt_path}")


NameError: name 'lr' is not defined

### Step 8: Begin testing (**students are required to complete this part!**)

In [55]:
import pickle, torch, re

# device = torch.device(
#     "cuda" if torch.cuda.is_available()
#     else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
#     else "cpu"
# )
device = torch.device("cpu")
gpt = gpt.to(device)

print("Using device:", device)

with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]

PAD_ID = 0  # we pad with 0 in batching

import torch

def encode(text: str) -> torch.Tensor:
    # return 1D ids (we'll add batch dim in generate)
    ids = [stoi.get(ch, 0) for ch in text]
    return torch.tensor(ids, dtype=torch.long)

@torch.no_grad()
def generate(model, prompt: str, max_new_tokens=64, temperature=1.0, top_k=1):
    # 1) encode
    x = encode(prompt)
    # 2) ensure 2D: (1, T)
    if x.dim() == 1:
        x = x.unsqueeze(0)
    # 3) trim to block_size if needed
    try:
        block_size = model.config.block_size
        if x.size(1) > block_size:
            x = x[:, -block_size:]
    except Exception:
        pass
    # 4) move to device
    x = x.to(device)

    # normalize sampler args for your model's generate()
    if top_k is not None and top_k <= 0:
        top_k = 1
    if temperature is None or temperature <= 0:
        temperature = 1e-5

    # 5) generate
    out = model.generate(
        x, max_new_tokens=max_new_tokens,
        temperature=temperature, top_k=top_k
    )
    if isinstance(out, tuple):
        out = out[0]
    return decode(out)


def _to_list(ids):
    if isinstance(ids, torch.Tensor): ids = ids.detach().cpu().tolist()
    out = []
    def flat(x):
        if isinstance(x, (list, tuple)):
            for y in x: flat(y)
        else:
            out.append(int(x))
    flat(ids)
    return out

def decode(ids) -> str:
    ids = _to_list(ids)
    if isinstance(itos, dict):
        return "".join(itos.get(i, "") for i in ids)
    L = len(itos)
    return "".join(itos[i] if 0 <= i < L else "" for i in ids)

def parse_last_number(text: str):
    nums = re.findall(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", text)
    return float(nums[-1]) if nums else None




Using device: cpu


In [44]:
import json, os

with open("./pos_neg_pairs.json", "r") as f:
    raw = json.load(f)

pairs_list = raw["pairs"] if isinstance(raw, dict) and isinstance(raw.get("pairs"), list) else raw
print("Raw pairs:", len(pairs_list))

def extract_pair(p):
    # Returns (NEG_text, POS_text). Try common schemas; adjust if needed.
    if isinstance(p, dict):
        d = {k.lower(): v for k,v in p.items()}
        prompt = d.get("prompt") or d.get("question") or d.get("input") or ""
        def join(pr, ans): return (pr.strip()+"\n"+str(ans).strip()) if pr else str(ans).strip()

        candidates = [
            ("neg","pos"),
            ("negative","positive"),
            ("rejected","chosen"),
            ("dispreferred","preferred"),
            ("bad","good"),
            ("completion_b","completion_a"),
        ]
        for nk, pk in candidates:
            if nk in d and pk in d:
                return join(prompt, d[nk]), join(prompt, d[pk])
        if "chosen" in d and "rejected" in d:
            return join(prompt, d["rejected"]), join(prompt, d["chosen"])
    if isinstance(p, (list, tuple)) and len(p) >= 2 and all(isinstance(x, str) for x in p[:2]):
        # assume [POS, NEG] is common; flip to (NEG, POS)
        return p[1], p[0]
    raise KeyError("Unrecognized pair format")

# Build encoded lines (NEG_ids, POS_ids)
lines = []
skipped = 0
for i, p in enumerate(pairs_list):
    try:
        neg_txt, pos_txt = extract_pair(p)
        lines.append((encode(neg_txt), encode(pos_txt)))
    except Exception as e:
        skipped += 1
print(f"Usable pairs: {len(lines)}  | Skipped: {skipped}")

# >>> VERY IMPORTANT: inspect a few samples <<<
for i in range(2):
    n, p = lines[i]
    print("\n--- Sample", i, "---")
    print("NEG:", decode(n[:120]))
    print("POS:", decode(p[:120]))


Raw pairs: 100000
Usable pairs: 100000  | Skipped: 0

--- Sample 0 ---
NEG: 95-62, x=? Sorry, I do not know

POS: 95-62, x=? The answer is 33 because 95-62 equals 33.

--- Sample 1 ---
NEG: 41*42, x=? Sorry, I do not know

POS: 41*42, x=? The answer is 1722 because 41*42 equals 1722.


In [45]:
import torch.nn.utils.rnn as rnn

def get_batches(lines, batch_size):
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        negs = [n for n, _ in batch]
        poss = [p for _, p in batch]
        neg_len = torch.tensor([len(t) for t in negs], dtype=torch.long)
        pos_len = torch.tensor([len(t) for t in poss], dtype=torch.long)
        neg_pad = rnn.pad_sequence(negs, batch_first=True, padding_value=PAD_ID)
        pos_pad = rnn.pad_sequence(poss, batch_first=True, padding_value=PAD_ID)
        yield (neg_pad, neg_len), (pos_pad, pos_len)


In [53]:
import torch
import torch.nn.functional as F

PAD_ID = 0  # keep consistent with your padding

# def make_targets_shifted(x, lengths):
#     """
#     Shift inputs to create next-token targets and mask padding with -100.
#     x: (B, T) token ids
#     lengths: (B,) true lengths before right padding
#     return: y (B, T) with -100 at ignored positions
#     """
#     y = x.clone()
#     y[:, :-1] = x[:, 1:]
#     y[:, -1] = PAD_ID
#     B, T = x.shape
#     mask = torch.zeros_like(y, dtype=torch.bool)
#     for i, L in enumerate(lengths.tolist()):
#         cut = max(int(L) - 1, 0)
#         if cut < T:
#             mask[i, cut:] = True
#     y[mask] = -100  # ignore_index for CrossEntropy
#     return y

# def mean_logprob_per_seq(model, x, lengths):
#     """
#     One forward pass over the whole sequence.
#     Works with NanoGPT-style forward that returns (logits, loss) when given targets POSITIONALLY.
#     """
#     # 1) build masked targets
#     y = make_targets_shifted(x, lengths)              # (B, T), -100 where ignored

#     # 2) forward: IMPORTANT — pass targets POSITIONALLY, not as keyword
#     # Expected: logits shape (B, T, V)
#     out = model(x, y)                                 # <-- key change
#     logits = out[0] if isinstance(out, (tuple, list)) else out

#     if logits.dim() != 3:
#         raise RuntimeError(f"Expected logits of shape (B, T, V), got {tuple(logits.shape)}")

#     # 3) compute token log-probs and average per sequence (ignore -100)
#     logp = F.log_softmax(logits, dim=-1)              # (B, T, V)

#     y_gather = y.clone()
#     y_gather[y_gather < 0] = 0                        # safe index for ignored
#     token_lp = logp.gather(2, y_gather.unsqueeze(-1)).squeeze(-1)  # (B, T)

#     valid = (y != -100).float()
#     seq_sum = (token_lp * valid).sum(dim=1)           # (B,)
#     seq_cnt = valid.sum(dim=1).clamp_min(1.0)
#     return seq_sum / seq_cnt                          # (B,)


def mean_logprob_per_seq(model, x, lengths):
    """
    Works when model only returns last-step logits.
    Average per-token log-prob via teacher forcing:
      for t in [0..T-2], run model on x[:, :t+1] and score next token x[:, t+1].
    Ignores padding using 'lengths'.
    Returns: (B,) of mean log-prob per valid token.
    """
    B, T = x.shape
    device = x.device
    total_logp = torch.zeros(B, device=device)
    total_cnt  = torch.zeros(B, device=device)

    # DO NOT use autocast here on MPS (saves memory)
    for t in range(T - 1):
        # rows that still have a next token at t+1
        valid_mask = (lengths > (t + 1))  # (B,)
        if not bool(valid_mask.any()):
            break

        prefix  = x[:, :t+1]          # (B, t+1)
        next_id = x[:, t+1]           # (B,)

        out = model(prefix)           # logits for last step only
        logits = out[0] if isinstance(out, (tuple, list)) else out
        if logits.dim() == 3:         # (B,1,V) -> (B,V)
            logits = logits[:, -1, :]

        logp = F.log_softmax(logits, dim=-1)            # (B,V)
        lp_next = logp.gather(1, next_id.unsqueeze(1)).squeeze(1)  # (B,)

        total_logp = total_logp + lp_next * valid_mask.float()
        total_cnt  = total_cnt  + valid_mask.float()

    return total_logp / total_cnt.clamp_min(1.0)



In [50]:
# small batch if you still see OOM; try 16 first
batch_size = min(batch_size, 16)  

for epoch in range(epochs):
    epoch_loss, num_batches = 0.0, 0
    pbar = tqdm(get_batches(lines, batch_size), total=(len(lines)+batch_size-1)//batch_size)
    gpt.train()

    for (neg_pad, neg_len), (pos_pad, pos_len) in pbar:
        neg_pad = neg_pad.to(device);  neg_len = neg_len.to(device)
        pos_pad = pos_pad.to(device);  pos_len = pos_len.to(device)

        # IMPORTANT: no autocast on MPS to reduce memory pressure
        neg_lp = mean_logprob_per_seq(gpt, neg_pad, neg_len)   # (B,)
        pos_lp = mean_logprob_per_seq(gpt, pos_pad, pos_len)   # (B,)

        loss = -F.logsigmoid((pos_lp - neg_lp) / beta).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item(); num_batches += 1
        pbar.set_description(f"Epoch {epoch+1} | Loss {loss.item():.4f}")

    print(f"Epoch {epoch+1} mean loss: {epoch_loss / max(num_batches,1):.4f}")

    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": checkpoint['model_args'],
    }, CKPT_PATH)
    print(f"✅ Saved to {CKPT_PATH}")


Epoch 1 | Loss 0.0000: 100%|██████████| 6250/6250 [10:35:05<00:00,  6.10s/it] 

Epoch 1 mean loss: 0.0064
✅ Saved to ./dpo.pt





In [37]:
from tqdm import tqdm

beta = 0.1
epochs = 1           # start with 1 to sanity-check; increase later
batch_size = 64      # whatever you used earlier
lr = 3e-4            # example; keep your original

optimizer = torch.optim.AdamW(gpt.parameters(), lr=lr)

total_steps = (len(lines) + batch_size - 1) // batch_size
for epoch in range(epochs):
    gpt.train()
    running, batches = 0.0, 0
    pbar = tqdm(get_batches(lines, batch_size), total=total_steps)
    for (neg_pad, neg_len), (pos_pad, pos_len) in pbar:
        neg_pad = neg_pad.to(device);  neg_len = neg_len.to(device)
        pos_pad = pos_pad.to(device);  pos_len = pos_len.to(device)

        with torch.autocast(device_type=device.type,
                    dtype=torch.float16 if device.type != "cpu" else torch.float32):
                neg_lp = mean_logprob_per_seq(gpt, neg_pad, neg_len)   # (B,)
                pos_lp = mean_logprob_per_seq(gpt, pos_pad, pos_len)   # (B,)
                loss   = -F.logsigmoid((pos_lp - neg_lp) / beta).mean()


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running += loss.item(); batches += 1
        pbar.set_description(f"Epoch {epoch+1} | Loss {loss.item():.4f}")

    print(f"Epoch {epoch+1} mean loss: {running / max(batches,1):.4f}")

    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": checkpoint['model_args'],
    }, CKPT_PATH)
    print(f"✅ Saved to {CKPT_PATH}")


  0%|          | 0/1563 [00:24<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 9.06 GiB, other allocations: 14.69 MiB, max allowed: 9.07 GiB). Tried to allocate 5.95 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [57]:
import re

def first_int_after_answer(text: str):
    # Take only the substring after "Answer"
    m = re.search(r'answer.*?:\s*(.*)', text, flags=re.IGNORECASE|re.DOTALL)
    tail = m.group(1) if m else text
    m2 = re.search(r'[-+]?\d+', tail)  # first integer sequence
    return int(m2.group()) if m2 else None


In [59]:
import torch
import torch.nn.functional as F

# Build the allowed charset -> token ids once
ALLOWED_CHARS = "0123456789-"
ALLOWED_IDS = [stoi.get(ch, None) for ch in ALLOWED_CHARS]
ALLOWED_IDS = [i for i in ALLOWED_IDS if i is not None]
assert len(ALLOWED_IDS) > 0, "Tokenizer has no digit tokens?!"

def generate_number_only(model, prompt: str, max_digits: int = 6):
    """
    Greedy decode but only allow digits (and optional '-').
    Stops when it produced at least 1 digit and then hits a non-digit next, or max_digits reached.
    """
    # 1) encode and batchify
    x = torch.tensor([stoi.get(ch, 0) for ch in prompt], dtype=torch.long)
    if x.dim() == 1: x = x.unsqueeze(0)                # (1, T)
    # 2) trim to block_size if necessary
    try:
        block = model.config.block_size
        if x.size(1) > block:
            x = x[:, -block:]
    except Exception:
        pass
    x = x.to(device)

    produced = []
    for t in range(max_digits):
        # forward; many tiny GPTs return last-step logits only
        out = model(x)
        logits = out[0] if isinstance(out, (tuple, list)) else out   # (1, V) or (1,1,V)
        if logits.dim() == 3:
            logits = logits[:, -1, :]  # (1, V)

        # mask everything except allowed ids
        mask = torch.full_like(logits, float("-inf"))
        mask[:, ALLOWED_IDS] = 0.0
        masked = logits + mask

        # greedy pick
        next_id = masked.argmax(dim=-1)  # (1,)
        nid = next_id.item()

        # append
        produced.append(nid)

        # append to input (incremental generation)
        x = torch.cat([x, next_id.view(1, 1)], dim=1)

        # simple stopping: if we already have ≥1 digit and model tries to repeat '-' as 2nd char, ignore;
        # we’ll just rely on max_digits to stop.
        # (You can also stop if the next best non-digit is far away; not needed here.)

    # decode just the produced part
    txt = "".join(itos[i] if isinstance(itos, list) else itos.get(i, "") for i in produced)
    # sanitize: keep leading '-' then digits
    import re
    m = re.match(r"^-?\d+", txt)
    return m.group(0) if m else ""


In [64]:
import re
import torch

FEWSHOT = (
    "3+4=? The answer is 7 because 3+4 equals 7.\n"
    "12*3=? The answer is 36 because 12*3 equals 36.\n"
)

def format_prompt(q: str) -> str:
    # match POS style exactly; we end with a trailing space
    return FEWSHOT + f"{q} The answer is "

@torch.no_grad()
def gen_and_parse(model, q, max_new_tokens=16, temperature=0.6, top_k=40):
    prompt = format_prompt(q)
    x = torch.tensor([stoi.get(ch,0) for ch in prompt], dtype=torch.long).unsqueeze(0).to(device)

    y = model.generate(x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    if isinstance(y, tuple): y = y[0]
    txt = decode(y)

    # *** Parse ONLY the completion part ***
    comp = txt[len(prompt):]  # everything the model added after our prompt

    # optional: cut off at common stop tokens to avoid trailing junk
    for stop in [" because", "\n", ".", " Answer", "The answer is"]:
        idx = comp.find(stop)
        if idx > 0:
            comp = comp[:idx]
            break

    # now extract the first integer from the completion
    m = re.search(r"[-+]?\d+", comp)
    pred = int(m.group()) if m else None
    return prompt, comp, pred


In [65]:
# load the same ./dpo.pt as you just saved earlier (you already do this above)
FEWSHOT = (
    "3+4=? The answer is 7 because 3+4 equals 7.\n"
    "12*3=? The answer is 36 because 12*3 equals 36.\n"
)

def format_prompt(q: str) -> str:
    return FEWSHOT + f"{q} The answer is "



test_set = [
    ("17+19=?", 36),
    ("3*17=?", 51),
    ("72/4=?", 18),
    ("72-x=34, x=?", 38),
    ("x*11=44, x=?", 4),
]

correct = 0
print("\n=== EVALUATION (few-shot; parse completion only) ===")
for q, tgt in test_set:
    prompt, comp, pred = gen_and_parse(gpt, q, max_new_tokens=16, temperature=0.6, top_k=40)
    ok = (pred is not None) and (pred == int(tgt))
    correct += int(ok)
    print("Q:", q)
    print("Completion:", comp.strip().replace("\n"," ⏎ "))
    print("Parsed:", pred, "| Target:", tgt, "|", "✓" if ok else "✗")
    print("-"*70)
print(f"Accuracy: {correct}/{len(test_set)}")





=== EVALUATION (few-shot; parse completion only) ===
Q: 17+19=?
Completion: 4142462704982724
Parsed: 4142462704982724 | Target: 36 | ✗
----------------------------------------------------------------------
Q: 3*17=?
Completion: 2704228e7294=8+1
Parsed: 2704228 | Target: 51 | ✗
----------------------------------------------------------------------
Q: 72/4=?
Completion: 346444464478+664
Parsed: 346444464478 | Target: 18 | ✗
----------------------------------------------------------------------
Q: 72-x=34, x=?
Completion: 6447463544414=14
Parsed: 6447463544414 | Target: 38 | ✗
----------------------------------------------------------------------
Q: x*11=44, x=?
Completion: 37418487472-6949
Parsed: 37418487472 | Target: 4 | ✗
----------------------------------------------------------------------
Accuracy: 0/5


In [61]:
(neg_pad, neg_len), (pos_pad, pos_len) = next(get_batches(lines[:64], 8))
neg_pad, neg_len = neg_pad.to(device), neg_len.to(device)
pos_pad, pos_len = pos_pad.to(device), pos_len.to(device)
gpt.eval()
with torch.no_grad():
    d = (mean_logprob_per_seq(gpt, pos_pad, pos_len) - 
         mean_logprob_per_seq(gpt, neg_pad, neg_len)).mean().item()
print("Avg Δ (pos-neg) =", d)  # should be > 0


Avg Δ (pos-neg) = 2.746781349182129


did not work new try

Utilities: find common prefix + build completion masks

In [75]:
import torch
import torch.nn.functional as F

PAD_ID = 0  # keep consistent

def common_prefix_len(a: torch.Tensor, b: torch.Tensor):
    """
    a,b: 1D Long tensors
    Returns the length of the exact shared prefix.
    """
    L = min(a.numel(), b.numel())
    i = 0
    while i < L and int(a[i]) == int(b[i]):
        i += 1
    return i


def build_completion_targets(padded: torch.Tensor, true_lens: torch.Tensor, comp_starts: torch.Tensor):
    """
    y[:, t] = next token id to predict at t; -100 outside completion.
    """
    padded = padded.long()
    true_lens = true_lens.long()
    comp_starts = comp_starts.long()

    B, T = padded.shape
    y = padded.clone()
    y[:, :-1] = padded[:, 1:]
    y[:, -1] = PAD_ID

    mask = torch.zeros_like(y, dtype=torch.bool)
    for i in range(B):
        L = int(true_lens[i])
        start = max(int(comp_starts[i]) - 1, 0)  # shift by one because y predicts next
        end = max(L - 1, 0)                      # last valid prediction index
        # mask BEFORE start
        if start > 0: mask[i, :start] = True
        # mask AFTER end
        if end < T:  mask[i, end:] = True

    y[mask] = -100
    return y


    # Slow path: last-token logits → loop over time
    B, T = padded.shape
    device = padded.device
    total_lp = torch.zeros(B, device=device)
    total_cnt = torch.zeros(B, device=device)

    for t in range(T):
        y_t = y[:, t]
        valid = (y_t >= 0)  # only positions inside the completion span
        if not bool(valid.any()):
            continue

        prefix = padded[:, :t+1]
        out = model(prefix)
        logits = out[0] if isinstance(out, (tuple, list)) else out
        if logits.dim() == 3:
            logits = logits[:, -1, :]  # (B, V)

        logp = F.log_softmax(logits, dim=-1)  # (B, V)

        # >>> Gather only on valid rows to avoid indexing -100 <<<
        y_valid = y_t[valid].long().unsqueeze(1)       # (Bv, 1)
        lp_next = logp[valid].gather(1, y_valid).squeeze(1)  # (Bv,)

        total_lp[valid] += lp_next
        total_cnt[valid] += 1.0

    return total_lp / total_cnt.clamp_min(1.0)



In [76]:
# One mini-batch check
(neg_pad, neg_len, comp_start), (pos_pad, pos_len, comp_start2) = next(get_batches_with_compinfo(lines[:64], 8))
y_pos = build_completion_targets(pos_pad, pos_len, comp_start2)
assert (y_pos >= 0).any(dim=1).all(), "Found a sample with zero completion tokens."


Build batches that also carry completion start indices
We need the index where the completion (the part after the shared prompt) starts for each pair. 

In [77]:
import torch.nn.utils.rnn as rnn

def get_batches_with_compinfo(lines, batch_size):
    """
    lines: list of (neg_ids, pos_ids) 1D tensors (prompt+completion).
    For each pair, compute the shared-prefix length and keep it as the completion start.
    """
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        negs, poss = [], []
        neg_len, pos_len, comp_start = [], [], []
        for (n, p) in batch:
            L = common_prefix_len(n, p)
            negs.append(n); poss.append(p)
            neg_len.append(len(n)); pos_len.append(len(p))
            comp_start.append(L)
        neg_len = torch.tensor(neg_len, dtype=torch.long)
        pos_len = torch.tensor(pos_len, dtype=torch.long)
        comp_start = torch.tensor(comp_start, dtype=torch.long)
        neg_pad = rnn.pad_sequence(negs, batch_first=True, padding_value=PAD_ID)
        pos_pad = rnn.pad_sequence(poss, batch_first=True, padding_value=PAD_ID)
        yield (neg_pad, neg_len, comp_start), (pos_pad, pos_len, comp_start)


In [78]:
import torch
import torch.nn.functional as F

def mean_completion_logprob(model, padded: torch.Tensor, y: torch.Tensor):
    """
    Returns mean log-prob per sequence computed ONLY over completion tokens.
    'y' must be (B,T) with -100 at positions to ignore.
    Works with:
      - Fast path: model(padded) -> logits (B,T,V)
      - Slow path: model(prefix) -> last-step logits (B,V) or (B,1,V)
    """
    B, T = padded.shape
    device = padded.device

    # ---------- FAST PATH: full-sequence logits (B,T,V) ----------
    out = model(padded)
    logits = out[0] if isinstance(out, (tuple, list)) else out
    if logits.dim() == 3 and logits.size(1) == T:
        logp = F.log_softmax(logits, dim=-1)          # (B,T,V)
        mask = (y != -100)                            # (B,T) bool

        gather_y = y.clone()
        gather_y[~mask] = 0                           # safe index
        token_lp = logp.gather(2, gather_y.unsqueeze(-1)).squeeze(-1)  # (B,T)

        seq_sum = (token_lp * mask.float()).sum(dim=1)                 # (B,)
        seq_cnt = mask.float().sum(dim=1).clamp_min(1.0)               # (B,)
        return seq_sum / seq_cnt

    # ---------- SLOW PATH: last-step logits; loop over time ----------
    total_lp = torch.zeros(B, device=device)
    total_cnt = torch.zeros(B, device=device)

    for t in range(T):
        y_t = y[:, t]
        valid_idx = (y_t >= 0).nonzero(as_tuple=False).squeeze(1)  # indices where we have a real target
        if valid_idx.numel() == 0:
            continue

        # run ONLY the valid rows for this time-step
        prefix = padded[valid_idx, :t+1]                            # (Bv, t+1)
        out = model(prefix)
        logits = out[0] if isinstance(out, (tuple, list)) else out  # (Bv,V) or (Bv,1,V)
        if logits.dim() == 3:
            logits = logits[:, -1, :]                               # (Bv,V)

        logp = F.log_softmax(logits, dim=-1)                        # (Bv,V)
        tgt  = y_t[valid_idx].long().unsqueeze(1)                   # (Bv,1)
        lp_next = logp.gather(1, tgt).squeeze(1)                    # (Bv,)

        total_lp[valid_idx] += lp_next
        total_cnt[valid_idx] += 1.0

    return total_lp / total_cnt.clamp_min(1.0)


Tiny SFT warm-start on POS completions (1 short epoch)
This teaches the pattern “… The answer is 36” explicitly.

In [79]:
# one quick SFT pass on POS completions only
gpt.train()
opt_sft = torch.optim.AdamW(gpt.parameters(), lr=3e-4)

from tqdm import tqdm
BATCH = 16
steps = 0
for (neg_pad, neg_len, comp_start), (pos_pad, pos_len, comp_start2) in tqdm(get_batches_with_compinfo(lines[:5000], BATCH), total=(min(5000,len(lines))+BATCH-1)//BATCH):
    # targets only over completion span
    y_pos = build_completion_targets(pos_pad, pos_len, comp_start2).to(device)
    pos_pad = pos_pad.to(device)
    out = gpt(pos_pad)                      # forward WITHOUT y; we compute CE ourselves
    logits = out[0] if isinstance(out,(tuple,list)) else out   # (B,T,V) or (B,1,V)
    if logits.dim()==3 and logits.size(1)==pos_pad.size(1):
        loss = F.cross_entropy(logits.transpose(1,2), y_pos, ignore_index=-100)
    else:
        # fallback: loop over time (slower)
        loss = 0.0
        B,T = pos_pad.shape
        for t in range(T):
            valid = (y_pos[:, t] >= 0)
            if not bool(valid.any()): continue
            out = gpt(pos_pad[:, :t+1])
            logits = out[0] if isinstance(out,(tuple,list)) else out
            if logits.dim() == 3: logits = logits[:, -1, :]
            loss = loss + F.cross_entropy(logits[valid], y_pos[valid, t], ignore_index=-100)
        loss = loss / T

    opt_sft.zero_grad()
    loss.backward()
    opt_sft.step()
    steps += 1
    if steps % 50 == 0:
        print("SFT step", steps, "loss", float(loss))
    if steps >= 300:     # ~300 updates is enough for a warm-start
        break
print("SFT warm-start done.")


 16%|█▌        | 50/313 [03:21<15:56,  3.64s/it]

SFT step 50 loss 0.17940084636211395


 32%|███▏      | 100/313 [06:22<12:48,  3.61s/it]

SFT step 100 loss 0.15409192442893982


 48%|████▊     | 150/313 [09:35<12:43,  4.68s/it]

SFT step 150 loss 0.13823696970939636


 64%|██████▍   | 200/313 [12:44<06:41,  3.55s/it]

SFT step 200 loss 0.16314968466758728


 80%|███████▉  | 250/313 [15:53<04:02,  3.85s/it]

SFT step 250 loss 0.13672390580177307


 96%|█████████▌| 299/313 [19:31<00:54,  3.92s/it]

SFT step 300 loss 0.1342238485813141
SFT warm-start done.





DPO — completion-only objective
Now redo (or continue) DPO with the completion-only mean log-prob:

In [80]:
beta = 0.1
optimizer = torch.optim.AdamW(gpt.parameters(), lr=3e-4)

EPOCHS = 1  # start with 1, you can bump to 2 later
BATCH = 16
total = (len(lines)+BATCH-1)//BATCH

for epoch in range(EPOCHS):
    gpt.train()
    running, nst = 0.0, 0
    for (neg_pad, neg_len, comp_start), (pos_pad, pos_len, comp_start2) in tqdm(get_batches_with_compinfo(lines, BATCH), total=total):
        neg_pad = neg_pad.to(device); pos_pad = pos_pad.to(device)
        neg_len = neg_len.to(device); pos_len = pos_len.to(device)
        comp_start = comp_start.to(device)            # same for both

        y_neg = build_completion_targets(neg_pad, neg_len, comp_start)
        y_pos = build_completion_targets(pos_pad, pos_len, comp_start)

        neg_lp = mean_completion_logprob(gpt, neg_pad, y_neg)   # (B,)
        pos_lp = mean_completion_logprob(gpt, pos_pad, y_pos)   # (B,)

        loss = -F.logsigmoid((pos_lp - neg_lp) / beta).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running += float(loss); nst += 1
    print(f"Epoch {epoch+1} mean DPO loss: {running/max(nst,1):.4f}")

torch.save({
    "model_state_dict": gpt.state_dict(),
    "model_args": checkpoint['model_args'],
}, "./dpo.pt")
print("✅ Saved ./dpo.pt")


100%|██████████| 6250/6250 [9:19:45<00:00,  5.37s/it]     

Epoch 1 mean DPO loss: 0.0000
✅ Saved ./dpo.pt





In [84]:
# reuse: build_completion_targets, get_batches_with_compinfo, mean_completion_logprob helpers you have

gpt.train()
opt_sft = torch.optim.AdamW(gpt.parameters(), lr=3e-4)

STEPS = 3000             # do ~600–1000 updates total; you did 300 before
BATCH = 16
done = 0
from tqdm import tqdm

for (neg_pad, neg_len, comp_start), (pos_pad, pos_len, comp_start2) in tqdm(
        get_batches_with_compinfo(lines[:50000], BATCH),
        total=(min(50000, len(lines))+BATCH-1)//BATCH):
    # targets = completion only for POS
    y_pos = build_completion_targets(pos_pad, pos_len, comp_start2).to(device)
    pos_pad = pos_pad.to(device)
    out = gpt(pos_pad)
    logits = out[0] if isinstance(out,(tuple,list)) else out
    if logits.dim()==3 and logits.size(1)==pos_pad.size(1):
        loss = F.cross_entropy(logits.transpose(1,2), y_pos, ignore_index=-100)
    else:
        # slow fallback
        B,T = pos_pad.shape
        loss = 0.0
        for t in range(T):
            valid = (y_pos[:, t] >= 0)
            if not bool(valid.any()): continue
            out = gpt(pos_pad[:, :t+1])
            logits = out[0] if isinstance(out,(tuple,list)) else out
            if logits.dim()==3: logits = logits[:, -1, :]
            loss = loss + F.cross_entropy(logits[valid], y_pos[valid, t], ignore_index=-100)
        loss = loss / T

    opt_sft.zero_grad()
    loss.backward()
    opt_sft.step()

    done += 1
    if done % 50 == 0:
        print(f"SFT step {done} loss {float(loss):.4f}")
    if done >= STEPS:
        break

torch.save({"model_state_dict": gpt.state_dict(),
            "model_args": checkpoint['model_args']},
           "./dpo.pt")
print("✅ SFT top-off saved ./dpo.pt")


  2%|▏         | 50/3125 [03:25<3:11:54,  3.74s/it]

SFT step 50 loss 0.1099


  3%|▎         | 100/3125 [06:16<2:33:42,  3.05s/it]

SFT step 100 loss 0.1155


  5%|▍         | 150/3125 [09:08<2:57:53,  3.59s/it]

SFT step 150 loss 0.1051


  6%|▋         | 200/3125 [11:57<2:43:27,  3.35s/it]

SFT step 200 loss 0.1265


  8%|▊         | 250/3125 [14:39<2:23:20,  2.99s/it]

SFT step 250 loss 0.1158


 10%|▉         | 300/3125 [17:41<3:09:02,  4.02s/it]

SFT step 300 loss 0.1005


 11%|█         | 350/3125 [21:01<2:52:30,  3.73s/it]

SFT step 350 loss 0.1378


 13%|█▎        | 400/3125 [23:47<2:19:32,  3.07s/it]

SFT step 400 loss 0.1049


 14%|█▍        | 450/3125 [26:31<2:21:25,  3.17s/it]

SFT step 450 loss 0.0947


 16%|█▌        | 500/3125 [29:23<2:23:02,  3.27s/it]

SFT step 500 loss 0.1307


 18%|█▊        | 550/3125 [32:02<2:19:10,  3.24s/it]

SFT step 550 loss 0.1029


 19%|█▉        | 600/3125 [34:51<2:32:51,  3.63s/it]

SFT step 600 loss 0.1252


 21%|██        | 650/3125 [37:32<2:25:21,  3.52s/it]

SFT step 650 loss 0.1109


 22%|██▏       | 700/3125 [40:21<2:24:59,  3.59s/it]

SFT step 700 loss 0.1338


 24%|██▍       | 750/3125 [42:59<2:01:40,  3.07s/it]

SFT step 750 loss 0.1074


 26%|██▌       | 800/3125 [45:37<1:55:44,  2.99s/it]

SFT step 800 loss 0.1438


 27%|██▋       | 850/3125 [48:12<2:00:12,  3.17s/it]

SFT step 850 loss 0.1062


 29%|██▉       | 900/3125 [51:06<2:03:09,  3.32s/it]

SFT step 900 loss 0.0949


 30%|███       | 950/3125 [53:52<1:57:48,  3.25s/it]

SFT step 950 loss 0.0932


 32%|███▏      | 1000/3125 [56:35<2:00:08,  3.39s/it]

SFT step 1000 loss 0.1182


 34%|███▎      | 1050/3125 [59:33<2:18:05,  3.99s/it]

SFT step 1050 loss 0.0931


 35%|███▌      | 1100/3125 [1:02:49<2:10:09,  3.86s/it]

SFT step 1100 loss 0.0922


 37%|███▋      | 1150/3125 [1:35:15<2:20:56,  4.28s/it]   

SFT step 1150 loss 0.0815


 38%|███▊      | 1200/3125 [1:38:32<2:18:25,  4.31s/it]

SFT step 1200 loss 0.1172


 40%|████      | 1250/3125 [1:41:54<2:00:49,  3.87s/it]

SFT step 1250 loss 0.1095


 42%|████▏     | 1300/3125 [1:45:11<1:54:22,  3.76s/it]

SFT step 1300 loss 0.0894


 43%|████▎     | 1350/3125 [1:48:26<1:57:26,  3.97s/it]

SFT step 1350 loss 0.0999


 45%|████▍     | 1400/3125 [1:51:34<1:47:20,  3.73s/it]

SFT step 1400 loss 0.1045


 46%|████▋     | 1450/3125 [1:54:49<1:46:11,  3.80s/it]

SFT step 1450 loss 0.1018


 48%|████▊     | 1500/3125 [1:58:08<1:41:23,  3.74s/it]

SFT step 1500 loss 0.0773


 50%|████▉     | 1550/3125 [2:01:04<1:21:08,  3.09s/it]

SFT step 1550 loss 0.0905


 51%|█████     | 1600/3125 [2:03:49<1:32:49,  3.65s/it]

SFT step 1600 loss 0.1033


 53%|█████▎    | 1650/3125 [2:06:58<1:30:29,  3.68s/it]

SFT step 1650 loss 0.0865


 54%|█████▍    | 1700/3125 [2:13:41<8:52:00, 22.40s/it] 

SFT step 1700 loss 0.0898


 56%|█████▌    | 1750/3125 [2:16:26<1:09:17,  3.02s/it]

SFT step 1750 loss 0.1217


 58%|█████▊    | 1800/3125 [2:19:05<1:31:30,  4.14s/it]

SFT step 1800 loss 0.0780


 59%|█████▉    | 1850/3125 [2:21:46<1:10:04,  3.30s/it]

SFT step 1850 loss 0.0880


 61%|██████    | 1900/3125 [2:24:40<1:08:52,  3.37s/it]

SFT step 1900 loss 0.0921


 62%|██████▏   | 1950/3125 [2:27:34<1:26:08,  4.40s/it]

SFT step 1950 loss 0.1053


 64%|██████▍   | 2000/3125 [2:30:20<1:03:41,  3.40s/it]

SFT step 2000 loss 0.0809


 66%|██████▌   | 2050/3125 [2:32:59<53:57,  3.01s/it]  

SFT step 2050 loss 0.0909


 67%|██████▋   | 2100/3125 [2:35:35<54:00,  3.16s/it]  

SFT step 2100 loss 0.0816


 69%|██████▉   | 2150/3125 [2:38:11<48:51,  3.01s/it]

SFT step 2150 loss 0.0852


 70%|███████   | 2200/3125 [2:40:47<45:51,  2.97s/it]

SFT step 2200 loss 0.1058


 72%|███████▏  | 2250/3125 [2:43:29<47:56,  3.29s/it]

SFT step 2250 loss 0.1043


 74%|███████▎  | 2300/3125 [2:46:23<45:21,  3.30s/it]  

SFT step 2300 loss 0.1064


 75%|███████▌  | 2350/3125 [2:49:01<37:27,  2.90s/it]

SFT step 2350 loss 0.1101


 77%|███████▋  | 2400/3125 [2:51:57<47:26,  3.93s/it]

SFT step 2400 loss 0.0846


 78%|███████▊  | 2450/3125 [2:55:28<56:20,  5.01s/it]

SFT step 2450 loss 0.1149


 80%|████████  | 2500/3125 [2:58:34<37:38,  3.61s/it]  

SFT step 2500 loss 0.0872


 82%|████████▏ | 2550/3125 [3:01:30<38:51,  4.06s/it]

SFT step 2550 loss 0.0935


 83%|████████▎ | 2600/3125 [3:04:31<31:35,  3.61s/it]

SFT step 2600 loss 0.0868


 85%|████████▍ | 2650/3125 [3:07:10<23:41,  2.99s/it]

SFT step 2650 loss 0.0702


 86%|████████▋ | 2700/3125 [3:09:46<22:01,  3.11s/it]

SFT step 2700 loss 0.1031


 88%|████████▊ | 2750/3125 [3:12:28<23:20,  3.73s/it]

SFT step 2750 loss 0.0977


 90%|████████▉ | 2800/3125 [3:15:24<20:16,  3.74s/it]

SFT step 2800 loss 0.1012


 91%|█████████ | 2850/3125 [3:18:24<16:32,  3.61s/it]

SFT step 2850 loss 0.0877


 93%|█████████▎| 2900/3125 [3:21:12<12:26,  3.32s/it]

SFT step 2900 loss 0.0988


 94%|█████████▍| 2950/3125 [3:24:14<12:30,  4.29s/it]

SFT step 2950 loss 0.0800


 96%|█████████▌| 2999/3125 [3:27:45<08:43,  4.16s/it]

SFT step 3000 loss 0.0930
✅ SFT top-off saved ./dpo.pt





Evaluation — simple, template-matched, short
Keep it minimal and parse only the completion:

In [85]:
import re, torch

def format_prompt(q: str) -> str:
    return f"{q} The answer is "

@torch.no_grad()
def gen_and_parse(model, q, max_new_tokens=12, temperature=0.6, top_k=40):
    prompt = format_prompt(q)
    x = torch.tensor([stoi.get(ch,0) for ch in prompt], dtype=torch.long).unsqueeze(0).to(device)
    y = model.generate(x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
    if isinstance(y, tuple): y = y[0]
    txt = decode(y)
    comp = txt[len(prompt):]
    # stop at common delimiters
    for stop in [" because", "\n", ".", " Answer", "The answer is"]:
        j = comp.find(stop)
        if j > 0:
            comp = comp[:j]
            break
    m = re.search(r"^[-+]?\d+", comp.strip())
    pred = int(m.group()) if m else None
    return comp, pred

tests = [("17+19=?", 36), ("3*17=?", 51), ("72/4=?", 18), ("72-x=34, x=?", 38), ("x*11=44, x=?", 4)]
ok = 0
for q,t in tests:
    comp, pred = gen_and_parse(gpt, q, max_new_tokens=12, temperature=0.6, top_k=40)
    good = (pred is not None) and (pred == int(t))
    ok += int(good)
    print(f"{q:<16} → out={repr(comp)} | pred={pred} | tgt={t} | {'✓' if good else '✗'}")
print(f"Accuracy: {ok}/{len(tests)}")


17+19=?          → out='1 15' | pred=1 | tgt=36 | ✗
3*17=?           → out='6 138 becaus' | pred=6 | tgt=51 | ✗
72/4=?           → out='4+49' | pred=4 | tgt=18 | ✗
72-x=34, x=?     → out='iis 5-542 eq' | pred=None | tgt=38 | ✗
x*11=44, x=?     → out='because 1*12' | pred=None | tgt=4 | ✗
Accuracy: 0/5
