In [None]:
import torch # torch is a folder, __init__.py in that torch folder imports other .py files in that torch folder, i can then import any of those imports with torch dot whatever
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda' and torch.cuda.is_bf16_supported():
    dtype = torch.bfloat16
elif device == 'cuda':
    dtype = torch.float16
else:
    dtype = torch.float32

print(f"Using: {device}, dtype: {dtype}")

# config
batch_size = 64
dim = 128 # embedding vector (aka token vector) length
n_layers = 2 # number of transformer blocks
mlp_dim = 256 # hidden layer 1 nodes in MLP
ctx_len = 32 # how many tokens it looks at, at once

# load data
with open("input.txt", "r") as f:
    text = f.read()

# char tokenizer
chars = sorted(set(text)) # set finds unique chars, sorted puts them in order
vocab_size = len(chars) # number of unique tokens in the vocabulary
stoi = {c:i for i,c in enumerate(chars)} # stoi is a dictionary. this loop (or dict comprehension) built it. from then on, stoi now is dict like {'a': 0, 'b': 1, ...}
itos = {i:c for c,i in stoi.items()} # itos is the reverse of stoi. itos is like {0: 'a', 1: 'b', ...}
encode = lambda s: [stoi[c] for c in s] # makes list, gets int (id) from stoi dictionary by single char as key, that char came from string. like s = abc, c = a first, then c = b, etc.
decode = lambda l: ''.join([itos[i] for i in l]) # reverse of encode, and makes list into string by using ''.join
print(f"Vocab size: {vocab_size}, dim: {dim}, layers: {n_layers}")

# weights
emb = torch.randn(vocab_size, dim, device=device, dtype=dtype) * 0.02 # creates emb matrix where each row represents 1 tokens embedding vector
pos = torch.randn(ctx_len, dim, device=device, dtype=dtype) * 0.02 # same as emb but for positional embeddings. each row represents 1 position in the context window, and the values in that row are the positional embedding elements for that position. pos[0] is the positional embedding for the first token in the context window, pos[1] is for the second token, etc.


# Transformer block weights


