## ShakespeareGPT

> based on [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY)

In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

In [80]:
torch.manual_seed(1357)

<torch._C.Generator at 0x7fbec4214050>

# Preparing Data

In [82]:
with open('./dataset/shakespeare.txt','r',encoding='utf-8') as f:
    data = f.read()
    
print(f"{len(data)=}\n{data[:100]}")

len(data)=1114985
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## Tokenizer

In [83]:
class CharacterLevelTokenizer:
    def __init__(self,data):
        self.data = data
        self.vocab = sorted(list(set(self.data)))
        self.VOCAB_SIZE = len(self.vocab)
        
        self.i_s = {i:s for i,s in enumerate(self.vocab)}
        self.s_i = {s:i for i,s in self.i_s.items()}
        
    def encode(self,s):
        return torch.tensor([self.s_i[c] for c in s],dtype=torch.long)

    def decode(self,s):
        return ''.join([self.i_s[i.item()] for i in s])

In [84]:
tokenizer = CharacterLevelTokenizer(data)
print(tokenizer.vocab)
print(tokenizer.VOCAB_SIZE)

['\n', ' ', '!', "'", ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
62


In [85]:
tokenizer.encode('et tu brute?')

tensor([40, 55,  1, 55, 56,  1, 37, 53, 56, 55, 40,  9])

In [86]:
tokenizer.decode(tokenizer.encode('et tu brute?'))

'et tu brute?'

## Config!

In [175]:
@dataclass
class Config:
    block_size = 8 # context-length
    batch_size = 4 # mini-batch size
    vocab_size = tokenizer.VOCAB_SIZE

## Dataset & Dataloader

In [87]:
class ShakespeareDataset:
    def __init__(self,block_size:int, is_test=False) -> None:
        self.tokenizer = CharacterLevelTokenizer(data)
        self.is_test = is_test
        self.full_data = self.tokenizer.encode(self.tokenizer.data)
        if self.is_test:
            self.data = self.full_data[int(0.9*len(self.full_data)):]
        else:
            self.data = self.full_data[:int(0.9*len(self.full_data))]
        self.block_size = block_size

    def __len__(self) -> int:
        return len(self.data)

    def get_block_size(self) -> int:
        return self.block_size

    def get_vocab_size(self) -> int:
        return self.tokenizer.VOCAB_SIZE

    def __getitem__(self,idx):
        item = self.data[idx:idx+self.block_size+1]
        x = item[:-1]
        y = item[1:]
        return x,y

In [88]:
train_ds = ShakespeareDataset(Config.block_size)
print(f'{train_ds.get_block_size()=}\n{train_ds.get_vocab_size()=}\n{len(train_ds)=}')

val_ds = ShakespeareDataset(Config.block_size,is_test=True)
print(f'{len(val_ds)=}')

train_ds.get_block_size()=8
train_ds.get_vocab_size()=62
len(train_ds)=1003486
len(val_ds)=111499


In [89]:
train_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=Config.batch_size)

In [90]:
inputs,targets=next(iter(train_dl))
print(inputs.shape,targets.shape)
inputs,targets

torch.Size([4, 8]) torch.Size([4, 8])


(tensor([[15, 44, 53, 54, 55,  1, 12, 44],
         [44, 53, 54, 55,  1, 12, 44, 55],
         [53, 54, 55,  1, 12, 44, 55, 44],
         [54, 55,  1, 12, 44, 55, 44, 61]]),
 tensor([[44, 53, 54, 55,  1, 12, 44, 55],
         [53, 54, 55,  1, 12, 44, 55, 44],
         [54, 55,  1, 12, 44, 55, 44, 61],
         [55,  1, 12, 44, 55, 44, 61, 40]]))

# Bi-gram Language Model

