In [None]:
import math, random, torch
import torch.nn as nn
import torch.nn.functional as F

device = "mps" if torch.mps.is_available() else "cpu"
torch.manual_seed(0)

# 1+2=3
# 1+4=5
# -----------------------
# Data: a+b=(a+b)%10
# -----------------------
itos = list("0123456789+=")          # index -> char
stoi = {c:i for i,c in enumerate(itos)}
V = len(itos)

def make_batch(batch_size, device):
    # sequence length 5: d + d = d
    a = torch.randint(0, 10, (batch_size,), device=device)
    b = torch.randint(0, 10, (batch_size,), device=device)
    c = (a + b) % 10

    # tokens: [a, '+', b, '=', c]
    x = torch.stack([
        a, 
        torch.full_like(a, stoi['+']),
        b,
        torch.full_like(a, stoi['=']),
        c
    ], dim=1)  # [B, 5]

    # next-token prediction: input is first 4 tokens, target is next 4 tokens
    # e.g. input:  "3+7="  target: "+7=0" (shifted by one)
    inp = x[:, :-1]   # [B, 4]
    tgt = x[:, 1:]    # [B, 4]
    return inp, tgt



In [None]:
# -----------------------
# Tiny GPT (decoder-only)
# -----------------------
class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=2, n_layers=2, max_len=4):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Parameter(torch.randn(max_len, d_model) * 0.02)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model,
            activation="gelu", batch_first=True
        )
        self.tr = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

        # causal mask for length max_len
        mask = torch.triu(torch.ones(max_len, max_len), diagonal=1).bool()
        self.register_buffer("causal_mask", mask)

    def forward(self, inp):
        # inp: [B, T]
        B, T = inp.shape
        h = self.tok(inp) + self.pos[:T].unsqueeze(0)
        h = self.tr(h, mask=self.causal_mask[:T, :T])
        logits = self.lm_head(h)  # [B, T, V]
        return logits

model = TinyGPT(vocab_size=V, d_model=64, n_heads=2, n_layers=2, max_len=4).to(device)
inp, _ = make_batch(2, device)
out = model(inp)  # test forward pass
out[0,-1].detach().cpu().numpy()  # logits for last token of first example

In [None]:
@torch.no_grad()
def sample(model, prompt_tokens, steps=1):
    # prompt_tokens: list[int], length <= 4
    # if prompt is empty, pick random number as first token
    if len(prompt_tokens) == 0:
        prompt_tokens = [random.randint(0, 9)]
    device = next(model.parameters()).device
    x = torch.tensor(prompt_tokens, device=device).unsqueeze(0)  # [1, T]
    for _ in range(steps):
        logits = model(x)                 # [1, T, V]
        next_logits = logits[0, -1]       # [V]
        nxt = torch.argmax(next_logits).item()
        x = torch.cat([x, torch.tensor([[nxt]], device=device)], dim=1)
    return x[0].tolist()

def decode(tokens):
    return "".join(itos[t] for t in tokens)



In [None]:
# -----------------------
# Train
# -----------------------

model = TinyGPT(V, d_model=64, n_heads=2, n_layers=2, max_len=4).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-3)

steps = 500
batch_size = 512

for step in range(1, steps+1):
    inp, tgt = make_batch(batch_size, device)
    logits = model(inp)  # [B, 4, V]
    loss = F.cross_entropy(logits.reshape(-1, V), tgt.reshape(-1))

    opt.zero_grad()
    loss.backward()
    opt.step()

    if step % 50 == 0:
        # quick accuracy on the final answer position (predict last digit after '=')
        with torch.no_grad():
            pred = logits.argmax(dim=-1)          # [B, 4]
            acc = (pred == tgt).float().mean().item()
        print(f"step {step:4d} | loss {loss.item():.4f} | token-acc {acc*100:5.1f}%")

# Demo: generate answer digit from "a+b="
for a, b in [(3,7), (9,9), (2,5), (4,8)]:
    prompt = [stoi[str(a)], stoi['+'], stoi[str(b)], stoi['=']]
    out = sample(model, prompt, steps=1)  # predict one token (the answer digit)
    print(decode(out), " (expected:", (a+b)%10, ")")


In [None]:
model1 = TinyGPT(V, d_model=64, n_heads=2, n_layers=2, max_len=4).to(device)
prompt = '1+3='
decode(sample(model, [stoi[c] for c in prompt], steps=1))

In [None]:
prompt = ''
for i in range(4):
    prompt = decode(sample(model, [stoi[c] for c in prompt], steps=1))
    print(prompt)


In [None]:
sum([len(p) for p in model.parameters()])

In [None]:
# visualise embedding using umap
import umap
import matplotlib.pyplot as plt
emb = model.tok.weight.detach().cpu().numpy()  # [V, d_model]
reducer = umap.UMAP(n_neighbors=5, min_dist=0.1)
emb_2d = reducer.fit_transform(emb)  # [V, 2]
plt.scatter(emb_2d[:,0], emb_2d[:,1])
for i, c in enumerate(itos):
    plt.text(emb_2d[i,0], emb_2d[i,1], c)

In [None]:
mask = torch.triu(torch.ones(3,3), diagonal=1).bool()
mask