In [1]:
import tiktoken
import torch
import torch.nn.functional as F

# tokenizer
enc = tiktoken.get_encoding("o200k_base")

# all weights (scaled initialization)
emb = torch.randn(200019, 64) * 0.02
pos = torch.randn(32, 64) * 0.02
wq = torch.randn(64, 64) * 0.1
wk = torch.randn(64, 64) * 0.1
wv = torch.randn(64, 64) * 0.1
w1 = torch.randn(64, 256) * 0.1
w2 = torch.randn(256, 64) * 0.1

params = [emb, pos, wq, wk, wv, w1, w2]  # tied embeddings (no separate wo)
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps)

def forward(x):
    B, T = x.shape
    x = emb[x] + pos[:T]
    
    # attention (with pre-norm)
    nx = rmsnorm(x)
    q, k, v = nx @ wq, nx @ wk, nx @ wv
    w = q @ k.transpose(-1, -2) / 8
    mask = torch.triu(torch.ones(T, T), 1).bool()
    w = w.masked_fill(mask, -1e9)
    w = F.softmax(w, dim=-1)
    x = x + w @ v
    
    # mlp (with pre-norm)
    nx = rmsnorm(x)
    x = x + (nx @ w1).relu() @ w2
    
    # output (tied to input embeddings)
    return x @ emb.T

# load data
with open("input.txt", "r") as f:
    text = f.read()
tokens = torch.tensor(enc.encode(text))

# train with Adam
lr = 1e-3
batch_size = 32
opt = torch.optim.Adam(params, lr=lr)
prev_loss = float('inf')

for i in range(20):
    starts = torch.randint(0, len(tokens) - 32, (batch_size,))
    x = torch.stack([tokens[s:s+31] for s in starts])
    y = torch.stack([tokens[s+1:s+32] for s in starts])
    
    logits = forward(x)
    loss = F.cross_entropy(logits.view(-1, 200019), y.view(-1))
    
    if loss.item() > prev_loss:
        lr *= 0.5
        opt.param_groups[0]['lr'] = lr
    prev_loss = loss.item()
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    print(f"{i}: loss={loss.item():.2f}")

0: loss=12.21
1: loss=12.17
2: loss=12.12
3: loss=12.08
4: loss=11.96
5: loss=11.83
6: loss=11.68
7: loss=11.49
8: loss=11.27
9: loss=10.94
10: loss=10.59
11: loss=10.20
12: loss=9.70
13: loss=9.24
14: loss=8.76
15: loss=8.01
16: loss=7.89
17: loss=7.54
18: loss=7.45
19: loss=7.86


In [2]:
# generate text
tokens = enc.encode("First Citizen")

for _ in range(100):
    x = torch.tensor([tokens[-31:]])
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)  # temperature=0.8
    next_token = torch.multinomial(probs, 1)
    tokens.append(next_token.item())

print(enc.decode(tokens))

First Citizen name,
 tobuck as.

, none's.

 thou,, d,:
:
 is,Gent to lord,: it,:
 have,, men I
 brother,
 do, been we, the,
, grace;,
,,
 needlessUS, is from?

 must is and of.

, have,
 not and these,
is is of
, our be,
, well,
 thy, IA.

'd,,, let.


, her to., as,