In [125]:
class BigramLM(nn.Module):
    def __init__(self,vocab_size):
        super(BigramLM,self).__init__()
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)
        
    def forward(self,idx,targets=None):
        logits = self.token_embedding_table(idx) # (B,T,C:vocab_size)
        
        if targets is None:
            loss = None
        else:
            # torch cross entropy expects B,C,T instead of B,T,C
            # and for targets, we need B*T instead of B,T
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
            
        return logits,loss

        
    def generate(self,idx,total):
        # idx (B,T) in current context
        for _ in range(total):
            logits,loss = self(idx)
            # since the last element is the next character, we pluck out -1 from T
            logits = logits[:,-1,:] # (B*T,C) -> (B,C)
            probs = F.softmax(logits,dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx,idx_next],dim=1) # (B, T+=1)
            
        return idx

In [126]:
bglm = BigramLM(tokenizer.VOCAB_SIZE)
logits,loss = bglm(inputs,targets)
print(logits.shape,loss)

generated = bglm.generate(
    torch.zeros((1,1),dtype=torch.long), # initial context 0
    total=100
)
generated = tokenizer.decode(generated[0])
generated

torch.Size([256, 62]) tensor(4.4188, grad_fn=<NllLossBackward0>)


"\nxqI E:Mp.HtfZDyhnSa!uQaSjIncCgX'xwUv-P;DzahqW.RY;ldx CmYAQgCT.noI\nXqtX JCeZzbMPkGcbMxQAt;l--ZFlHmUSc"

## training the bigram LM

In [136]:
bglm = BigramLM(tokenizer.VOCAB_SIZE)

optim = torch.optim.AdamW(bglm.parameters(),lr=1e-3)
bglm_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=32)

it = iter(bglm_dl)
for steps in range(25_000):
    inputs,targets = next(it)
    logits,loss=bglm(inputs,targets)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if steps%2500==0:
        print(f'step: {steps} loss: {loss.item()}')

step: 0 loss: 4.459361553192139
step: 2500 loss: 3.365318536758423
step: 5000 loss: 3.2022905349731445
step: 7500 loss: 2.7361788749694824
step: 10000 loss: 2.567321538925171
step: 12500 loss: 2.655674457550049
step: 15000 loss: 2.230267286300659
step: 17500 loss: 2.302700996398926
step: 20000 loss: 2.6413631439208984
step: 22500 loss: 2.369800329208374


In [138]:
generated = bglm.generate(
    torch.zeros((1,1),dtype=torch.long), # initial context 0
    total=500
)
generated = tokenizer.decode(generated[0])

print(generated)


Thel sofelie ly rouey warbl.
Tough!
Whe.
Yow. m;
DWe f of poote I tigicowe

Theld,
Prt bre'sinil my:
MAn'lyoombur medintacot, he angsss
Toy be?
CEvese k, h ne thenesee se thtere ngsoupyoree akimy's t geallin tupreiespul o h weakllf ld peais LI winee;
VXF ske, dse wavee nsth wersscor g bomalosee at: te I Rothowis t mend n cho, m an cat f o hisemisakelfl gen winer f.
Youpand ty bait:
Bin I
AREdo t, whug.
ARYou ghiry, w as s l p.
Ter it!
Ed hy Heenorivearshair'sthe The beself OLEENUCanthin it ayo m


---

# basic communication between tokens!

### Toy Example

we want the tokens along T to "talk" to each other
AND we also want the tokens to NOT talk to tokens after them, i.e. the future tokens

for now let's talk using cumulative average
for every Tth token, calculate cumulative average upto that token

In [154]:
B,T,C = 4,8,2 # batch, time, channel
x=torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [155]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev,0)

In [144]:
# notice how each row is an average of all the previous rows
x[0],xbow[0]

(tensor([[ 0.6892,  0.8805],
         [-0.3160, -0.8412],
         [-0.9974, -0.3895],
         [-0.5201,  0.0344],
         [-0.1666, -0.6107],
         [ 1.4334, -0.0633],
         [-0.2296, -0.3650],
         [-1.4887, -0.2825]]),
 tensor([[ 0.6892,  0.8805],
         [ 0.1866,  0.0197],
         [-0.2080, -0.1167],
         [-0.2860, -0.0790],
         [-0.2622, -0.1853],
         [ 0.0204, -0.1650],
         [-0.0153, -0.1935],
         [-0.1995, -0.2047]]))

