In [None]:
import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda' and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
elif device == 'cuda':
    dtype = torch.float16
else:
    dtype = torch.float32

print(f"Using: {device}, dtype: {dtype}")

# config
dim = 512
n_layers = 12
mlp_dim = 2048
ctx_len = 512

# load data
with open("input.txt", "r") as f:
    text = f.read()

# char tokenizer
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for c,i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
print(f"Vocab size: {vocab_size}, dim: {dim}, layers: {n_layers}")

# weights
emb = torch.randn(vocab_size, dim, device=device, dtype=dtype) * 0.02
pos = torch.randn(ctx_len, dim, device=device, dtype=dtype) * 0.02

layers = []
for _ in range(n_layers):
    layer = {
        'wq': torch.randn(dim, dim, device=device, dtype=dtype) * (0.1 / (n_layers ** 0.5)),
        'wk': torch.randn(dim, dim, device=device, dtype=dtype) * (0.1 / (n_layers ** 0.5)),
        'wv': torch.randn(dim, dim, device=device, dtype=dtype) * (0.1 / (n_layers ** 0.5)),
        'w1': torch.randn(dim, mlp_dim, device=device, dtype=dtype) * (0.1 / (n_layers ** 0.5)),
        'w2': torch.randn(mlp_dim, dim, device=device, dtype=dtype) * (0.1 / (n_layers ** 0.5)),
    }
    layers.append(layer)

params = [emb, pos] + [w for layer in layers for w in layer.values()]
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps)

def forward(x):
    B, T = x.shape
    x = emb[x] + pos[:T]

    for layer in layers:
        nx = rmsnorm(x)
        q, k, v = nx @ layer['wq'], nx @ layer['wk'], nx @ layer['wv']
        x = x + F.scaled_dot_product_attention(q, k, v, is_causal=True)
        x = x + (rmsnorm(x) @ layer['w1']).relu() @ layer['w2']

    return x @ emb.T

if hasattr(torch, 'compile'):
    forward = torch.compile(forward)
    print("Using torch.compile")

tokens = torch.tensor(encode(text), device=device)
all_seqs = tokens.unfold(0, ctx_len, 1)

# train
opt = torch.optim.Adam(params, lr=1e-4, fused=True)

for i in range(10000):
    idx = torch.randint(0, all_seqs.size(0), (32,))
    x, y = all_seqs[idx, :-1], all_seqs[idx, 1:]

    logits = forward(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    opt.zero_grad()
    loss.backward()
    opt.step()

    if i % 100 == 0:
        print(f"{i}: loss={loss.item():.2f}")

In [None]:
# generate
ctx = "First Citizen"
tokens = encode(ctx)

for _ in range(500):
    x = torch.tensor([tokens[-(ctx_len-1):]], device=device)
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
    next_token = torch.multinomial(probs, 1).item()
    tokens.append(next_token)

print(decode(tokens))