In [7]:
import torch # torch is a folder, __init__.py in that torch folder imports other .py files in that torch folder, i can then import any of those imports with torch dot whatever
import torch.nn.functional as F
import matplotlib.pyplot as plt



# uncomment for file download:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda' and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
elif device == 'cuda':
    dtype = torch.float16
else:
    dtype = torch.float32

print(f"Using: {device}, dtype: {dtype}")

# config
lr = 3e-4
train_iter = 1000000
batch_size = 8
dim = 256 # embedding vector (aka token vector) length
n_layers = 8 # number of transformer blocks
mlp_dim = 4096 # hidden layer 1 nodes in MLP
ctx_len = 32 # how many tokens it looks at, at once
plot_every = 5000 # plot embedding map every N iterations

# load data
with open("input.txt", "r") as f:
    text = f.read()

# char tokenizer
chars = sorted(set(text)) # set finds unique chars, sorted puts them in order
vocab_size = len(chars) # number of unique tokens in the vocabulary
stoi = {c:i for i,c in enumerate(chars)} # stoi is a dictionary. this loop (or dict comprehension) built it. from then on, stoi now is dict like {'a': 0, 'b': 1, ...}
itos = {i:c for c,i in stoi.items()} # itos is the reverse of stoi. itos is like {0: 'a', 1: 'b', ...}
encode = lambda s: [stoi[c] for c in s] # makes list, gets int (id) from stoi dictionary by single char as key, that char came from string. like s = abc, c = a first, then c = b, etc.
decode = lambda l: ''.join([itos[i] for i in l]) # reverse of encode, and makes list into string by using ''.join
print(f"Vocab size: {vocab_size}, dim: {dim}, layers: {n_layers}")

emb = torch.randn(vocab_size, dim, device=device, dtype=dtype) * 0.02 # creates emb matrix where each row represents 1 tokens embedding vector
pos = torch.randn(ctx_len, dim, device=device, dtype=dtype) * 0.02 # same as emb but for positional embeddings. each row represents 1 position in the context window, and the values in that row are the positional embedding elements for that position. pos[0] is the positional embedding for the first token in the context window, pos[1] is for the second token, etc.


# Transformer block weights — lists of weight matrices, one per layer
# each layer gets its own wq, wk, wv (attention) and w1, w2 (MLP)
init_scale = 0.1 / (2 ** 0.5)
wq = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)] # wq[layer] = query weight matrix
wk = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)] # wk[layer] = key weight matrix
wv = [torch.randn(dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)] # wv[layer] = value weight matrix
w1 = [torch.randn(dim, mlp_dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)] # w1[layer] = MLP input→hidden
w2 = [torch.randn(mlp_dim, dim, device=device, dtype=dtype) * init_scale for _ in range(n_layers)] # w2[layer] = MLP hidden→output


params = [emb, pos] + wq + wk + wv + w1 + w2
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps)

def forward(x):
    B, T = x.shape # x.shape(batch_size, ctx_len-1) B = batch size, T = sequence length (context length - 1). Because x = all_seqs[idx, :-1] — the last token is sliced off to become the target y
    x = emb[x] + pos[:T] # C = dim, emb[x].shape(B,T,C), pos[:T].shape(T,C), + broadcasts pos[:T] along dimension 0 for element wise add, pos[:T] indexing is up to but not including T value (slicing rules)
    mask = torch.tril(torch.ones(T, T, device=device, dtype=torch.bool)) # causal mask, reused across all layers

    for layer in range(n_layers):
        # Norm
        nx = rmsnorm(x)

        # Attention
        q, k, v = nx @ wq[layer], nx @ wk[layer], nx @ wv[layer]
        scores = q @ k.transpose(-2, -1) / (q.size(-1) ** 0.5)
        scores = scores.masked_fill(~mask, float('-inf'))
        attn_out = F.softmax(scores, dim=-1) @ v
        x = x + attn_out # residual connection

        # Norm
        nx = rmsnorm(x)

        # MLP
        out = (nx @ w1[layer]).relu() @ w2[layer]
        x = x + out # residual connection
        
    x = rmsnorm(x)
    x = x @ emb.T # similarity scores against all token embeddings → logits
    return x
    
    '''
    one x row mul and add with each emb row, to get similarity score for each token. The highest score is the predicted next token. The blocks learn to push x toward the embedding of the correct next token, so that the dot product (similarity) is highest for the correct next token.
    Say dim=3, vocab=3 for simplicity:

    emb = [[0.2, 0.5, 0.1],   # token "a"
        [0.9, 0.1, 0.3],   # token "b"  
        [0.1, 0.8, 0.2]]   # token "c"
    Output vector after all blocks: x = [0.85, 0.15, 0.28]

    x @ emb.T = dot product of x with each row:

    vs "a": 0.85*0.2 + 0.15*0.5 + 0.28*0.1 = 0.27
    vs "b": 0.85*0.9 + 0.15*0.1 + 0.28*0.3 = 0.86 ← highest
    vs "c": 0.85*0.1 + 0.15*0.8 + 0.28*0.2 = 0.26
    Result: [0.27, 0.86, 0.26] → model predicts "b" because x is most similar to "b"'s embedding.

    The blocks learned to push x toward the embedding of the correct next token.
    '''


