In [None]:
import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda' and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
elif device == 'cuda':
    dtype = torch.float16
else:
    dtype = torch.float32

print(f"Using: {device}, dtype: {dtype}")

# config
emb_vec_len = 64 # emb vector length
n_transformer_blocks = 4 # number of transformer blocks
mlp_dim = 32 # hidden layer size in MLP
ctx_len = 128 # max number of tokens model looks at in one go including itself, some may be zeroed out in causal attention
batch_size = 32 # number of sequences per batch

# load data
with open("input.txt", "r") as f:
    text = f.read()

# char tokenizer
chars = sorted(set(text))
vocab_size = len(chars) # how many unique tokens (characters for this version)
stoi = {c:i for i,c in enumerate(chars)} # makes dict for token to int id
itos = {i:c for c,i in stoi.items()} # makes dict for id_int to token_char
encode = lambda s: [stoi[c] for c in s] # encoder: take string, output list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take list of integers, output string
print(f"Vocab size: {vocab_size}, emb_vec_len: {emb_vec_len}, transformer blocks: {n_transformer_blocks}")

# weights
emb = torch.randn(vocab_size, emb_vec_len, device=device, dtype=dtype) * 0.02 # matrix of emb vectors (picture each row is an emb vector)
pos = torch.randn(ctx_len, emb_vec_len, device=device, dtype=dtype) * 0.02 # matrix of pos emb vectors (picture 1d... pos like 3 says you start here, that 1D emb element like 5 added tells where it ends, like 5+3=8, and positive eg tells what direction. does for each dim so nudges the vector in specifid directions.)

layers = []
for _ in range(n_transformer_blocks):
    layer = {
        'wq': torch.randn(emb_vec_len, emb_vec_len, device=device, dtype=dtype) * (0.1 / (n_transformer_blocks ** 0.5)),
        'wk': torch.randn(emb_vec_len, emb_vec_len, device=device, dtype=dtype) * (0.1 / (n_transformer_blocks ** 0.5)),
        'wv': torch.randn(emb_vec_len, emb_vec_len, device=device, dtype=dtype) * (0.1 / (n_transformer_blocks ** 0.5)),
        'w1': torch.randn(emb_vec_len, mlp_dim, device=device, dtype=dtype) * (0.1 / (n_transformer_blocks ** 0.5)),
        'w2': torch.randn(mlp_dim, emb_vec_len, device=device, dtype=dtype) * (0.1 / (n_transformer_blocks ** 0.5)),
    }
    layers.append(layer)

params = [emb, pos] + [w for layer in layers for w in layer.values()]
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps) # denominator is like standar deviation except it subtracts 0 instead of mean

def forward(x): # is a matrix of token ids. x is batch of 1d loodup id's, aka x is 2d, but just break it down and think in terms of 1d because the other d is just for parallelization
    B, T = x.shape # B = batch size, T = time = sequence length (number of tokens in each sequence in the batch, = ctx_len - 1, because it predicts the next token)
    x = emb[x] + pos[:T] # x = id's, 1 id grabs 1 emb vector from the emb matrix

    for layer in layers:
        nx = rmsnorm(x)
        q, k, v = nx @ layer['wq'], nx @ layer['wk'], nx @ layer['wv']
        x = x + F.scaled_dot_product_attention(q, k, v, is_causal=True)
        x = x + (rmsnorm(x) @ layer['w1']).relu() @ layer['w2']

    return x @ emb.T

if hasattr(torch, 'compile'):
    forward = torch.compile(forward)
    print("Using torch.compile")

tokens = torch.tensor(encode(text), device=device)
all_seqs = tokens.unfold(0, ctx_len, 1)

# train
opt = torch.optim.Adam(params, lr=1e-3, fused=True) # optimizer = update weights eg.: w-= lr*gradient

for i in range(1000):
    idx = torch.randint(0, all_seqs.size(0), (batch_size,))
    x, y = all_seqs[idx, :-1], all_seqs[idx, 1:]

    logits = forward(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    opt.zero_grad()
    loss.backward()
    opt.step()

    if i % 100 == 0:
        print(f"{i}: loss={loss.item():.2f}")
layers

In [6]:
# generate
ctx = "First Citizen"
tokens = encode(ctx)

for _ in range(150):
    x = torch.tensor([tokens[-(ctx_len-1):]], device=device)
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
    next_token = torch.multinomial(probs, 1).item()
    tokens.append(next_token)

print(decode(tokens))

First Citizen fea then potarerd yo,
Thas irter ars. buth's are an thet hentl for eth twath coold erses and, and whoure s;
Siban wond beth af har to theen ling ceen
