## Things Added:

- multihead attention
- feed forward


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

In [30]:
torch.manual_seed(1357)

<torch._C.Generator at 0x7f5d201240f0>

In [31]:
with open('./dataset/shakespeare.txt','r',encoding='utf-8') as f:
    data = f.read()

In [32]:
class CharacterLevelTokenizer:
    def __init__(self,data):
        self.data = data
        self.vocab = sorted(list(set(self.data)))
        self.VOCAB_SIZE = len(self.vocab)
        
        self.i_s = {i:s for i,s in enumerate(self.vocab)}
        self.s_i = {s:i for i,s in self.i_s.items()}
        
    def encode(self,s):
        return torch.tensor([self.s_i[c] for c in s],dtype=torch.long)

    def decode(self,s):
        return ''.join([self.i_s[i.item()] for i in s])

In [33]:
tokenizer = CharacterLevelTokenizer(data)

In [34]:
class ShakespeareDataset:
    def __init__(self,block_size:int, is_test=False) -> None:
        self.tokenizer = CharacterLevelTokenizer(data)
        self.is_test = is_test
        self.full_data = self.tokenizer.encode(self.tokenizer.data)
        if self.is_test:
            self.data = self.full_data[int(0.9*len(self.full_data)):]
        else:
            self.data = self.full_data[:int(0.9*len(self.full_data))]
        self.block_size = block_size

    def __len__(self) -> int:
        return len(self.data)

    def get_block_size(self) -> int:
        return self.block_size

    def get_vocab_size(self) -> int:
        return self.tokenizer.VOCAB_SIZE

    def __getitem__(self,idx):
        item = self.data[idx:idx+self.block_size+1]
        x = item[:-1]
        y = item[1:]
        return x,y

In [54]:
@dataclass
class Config:
    block_size = 16 # context-length
    batch_size = 32 # mini-batch size
    vocab_size = tokenizer.VOCAB_SIZE
    n_embed = 32
    lr = 1e-3
    n_heads = 4
    head_size = n_embed // n_heads

In [36]:
train_ds = ShakespeareDataset(Config.block_size)
val_ds = ShakespeareDataset(Config.block_size,is_test=True)

In [37]:
train_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=Config.batch_size)

In [38]:
class AttentionHead(nn.Module):
    def __init__(self, Config):
        super().__init__()

        self.block_size = Config.block_size
        self.n_embed = Config.n_embed
        self.head_size = Config.head_size
        
        self.key = nn.Linear(self.n_embed, self.head_size, bias=False)
        self.query = nn.Linear(self.n_embed, self.head_size, bias=False)
        
        self.value = nn.Linear(self.n_embed, self.head_size, bias=False)

        self.register_buffer(
            'tril',
            torch.tril(torch.ones(self.block_size,self.block_size))
        )

    def forward(self, x):

        B,T,C = x.shape

        k = self.key(x)
        q = self.query(x)

        wei = q@k.transpose(-2,-1) * (C ** 0.5)
        wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf'))

        wei = F.softmax(wei, dim=-1)
        
        v = self.value(x)
        out = wei @ v
        
        return out

In [59]:
class MultiHeadAttention(nn.Module):
    def __init__(self, Config):
        super().__init__()
        self.n_heads = Config.n_heads
        self.head_size = Config.head_size
        
        self.heads = nn.ModuleList([AttentionHead(Config) for _ in range(self.n_heads)])
    
    def forward(self,x):
        return torch.cat([h(x) for h in self.heads],dim=-1)

In [67]:
class FeedForward(nn.Module):
    def __init__(self, Config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(Config.n_embed,Config.n_embed),
            nn.ReLU()
        )
    def forward(self,x):
        return self.net(x)

In [68]:
class LanguageModel(nn.Module):
    def __init__(self,Config):
        super().__init__()
        
        self.n_embed = Config.n_embed # number of embedding dims
        self.block_size = Config.block_size
        
        self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed)
        
        self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed)
        
        self.multihead_attn = MultiHeadAttention(Config)
        
        self.feed_forward = FeedForward(Config)
        self.lm_head = nn.Linear(self.n_embed,Config.vocab_size)
        
    def forward(self,idx,targets=None):
        
        B,T = idx.shape
        
        token_embs = self.token_embedding_table(idx) # (B,T,n_embed)
        pos_embs = self.pos_embedding_table(torch.arange(T)) # (T,n_embed)
        
        x = token_embs + pos_embs # (B,T,n_embed)
        x = self.multihead_attn(x) # (B,T,head_size)
        
        x = self.feed_forward(x) # (B,T,n_embed)
        
        logits = self.lm_head(x) # (B,T,vocab_size)
        
        if targets is None:
            loss = None
        else:
            # torch cross entropy expects B,C,T instead of B,T,C
            # and for targets, we need B*T instead of B,T
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
            
        return logits,loss

        
    def generate(self,idx,total):
            
        # idx is (B, T) array of indices in the current context
        for _ in range(total):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            # since the last element is the next character, we pluck out -1 from T
            logits = logits[:, -1, :] # (B*T,C) -> (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+=1)
        return idx
            

In [69]:
lm = LanguageModel(Config)
optim = torch.optim.AdamW(lm.parameters(),lr=Config.lr)

it = iter(train_dl)
for steps in range(10_000):
    inputs,targets = next(it)
    logits,loss=lm(inputs,targets)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if steps%1000==0:
        print(f'step: {steps} loss: {loss.item()}')

step: 0 loss: 4.137528419494629
step: 1000 loss: 2.4903576374053955
step: 2000 loss: 2.4997308254241943
step: 3000 loss: 2.621544599533081
step: 4000 loss: 2.216435194015503
step: 5000 loss: 3.4450135231018066
step: 6000 loss: 2.0625293254852295
step: 7000 loss: 2.517878293991089
step: 8000 loss: 1.8138558864593506
step: 9000 loss: 1.91243577003479


In [70]:
generated = lm.generate(
    torch.zeros((1,1),dtype=torch.long), # initial context 0
    total=500
)
generated = tokenizer.decode(generated[0])
print('generated (500 tokens) >>>\n',generated)

generated (500 tokens) >>>
 
fe lord sis
Wharty is;
Kon's Es tair, awraculd my cod ammy mot?

IFgod of Rirt I prouch we toats
Weas cepeamiche mep!
NO RI Ried GIFFlONNBUCI:
LUCat
Aith
Andd aingrre Mnttepe the lmell evove muanks ourd her law crarf oul he bigus nold.

A amif outin to our to woaughat soves wy provalis;
Witist
Ortch lot leate bloufrw
Cot hin, prar cof hak obaajokand wice tousill acites.
Wait;

Say clamis LUant chour aste,
Thot sich thime'd ris his oue,
Mng at ieg,
Ae foor whisonst ilwe,
Te frars rop?
Ler hre my 