In [148]:
# tril
torch.tril(torch.ones(3,3)) # lower triangular matrix

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [151]:
# all the same using matrix multiplication!
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a,dim=1,keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a@b
print(f'a\n{a}\nb\n{b}\nc\n{c}')

a
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b
tensor([[0., 1.],
        [3., 0.],
        [8., 5.]])
c
tensor([[0.0000, 1.0000],
        [1.5000, 0.5000],
        [3.6667, 2.0000]])


In [156]:
weights = torch.tril(torch.ones(T,T))
weights = weights/weights.sum(1,keepdim=True)
xbow2 = weights @ x # (T, T) @ (B,T,C) =[add batch dim]=> (B, T, T) @ (B, T, C) = (B,T,C)

In [157]:
torch.allclose(xbow,xbow2) 

True

In [158]:
xbow[0], xbow2[0]

(tensor([[-0.6425, -2.0431],
         [ 0.3159, -0.9583],
         [ 0.4632, -0.3893],
         [ 0.2778, -0.0274],
         [ 0.0654, -0.4543],
         [ 0.1753, -0.3378],
         [ 0.2875, -0.2855],
         [ 0.1904, -0.2079]]),
 tensor([[-0.6425, -2.0431],
         [ 0.3159, -0.9583],
         [ 0.4632, -0.3893],
         [ 0.2778, -0.0274],
         [ 0.0654, -0.4543],
         [ 0.1753, -0.3378],
         [ 0.2875, -0.2855],
         [ 0.1904, -0.2079]]))

### 3rd Version using SOFTMAX!

In [164]:
tril = torch.tril(torch.ones(T,T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [170]:
weights = torch.zeros((T,T))
weights = weights.masked_fill(tril==0,float('-inf')) # fill with -infinity where tril is 0
print(weights)
"""
softmax converts -inf to 0 and rest is avg.
since e^-inf = 0
and rest 1s are averaged
"""
weights = F.softmax(weights,dim=-1)
xbow3 = weights @ x
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [167]:
torch.allclose(xbow,xbow3)

True

the averages are gonna change as they're weights

weighted aggregation of past elements give you affinities between tokens depending on how interesting they're to each other ALL using lower triangular matrix multiplication where every value is aggregation of all its previous values.

Hiding the future tokens is known as masking or attention mask used in the decoder block of the transformer architecture -- the masking is done using tril as we saw above.

# Bigram LM with token and position embeddings

In [182]:
@dataclass
class Config:
    block_size = 8 # context-length
    batch_size = 32 # mini-batch size
    vocab_size = tokenizer.VOCAB_SIZE
    n_embed = 32

In [191]:
class BigramLM(nn.Module):
    def __init__(self,Config):
        super(BigramLM,self).__init__()
        
        self.n_embed = Config.n_embed # number of embedding dims
        self.block_size = Config.block_size
        
        self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed)
        
        self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed)
        
        self.lm_head = nn.Linear(self.n_embed,Config.vocab_size)
        
    def forward(self,idx,targets=None):
        
        B,T = idx.shape
        
        token_embs = self.token_embedding_table(idx) # (B,T,n_embed)
        pos_embs = self.pos_embedding_table(torch.arange(T)) # (T,n_embed)
        
        """
        token_embs: B,T,n_embed
        pos_embs:  ,T,n_embed
               +: B,T,n_embed (broadcasted)
               
        so at this point, x knows the token affinities and importance of position!
        [note: here since its a bigram model, position embeddings makes 0 sense]
        """
        x = token_embs + pos_embs # (B,T,n_embed)
        
        logits = self.lm_head(x) # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            # torch cross entropy expects B,C,T instead of B,T,C
            # and for targets, we need B*T instead of B,T
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
            
        return logits,loss

        
    def generate(self,idx,total):
        # idx (B,T) in current context
        for _ in range(total):
            logits,loss = self(idx)
            # since the last element is the next character, we pluck out -1 from T
            logits = logits[:,-1,:] # (B*T,C) -> (B,C)
            probs = F.softmax(logits,dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx,idx_next],dim=1) # (B, T+=1)
            
        return idx

