In [1]:
# tiny_llm_fast_smoketest.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerDecoderLayer, TransformerDecoder

torch.manual_seed(0)
torch.set_num_threads(1)  # keep startup snappy on laptops

# ----- tiny char vocab utils -----
def build_char_vocab(text: str):
    chars = sorted(set(text))
    stoi = {c: i for i, c in enumerate(chars)}
    itos = {i: c for c, i in stoi.items()}
    return stoi, itos

def encode(text, stoi): return [stoi[c] for c in text if c in stoi]
def decode(ids, itos): return "".join(itos[i] for i in ids)

# ----- tiny dataset -----
class TextDataset(Dataset):
    def __init__(self, tokens, seq_len: int):
        self.tokens = tokens
        self.seq_len = seq_len
    def __len__(self):
        return max(0, len(self.tokens) - self.seq_len - 1)
    def __getitem__(self, idx):
        x = self.tokens[idx: idx+self.seq_len]
        y = self.tokens[idx+1: idx+self.seq_len+1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

# ----- positional enc -----
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 1024):
        super().__init__()
        self.pos = nn.Embedding(max_len, d_model)
    def forward(self, x):  # x: [B, T]
        T = x.size(1)
        return self.pos(torch.arange(T, device=x.device)[None, :])

# ----- tiny decoder-only via TransformerDecoder -----
class TinyLLM(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 64, nhead: int = 2, num_layers: int = 1, max_len: int = 1024):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.posenc = LearnedPositionalEncoding(d_model, max_len=max_len)
        layer = TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=False)
        self.decoder = TransformerDecoder(layer, num_layers=num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, ids: torch.Tensor):
        """
        ids: [B, T]; we feed same sequence as src & tgt and rely on causal mask.
        Returns logits: [B, T, V]
        """
        x = self.embed(ids) + self.posenc(ids)         # [B, T, E]
        h = x.transpose(0, 1)                           # [T, B, E]
        T = h.size(0)
        tgt_mask = torch.triu(torch.ones(T, T, device=h.device), diagonal=1).bool()
        out = self.decoder(tgt=h, memory=h, tgt_mask=tgt_mask)  # [T, B, E]
        return self.lm_head(out.transpose(0, 1))        # [B, T, V]

    @torch.no_grad()
    def generate(self, start_ids: torch.Tensor, max_new_tokens: int, topk: int = 1):
        """
        Greedy (topk=1) by default. start_ids: [B, T0]
        """
        self.eval()
        ids = start_ids
        for _ in range(max_new_tokens):
            logits = self.forward(ids)[:, -1, :]            # [B, V]
            if topk == 1:
                next_id = torch.argmax(logits, dim=-1, keepdim=True)  # [B,1]
            else:
                vals, idxs = torch.topk(logits, k=topk, dim=-1)
                probs = torch.softmax(vals, dim=-1)
                choice = torch.multinomial(probs, num_samples=1)      # [B,1]
                next_id = idxs.gather(-1, choice)
            ids = torch.cat([ids, next_id], dim=1)           # [B, T+1]
        return ids

# ====== MAIN: ultra-small, fast demo ======
def main():
    # 1) Tiny in-memory corpus (repeat to ~2–3k chars)
    base = "To be, or not to be, that is the question. "
    raw_text = (base * 200).strip()

    # 2) Vocab + tokens
    stoi, itos = build_char_vocab(raw_text)
    tokens = encode(raw_text, stoi)

    # 3) Small hparams for speed
    seq_len = 32
    batch_size = 16
    vocab_size = len(stoi)
    epochs = 1
    max_batches = 20  # limit steps so it finishes in seconds

    # 4) Data
    ds = TextDataset(tokens, seq_len=seq_len)
    if len(ds) == 0:
        raise RuntimeError("Corpus too small; increase repeats or lower seq_len.")
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

    # 5) Model/opt/loss (CPU)
    device = "cpu"
    model = TinyLLM(vocab_size=vocab_size, d_model=64, nhead=2, num_layers=1).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=2e-3)
    loss_fn = nn.CrossEntropyLoss()

    # 6) Train just a few mini-batches
    model.train()
    step = 0
    for epoch in range(epochs):
        for x, y in dl:
            x, y = x.to(device), y.to(device)
            logits = model(y[:, :-1])                              # predict next char
            loss = loss_fn(logits.reshape(-1, vocab_size), y[:, 1:].reshape(-1))
            opt.zero_grad()
            loss.backward()
            opt.step()
            step += 1
            if step % 5 == 0:
                print(f"step {step:02d}  loss {loss.item():.4f}")
            if step >= max_batches:
                break
        if step >= max_batches:
            break

    # 7) Quick greedy generation
    prompt = "To be, or n"
    start = torch.tensor([encode(prompt, stoi)], dtype=torch.long, device=device)  # [1, T0]
    gen = model.generate(start, max_new_tokens=80, topk=1)[0].tolist()
    print("\n=== SAMPLE ===")
    print(decode(gen, itos))

if __name__ == "__main__":
    main()

step 05  loss 1.8262
step 10  loss 1.3533
step 15  loss 1.1664
step 20  loss 1.0646

=== SAMPLE ===
To be, or no be, the, to the, t the, be, the, thathe, be, the thato the, the, thathe, t tha
