In [None]:
import tiktoken
import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using: {device}")

enc = tiktoken.get_encoding("o200k_base")

# weights
emb = torch.randn(200019, 64, device=device) * 0.02
pos = torch.randn(32, 64, device=device) * 0.02
wq = torch.randn(64, 64, device=device) * 0.1
wk = torch.randn(64, 64, device=device) * 0.1
wv = torch.randn(64, 64, device=device) * 0.1
w1 = torch.randn(64, 256, device=device) * 0.1
w2 = torch.randn(256, 64, device=device) * 0.1

params = [emb, pos, wq, wk, wv, w1, w2]
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps)

def forward(x):
    B, T = x.shape
    x = emb[x] + pos[:T]

    nx = rmsnorm(x)
    q, k, v = nx @ wq, nx @ wk, nx @ wv
    w = q @ k.transpose(-1, -2) / 8
    mask = torch.triu(torch.ones(T, T, device=x.device), 1).bool()
    w = w.masked_fill(mask, -1e9)
    w = F.softmax(w, dim=-1)
    x = x + w @ v

    nx = rmsnorm(x)
    x = x + (nx @ w1).relu() @ w2

    return x @ emb.T

# load data
with open("input.txt", "r") as f:
    text = f.read()
tokens = torch.tensor(enc.encode(text), device=device)

# train
bs = 32
opt = torch.optim.Adam(params, lr=1e-3)

for i in range(100):
    starts = torch.randint(0, len(tokens) - 32, (bs,))
    x = torch.stack([tokens[s:s+31] for s in starts])
    y = torch.stack([tokens[s+1:s+32] for s in starts])

    logits = forward(x)
    loss = F.cross_entropy(logits.view(-1, 200019), y.view(-1))

    opt.zero_grad()
    loss.backward()
    opt.step()

    print(f"{i}: loss={loss.item():.2f}")

In [None]:
# generate
tokens = enc.encode("Hello")

for _ in range(100):
    x = torch.tensor([tokens[-31:]])
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
    next_token = torch.multinomial(probs, 1)
    tokens.append(next_token.item())

print(enc.decode(tokens))