In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-02-28 15:44:54--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-02-28 15:44:54 (21.2 MB/s) - ‘input.txt’ saved [1115394/1115394]



## Tiny shakespeare data i/o and basic encoding

In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('total chars:', vocab_size)
print(''.join(chars))

total chars: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [5]:
import numpy as np

# create a mapping of unique chars to indices
char_to_idx = {ch:i for i,ch in enumerate(chars)}
idx_to_char = {i:ch for i,ch in enumerate(chars)}
encode = lambda x: np.array([char_to_idx[ch] for ch in x])
decode = lambda x: ''.join([idx_to_char[idx] for idx in x])

print('encoded:', encode("Hello world"))
print('decoded:', decode(encode("Hello worlkd")))

encoded: [20 43 50 50 53  1 61 53 56 50 42]
decoded: Hello worlkd


In [7]:
# Convert text to encoded tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:10])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [9]:
# Train and valid split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

print(train_data.shape, val_data.shape)

torch.Size([1003854]) torch.Size([111540])


In [10]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [26]:
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"Input: {context}, Target: {target}")
    print(f"Input: {decode(context.numpy())}, Target: {decode(np.array([target.numpy()]))}")


Input: tensor([18]), Target: 47
Input: F, Target: i
Input: tensor([18, 47]), Target: 56
Input: Fi, Target: r
Input: tensor([18, 47, 56]), Target: 57
Input: Fir, Target: s
Input: tensor([18, 47, 56, 57]), Target: 58
Input: Firs, Target: t
Input: tensor([18, 47, 56, 57, 58]), Target: 1
Input: First, Target:  
Input: tensor([18, 47, 56, 57, 58,  1]), Target: 15
Input: First , Target: C
Input: tensor([18, 47, 56, 57, 58,  1, 15]), Target: 47
Input: First C, Target: i
Input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), Target: 58
Input: First Ci, Target: t


In [27]:
torch.manual_seed(42)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, size=(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

x_batch, y_batch = get_batch('train')
print(x_batch.shape, y_batch.shape)

for b in range(batch_size):
    print(f"Batch {b}:")
    for t in range(block_size):
        context = x_batch[b,:t+1]
        target = y_batch[b,t]
        print(f"Input: {context}, Target: {target}")
        print(f"Input: {decode(context.numpy())}, Target: {decode(np.array([target.numpy()]))}")

torch.Size([4, 8]) torch.Size([4, 8])
Batch 0:
Input: tensor([57]), Target: 1
Input: s, Target:  
Input: tensor([57,  1]), Target: 46
Input: s , Target: h
Input: tensor([57,  1, 46]), Target: 47
Input: s h, Target: i
Input: tensor([57,  1, 46, 47]), Target: 57
Input: s hi, Target: s
Input: tensor([57,  1, 46, 47, 57]), Target: 1
Input: s his, Target:  
Input: tensor([57,  1, 46, 47, 57,  1]), Target: 50
Input: s his , Target: l
Input: tensor([57,  1, 46, 47, 57,  1, 50]), Target: 53
Input: s his l, Target: o
Input: tensor([57,  1, 46, 47, 57,  1, 50, 53]), Target: 60
Input: s his lo, Target: v
Batch 1:
Input: tensor([1]), Target: 58
Input:  , Target: t
Input: tensor([ 1, 58]), Target: 46
Input:  t, Target: h
Input: tensor([ 1, 58, 46]), Target: 43
Input:  th, Target: e
Input: tensor([ 1, 58, 46, 43]), Target: 56
Input:  the, Target: r
Input: tensor([ 1, 58, 46, 43, 56]), Target: 43
Input:  ther, Target: e
Input: tensor([ 1, 58, 46, 43, 56, 43]), Target: 1
Input:  there, Target:  
Input

## A very simple Bigram model

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size) -> None:
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embeddings = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embeddings(idx)
        B, T, V = logits.shape
        
        if targets is None:
            return logits, None
        logits = logits.view(B*T, V)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)  # transpose to get B x T x V
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]  # take the last token
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)

        return idx
    
m = BigramLanguageModel(vocab_size)
out, loss = m(x_batch, y_batch)
print(out.shape, loss)

torch.Size([32, 65]) tensor(4.5700, grad_fn=<NllLossBackward0>)


In [41]:
idx = torch.zeros((1,1), dtype=torch.long)
idx = m.generate(idx, 100)
print(decode(idx[0].numpy()))


rERSLA:hsAy leWZhgYYt?$-.iF-wOY,FWORwQhBq:TwYYiW3uGJxi'?KO&.,-VG-vnvOaYz?L&
zbgLvcgcnCAKY.PuKDBH.nhu


