# WS_follow_along_nanogpt.ipynb
# WESmith 06/10/23
## follow along with Karpathy video
## https://www.youtube.com/watch?v=kCc8FmEb1nY

In [None]:
import torch
import torch.nn as nn
from   torch.nn import functional as F

In [None]:
# hyperparameters
batch_size = 32
block_size = 8
max_iters  = 5000
eval_interval = 500
learning_rate = 1e-3
device     = 'cpu'
eval_iters = 200
n_embd     = 32
seed       = 1337

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
len(text)

In [None]:
print(text[:200])

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [None]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [None]:
print(encode('hii there,\nyou'))
print(decode(encode('hii there,\nyou')))

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)

In [None]:
data.shape, data.dtype

In [None]:
print(data[:200])

In [None]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data   = data[n:]

In [None]:
train_data.shape, val_data.shape

In [None]:
train_data[:block_size + 1]

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t + 1]
    target  = y[t]
    print(f'when input is: {context} the target is: {target}')

In [None]:
torch.manual_seed(seed)
#batch_size = 4
#block_size = 8 # maximum context length for predictions

In [None]:
def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # set model to eval phase (check docs for meaning of this)
    # presumably eval() and train() methods are in nn.Module parent class of model
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # reset model back to training phase (check docs for meaning)
    return out

In [None]:
xb, yb = get_batch('train')
print('inputs')
print(xb.shape)
print(xb)
print('targets')
print(yb.shape)
print(yb)

In [None]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t + 1]
        target  = yb[b, t]
        print(f'when input is: {context.tolist()} the target is: {target}')

In [None]:
# an embedding is just a random mxn array
list(nn.Embedding(3, 5).parameters())
dd = nn.Embedding(3,5)
ee = list(dd.parameters())
for k in dd.parameters():
    print(k.shape)

In [None]:
class Head(nn.Module):
    '''one head of self-attention'''
    
    def __init__(self, head_size):
        super().__init__()
        self.key    = nn.Linear(n_embd, head_size, bias=False)
        self.query  = nn.Linear(n_embd, head_size, bias=False)
        self.value  = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ('affinities')
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B,T,C) @ (B,C,T) => (B,T,T)
        wei = F.softmax(wei, dim=-1) # (B,T,T)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B,T,T) @ (B,T,C) => (B,T,C)
        return out


In [None]:
# Digram model to start
torch.manual_seed(seed)

class BigramLanguageModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        # create a random embedding matrix
        self.token_embedding_table    = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        # B = batch, T = time, C = channel (embedding dimension size n_embd)
        # idx and targets are each (B,T) tensors of integers
        B, T = idx.shape
        
        #print(f'T = {T}') # diagnostic
        #for k in self.position_embedding_table.parameters():
        #    print(k.shape)
        
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C) due to broadcasting of pos_emb over B
        x = self.sa_head(x) # apply one head of self-attention, (B,T,C)
        logits  = self.lm_head(x) # (B,T,C) @ (C,vocab_zsize) => (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            # this is training the embedding matrix
            B, T, C = logits.shape
            logits  = logits.view(B * T, C)  # reshape for cross_entropy
            targets = targets.view(B * T)    # ditto (or could do ...view(-1))
            loss    = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    # this function is more general than a digraph model (block_size = 1) requires, 
    # for use later with longer values pf block_size
    def generate(self, idx, max_new_tokens):
        # idx is a (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # this becomes (B, C)
            # apply softmax to get probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
        return idx
        

In [None]:
model = BigramLanguageModel()
m     = model.to(device)
out, loss = m(xb, yb)
out.shape, loss.item()

In [None]:
batch = 1  # WS mod
idx = torch.zeros((batch, 1), dtype=torch.long)
out = m.generate(idx, max_new_tokens=500)
#print(out.shape)
for k in out:
    print(decode(k.tolist()))

In [None]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
# training
for iter in range(max_iters):
    
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter:5}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train')
    
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
#print(loss.item())

In [None]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

# The mathematical trick in self-attention

In [None]:
torch.manual_seed(seed)
B,T,C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape
x[0]

In [None]:
# we want x[b, t] = mean_{i<=t} x[b, i]
# ie, calculate the average of what comes before the ith token
# and include the ith token in the average
xbow = torch.zeros((B, T, C)) # bow = 'bag of words'
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0) # (C)
xbow[0]

In [None]:
# the trick to make more efficient: lower triangular matrix multiply
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
a, b, c

In [None]:
# now make xbow more efficient with this trick: version 2
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B,T,T) @ (B,T,C) => (B,T,C) (wei is broadcast over B)
xbow2.shape
torch.allclose(xbow, xbow2)

In [None]:
# version 3: softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)  # normalize along last dimension
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

In [None]:
# version 4: self attention!
torch.manual_seed(seed)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B, T, C)

# a single Head for self-attention
head_size = 16
key   = nn.Linear(C, head_size, bias=False) # (C,16)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k     = key(x)   # (B,T,C) @ (C,16) => (B,T,16)
q     = query(x)
v     = value(x)
# here's the magic to get the historical weights
# multiply by head_size**-5 to normalize the variance of the weights
wei = q @ k.transpose(-2,-1) * head_size**-0.5 # (B,T,16) @ (B,16,T) => (B,T,T)

In [None]:
tril = torch.tril(torch.ones(T, T))
#wei  = torch.zeros((T, T))
# NOTE: in decoder methods, use the causal line below (ie, the future not known)
#       in coder   methods, don't use the line below: all nodes can communicate
wei  = wei.masked_fill(tril == 0, float('-inf'))
wei  = F.softmax(wei, dim=-1)
#out  = wei @ x
out  = wei @ v

In [None]:
wei.shape, out.shape

In [None]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1) * head_size**-0.5
k.var(), q.var(), wei.var()

In [None]:
out[0]