In [192]:
bglm = BigramLM(Config)

optim = torch.optim.AdamW(bglm.parameters(),lr=1e-3)
bglm_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=32)

it = iter(bglm_dl)
for steps in range(10000):
    inputs,targets = next(it)
    logits,loss=bglm(inputs,targets)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if steps%1000==0:
        print(f'step: {steps} loss: {loss.item()}')

step: 0 loss: 4.577418804168701
step: 1000 loss: 2.4386796951293945
step: 2000 loss: 2.6219401359558105
step: 3000 loss: 2.6456706523895264
step: 4000 loss: 2.2980926036834717
step: 5000 loss: 3.0280601978302
step: 6000 loss: 2.297921895980835
step: 7000 loss: 2.7295444011688232
step: 8000 loss: 2.1326377391815186
step: 9000 loss: 2.271216869354248


# Self Attention

In [197]:
B,T,C = 4,8,32
x = torch.randn(B,T,C)

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0,float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei@x
out.shape

torch.Size([4, 8, 32])

gather info from past in data dependent way, the way it does it:

every single token emits 2 vectors: query, key

query: what am i looking for?

key: what do i contain?

dot product: query @ key which will be our weights

now the final value will be `value`

value: private info for each token

In [208]:
# single self-attention head

head_size = 16

key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)

value = nn.Linear(C,head_size,bias=False)

k = key(x) #(B,T,16)
q = query(x) #(B,T,16)

"""
k: (B,T,16)
q: (B,T,16)

last two dims of k have to be swapped
k: (B,16,T)

therefore

q: (B,T,16)
    @
k: (B,16,T)
    =
   (B,T,T)
   
which will be our new weights

notice how each batch has its own set of weights since each batch is different

hence we get data dependency
"""


wei = q @ k.transpose(-2,-1)
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0,float('-inf'))
wei = F.softmax(wei,dim=-1)

v = value(x)

out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [204]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4974, 0.5026, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5662, 0.3141, 0.1198, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4006, 0.0128, 0.4399, 0.1467, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1022, 0.2856, 0.4534, 0.0862, 0.0726, 0.0000, 0.0000, 0.0000],
        [0.3116, 0.2433, 0.2295, 0.1141, 0.0351, 0.0665, 0.0000, 0.0000],
        [0.0322, 0.0218, 0.0068, 0.1083, 0.3796, 0.4106, 0.0406, 0.0000],
        [0.0444, 0.0112, 0.0268, 0.1021, 0.4375, 0.3219, 0.0339, 0.0222]],
       grad_fn=<SelectBackward0>)

### notes on attention

- communication mechanism
- like a directed graph, with weights for each edge
- graph: 8 nodes
    - first node is directed to itself
    - second node is directed to itself and the first node
    
    ...
    - the last node is directed to itself and all the previous nodes
- attention acts over a set vectors of graph, no notion of node position
- hence position embedding is important so the nodes are aware where they are in time
- elements along batch dimension dont talk to each other
- batch multiplication is only for parallel processing, each sample can be considered its own graph independent of all other graphs

transformer:

- `encoder` block: allows all tokens to communicate, no masking
- `decoder` block: no future token communication, via masking

difference between self/cross attention:
- self: k,q,v are from same x
- cross: q from one x and k,v from other x (like translation task for example)


- scaled dot-product attention
    - to keep variance ~= 1
    - if weights have large variance, then softmax on it makes it more like one-hot vectors
    - we need good smooth affinities after softmax

```
attention = softmax((q@k.T)/sqrt(head_size)) @ v

```