In [50]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [51]:
# train the model
batch_size = 32
for steps in range(epochs):
    x_batch, y_batch = get_batch('train')
    
    logits, loss = m(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % 1000 == 0:
        print(f"Step {steps}, Loss {loss.item()}")

Step 0, Loss 2.5169339179992676
Step 1000, Loss 2.517500400543213
Step 2000, Loss 2.554572343826294
Step 3000, Loss 2.440171241760254
Step 4000, Loss 2.46695613861084
Step 5000, Loss 2.3237228393554688
Step 6000, Loss 2.502161979675293
Step 7000, Loss 2.526960611343384
Step 8000, Loss 2.4958548545837402
Step 9000, Loss 2.538480520248413


In [53]:
idx = torch.zeros((1,1), dtype=torch.long)
idx = m.generate(idx, 500)
print(decode(idx[0].numpy()))


Ye s t; CLogamowriour; thestot Eavius.
Mag a at fitous rdod bullo?
BRD yomu;

TO: he ere I t? end whal


THes the ch t che Isponds d tot nghe.

Maits 'swe, theste rdit BERDinocowous soupa LA b y l the.
haden baithio cr; Hee:

BENCl do-cisise?
Sigine oud I wenghath's pet, heangho,

Prcr stomolow'

ARI k'ss, angrpetangond! ed
Thot f'lor rther yof l I nd BNGLA:


nthoulotsmo fordsonot fueth yo thenoour eaus poy ist fond oo?
TUDONCI s my, nd pawim'Her il fifos RD ul inge, f in at uliss Be.

IOK:

An


## Mathematical trick in self-attention

In [71]:
# Toy example
torch.manual_seed(42)
B, T, V = 4, 8, 2
x = torch.randn(B, T, V)
x.shape

torch.Size([4, 8, 2])

In [78]:
# Version 1
# Calculate x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros_like(x)
for b in range(B):
    for t in range(T):
        xbow[b,t] = torch.mean(x[b,:t+1], dim=0)

In [79]:
x[0], xbow[0]

(tensor([[ 1.9269,  1.4873],
         [ 0.9007, -2.1055],
         [ 0.6784, -1.2345],
         [-0.0431, -1.6047],
         [-0.7521,  1.6487],
         [-0.3925, -1.4036],
         [-0.7279, -0.5594],
         [-0.7688,  0.7624]]),
 tensor([[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]]))

In [80]:
# Version: 2
# Vectorize damnit!
avg_weights = torch.tril(torch.ones(T,T))
avg_weights = avg_weights / torch.sum(avg_weights, dim=1, keepdim=True)
avg_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [83]:
xbow2 = avg_weights @ x # (T,T) @ (B,T,V) -> (B,T,T) @ (B,T,V) -> (B,T,V)
torch.allclose(xbow, xbow2)


True

In [89]:
# Version 3
# Using softmax
tril = torch.tril(torch.ones(T,T))
avg_weights = torch.zeros_like(tril)
avg_weights[tril == 0] = float('-inf')
avg_weights = F.softmax(avg_weights, dim=1)
avg_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [90]:
xbow3 = avg_weights @ x
torch.allclose(xbow, xbow3)

True

In [106]:
# Version 4: Self-attention
B, T, V = 4, 8, 32
x = torch.randn(B, T, V)

# Single head attention
head_size = 16
query = nn.Linear(V, head_size, bias=False)
key = nn.Linear(V, head_size, bias=False)
value = nn.Linear(V, head_size, bias=False)
k = key(x)  # (B,T,head_size)
q = query(x) # (B,T,head_size)

affinity = q @ k.transpose(-2, -1) / head_size**0.5  # (B,T,head_size) @ (B,head_size,T) -> (B,T,T)
# ^ normalise by the size of the head to prevent the dot product from exploding and causing softmaxx to converge to argmax

tril = torch.tril(torch.ones(T,T))
affinity = affinity.masked_fill(tril == 0, float('-inf'))
# affinity[tril == 0] = float('-inf')
affinity = F.softmax(affinity, dim=1)

v = value(x)  # (B,T,head_size)
out = affinity @ v  # (B,T,T) @ (B,T,head_size) -> (B,T,head_size)
# out = affinity @ x

out.shape

torch.Size([4, 8, 16])

In [105]:
affinity[0]

tensor([[0.1388, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1164, 0.1128, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1756, 0.1034, 0.1615, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1272, 0.1324, 0.1881, 0.2397, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0749, 0.2171, 0.1666, 0.1622, 0.3265, 0.0000, 0.0000, 0.0000],
        [0.0942, 0.1995, 0.1404, 0.2054, 0.1477, 0.1700, 0.0000, 0.0000],
        [0.1727, 0.1011, 0.1793, 0.1778, 0.3046, 0.4363, 0.3622, 0.0000],
        [0.1003, 0.1337, 0.1641, 0.2149, 0.2213, 0.3937, 0.6378, 1.0000]],
       grad_fn=<SelectBackward0>)