# 1st Block Weights
# -----------------------------------------------
# 1st attention
wq0 = torch.randn(128, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for q, input dim is rows and it's for dot producting with nx (RMS normalized x) row vector
wk0 = torch.randn(128, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for k, same as q
wv0 = torch.randn(128, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for v, same as q and k

# 1st MLP
w10 = torch.randn(128, 256, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for MLP input to hidden layer
w20 = torch.randn(256, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for MLP hidden layer to output
# -----------------------------------------------



# 2nd Block Weights
# -----------------------------------------------
# 2nd Attention
wq1 = torch.randn(128, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # takes in the output from 1st block
wk1 = torch.randn(128, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # same as wq1
wv1 = torch.randn(128, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # same as wq1 and wk1

# 2nd MLP
w11 = torch.randn(128, 256, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for MLP in 2nd, input to hidden layer
w21 = torch.randn(256, 128, device=device, dtype=dtype) * (0.1 / (2 ** 0.5)) # weight matrix for MLP in 2nd, hidden layer to output
# -----------------------------------------------




params = [emb, pos, wq0, wk0, wv0, w10, w20, wq1, wk1, wv1, w11, w21]
for p in params:
    p.requires_grad = True

def rmsnorm(x, eps=1e-5):
    return x / ((x ** 2).mean(dim=-1, keepdim=True).sqrt() + eps)

def forward(x):
    B, T = x.shape # x.shape(batch_size, ctx_len-1) B = batch size, T = sequence length (context length - 1). Because x = all_seqs[idx, :-1] — the last token is sliced off to become the target y
    x = emb[x] + pos[:T] # C = dim, emb[x].shape(B,T,C), pos[:T].shape(T,C), + broadcasts pos[:T] along dimension 0 for element wise add, pos[:T] indexing is up to but not including T value (slicing rules)



    # Transformer Blocks

    # ------------------------------------------------
    # 1st Block

    # Norm
    nx = rmsnorm(x)

    # Attention
    # Formula for attention scores: (Q @ K^T) / sqrt(d_k), where Q is the query matrix, K is the key matrix, and d_k is the dimension of the key vectors. This formula computes the attention scores by taking the dot product of the query and key matrices, and then scaling it by the square root of the dimension of the key vectors to prevent large values that can lead to exploding softmax
    q, k, v = nx @ wq0, nx @ wk0, nx @ wv0
    scores = q @ k.transpose(-2, -1) / (q.size(-1) ** 0.5) # actual attention calculations
    mask = torch.tril(torch.ones(T, T, device=device, dtype=torch.bool)) # creates matrix filled with 1s in the lower triangle and diagonal, and 0s in the upper triangle.
    scores = scores.masked_fill(~mask, float('-inf')) # ~ is bitwise not. Here it flips the mask so that the lower triangle and diagonal are False, and the upper triangle is True. Then masked_fill replaces the True values (upper triangle) with -inf, which effectively masks out those positions in the attention scores. Then it leaves everything else alone
    sm = F.softmax(scores, dim=-1) # probabilities per row
    attn_out = sm @ v # shape (B, T, C) because sm is (B, T, T) and v is (B, T, C). This is the output of the attention mechanism
    x = x + attn_out # residual connection, adds attn input to attn output, shape (B, T, 128)

    # Norm
    nx = rmsnorm(x)

    # MLP
    out = nx @ w10 # shape (B, T, 256)
    out = out.relu() # shape (B, T, 256), applies ReLU activation function element-wise
    out = out @ w20 # shape (B, T, 128), this is the output of the MLP
    x = x + out # residual connection, adds mlp input to mlp output, shape (B, T, 128)
    # -------------------------------------------------


    # -------------------------------------------------


    # -------------------------------------------------
    # 2nd Block

    # Norm
    nx = rmsnorm(x) # x is input from 1st block, shape (B, T, 128)

    # Attention
    q, k, v = nx @ wq1, nx @ wk1, nx @ wv1
    scores = q @ k.transpose(-2, -1) / (q.size(-1) ** 0.5)
    mask = torch.tril(torch.ones(T, T, device=device, dtype=torch.bool))
    scores = scores.masked_fill(~mask, float('-inf'))
    sm = F.softmax(scores, dim=-1)
    attn_out = sm @ v
    x = x + attn_out

    # Norm
    nx = rmsnorm(x)

    # MLP
    out = nx @ w11
    out = out.relu()
    out = out @ w21
    x = x + out
    # -------------------------------------------------



    x = x @ emb.T # x is prediction vector, that vector dots with emb.T to get similarity scores (dot product), which is most similar can be next token prediction. shape (B, T, vocab_size) because x is (B, T, 128) and emb.T is (128, vocab_size)
    return x
    
    '''
    one x row mul and add with each emb row, to get similarity score for each token. The highest score is the predicted next token. The blocks learn to push x toward the embedding of the correct next token, so that the dot product (similarity) is highest for the correct next token.
    Say dim=3, vocab=3 for simplicity:

    emb = [[0.2, 0.5, 0.1],   # token "a"
        [0.9, 0.1, 0.3],   # token "b"  
        [0.1, 0.8, 0.2]]   # token "c"
    Output vector after all blocks: x = [0.85, 0.15, 0.28]

    x @ emb.T = dot product of x with each row:

    vs "a": 0.85*0.2 + 0.15*0.5 + 0.28*0.1 = 0.27
    vs "b": 0.85*0.9 + 0.15*0.1 + 0.28*0.3 = 0.86 ← highest
    vs "c": 0.85*0.1 + 0.15*0.8 + 0.28*0.2 = 0.26
    Result: [0.27, 0.86, 0.26] → model predicts "b" because x is most similar to "b"'s embedding.

    The blocks learned to push x toward the embedding of the correct next token.
    '''


if hasattr(torch, 'compile'):
    forward = torch.compile(forward) # torch compile makes less trips to and from vram, and optimizes the forward function in various ways, like fusing operations together, which can speed up training
    print("Using torch.compile")

tokens = torch.tensor(encode(text), device=device)
all_seqs = tokens.unfold(0, ctx_len, 1)

# train
opt = torch.optim.Adam(params, lr=1e-4, fused=True)

for i in range(10000):
    idx = torch.randint(0, all_seqs.size(0), (batch_size,))
    x, y = all_seqs[idx, :-1], all_seqs[idx, 1:]

    logits = forward(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    opt.zero_grad()
    loss.backward()
    opt.step()

    if i % 100 == 0:
        print(f"{i}: loss={loss.item():.2f}")

Using: cpu, dtype: torch.float32
Vocab size: 65, dim: 128, layers: 2


KeyboardInterrupt: 

In [3]:
# generate
ctx = "First Citizen"
tokens = encode(ctx)
print(tokens)

for _ in range(500):
    x = torch.tensor([tokens[-(ctx_len-1):]], device=device)
    logits = forward(x)
    probs = F.softmax(logits[0, -1] / 0.8, dim=-1)
    next_token = torch.multinomial(probs, 1).item()
    tokens.append(next_token)

print(decode(tokens))

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52]
First CitizenazpAdk.QEtxgXv!o.Zm.
byJNz!rdb&-s,oFIBmggf&VKgy!yveN3aitkZv'IsCaita3&oRcmF?qdSl;M:y
NQatBe'IBS&$ 
u?kN.kYxKxcRmmQuGua3bhomqxV&KADm;ipwZ-d-Tv-!uVfyoSDauzmySA;.MAUcCZu:ruNyRzcomLoOuNDuNNcHb-im:BATdSHEi,pUEyc
?O!?bCbW,P'oDI PSrMXvqMfTvT$B!i!Naw!tvf.AvBKggAiTmTZzZWOpdd
 yksZdW?xxyy;u,-$eA:?eE- deuuZzDJ3bTpx ;rNBtdyomKB;!RfaQYeY'NV&maH
gmQMeD,!nqMKQ
t3y;E?J?Qpxj&jnsIKjv$!EJRtcxnb!ah
H!YZ:,k.
$EkEA&VYNqSuW ;yPew!WB-LVJFKgRQ
IwJRbokVtamXbWhy3ciBv?WFHGvcJ;Q$CZtyIGjPaE!AvrQjkO!vg.NjGJn&meWb:LEtpcX3xXQ?l!