if hasattr(torch, 'compile'):
    forward = torch.compile(forward) # compiles forward, removes python interpretor, looks at all the ops, makes less trips between vram and gpu cores with kernel fusion where instead of doing a calculation for 1 var by carrying data from vram to gpu cores doing an intermediate calculation then bringing  result back to vram then taking that answer back to cores for next operation then returning to vram over and over it fuses (combines) all those operations into 1 transfer then does the calculations then returns result to vram
    print("Using torch.compile")

tokens = torch.tensor(encode(text), device=device) # encoding text to list of numbers (token ids), converting that list to a tensor
'''
all_seqs = [] # stores tensors in this list. shape eg. (1000000, 32) aka (all tokens in file, context len) stores like lists of all context size sequences in the fiile. So if ctx_len is 32, it stores all sequences of 32 tokens (sliding 1 token at a time) in the file, which are the training examples. Each sequence is a sequence of token ids (integers).
for i in range(len(tokens) - ctx_len + 1): # same pattern as convolutional kernal... len of total elements minus window (or convolutional kernel) size + 1.
    all_seqs.append(tokens[i : i + ctx_len]) # ctx len is how wide is that window or kernel. i changes by step size. makes list of tokens 0 to 31, appends all those to all_seq list, next iter looks at tokens 1 to 32, appends that list as next row of tokens, etc.
all_seqs = torch.stack(all_seqs) # converts list of seperate tensors to one matrix tensor
'''
all_seqs = tokens.unfold(dimension=0, size=ctx_len, step=1) # same as torch.stack, just faster. shape (kind of like cnn number of chars in whole file like 1115390 minus window size like 32 and + 1 for the starting window position, ctx_len)

Using: cpu, dtype: torch.float32
Vocab size: 65, dim: 256, layers: 8
Using torch.compile


In [None]:
# train
opt = torch.optim.Adam(params, lr=lr, fused=True)

def plot_embeddings(step, loss_val):
    E = emb.detach().float().cpu()
    E = E - E.mean(dim=0)
    U, S, V = torch.svd(E)
    coords = E @ V[:, :2]
    plt.figure(figsize=(8, 8))
    plt.scatter(coords[:, 0].numpy(), coords[:, 1].numpy(), s=20)
    for i, c in itos.items():
        label = repr(c) if c in (' ', '\n', '\t') else c
        plt.annotate(label, (coords[i, 0].item(), coords[i, 1].item()), fontsize=11)
    plt.title(f"Step {step} | Loss {loss_val:.2f}")
    plt.grid(True, alpha=0.3)
    plt.show()

for i in range(train_iter):
    idx = torch.randint(0, all_seqs.shape[0], (batch_size,))
    x = all_seqs[idx, :-1]
    y = all_seqs[idx, 1:]
    logits = forward(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i % 100 == 0:
        print(f"{i}: loss={loss.item():.2f}")
    if i % plot_every == 0:
        plot_embeddings(i, loss.item())

In [8]:
# save weights
# torch.save({
#     'emb': emb.data, 'pos': pos.data,
#     'wq': [w.data for w in wq], 'wk': [w.data for w in wk],
#     'wv': [w.data for w in wv], 'w1': [w.data for w in w1],
#     'w2': [w.data for w in w2],
# }, 'weights.pt')

# if in colab, download to pc:
# from google.colab import files
# files.download('weights.pt')

# print("Saved weights.pt")

# to load later instead of training:
ckpt = torch.load('weights.pt', map_location=device)
emb.data = ckpt['emb'].to(dtype=dtype)
pos.data = ckpt['pos'].to(dtype=dtype)
for i in range(n_layers):
    wq[i].data = ckpt['wq'][i].to(dtype=dtype)
    wk[i].data = ckpt['wk'][i].to(dtype=dtype)
    wv[i].data = ckpt['wv'][i].to(dtype=dtype)
    w1[i].data = ckpt['w1'][i].to(dtype=dtype)
    w2[i].data = ckpt['w2'][i].to(dtype=dtype)

In [9]:
# generate
ctx = "most important"
tokens = encode(ctx)
print(tokens)

for _ in range(500):
    x = torch.tensor([tokens[-(ctx_len-1):]], device=device)
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
    next_token = torch.multinomial(probs, 1).item()
    tokens.append(next_token)

print(decode(tokens))

[51, 53, 57, 58, 1, 47, 51, 54, 53, 56, 58, 39, 52, 58]
most important,
Which runs of war suffice.

JULIET:
I'll swear to the Tower.

PRINCE:
Be rale, my lord, as ; many a thousand these words in this point on Kent,
Shall I dash out my deserts in this rude assault;
And when the trier of my death,
To show the state some noise me from their lives,
That I mine own hands I spied
With betters than England? What realm
Ymbron that dogs must be a prisoner.
A silly king and a severable cannot round
Juliet matter to speak from himself.
Go, sirrah, to addget it, thou strikes
