In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

torch.manual_seed(1357)

<torch._C.Generator at 0x7f06204200f0>

# Data

In [82]:
# --------------------------- DATA -------------------------

with open('./dataset/shakespeare.txt','r',encoding='utf-8') as f:
    data = f.read()

class CharacterLevelTokenizer:
    def __init__(self,data):
        self.data = data
        self.vocab = sorted(list(set(self.data)))
        self.VOCAB_SIZE = len(self.vocab)
        
        self.i_s = {i:s for i,s in enumerate(self.vocab)}
        self.s_i = {s:i for i,s in self.i_s.items()}
        
    def encode(self,s):
        return torch.tensor([self.s_i[c] for c in s],dtype=torch.long)

    def decode(self,s):
        return ''.join([self.i_s[i.item()] for i in s])

tokenizer = CharacterLevelTokenizer(data)

# Dataset & Dataloader

In [83]:
class ShakespeareDataset:
    def __init__(self,block_size:int, is_test=False) -> None:
        self.tokenizer = CharacterLevelTokenizer(data)
        self.is_test = is_test
        self.full_data = self.tokenizer.encode(self.tokenizer.data)
        if self.is_test:
            self.data = self.full_data[int(0.9*len(self.full_data)):]
        else:
            self.data = self.full_data[:int(0.9*len(self.full_data))]
        self.block_size = block_size

    def __len__(self) -> int:
        return len(self.data)

    def get_block_size(self) -> int:
        return self.block_size

    def get_vocab_size(self) -> int:
        return self.tokenizer.VOCAB_SIZE

    def __getitem__(self,idx):
        item = self.data[idx:idx+self.block_size+1]
        x = item[:-1]
        y = item[1:]
        return x,y

# Model

In [84]:
@dataclass
class Config:
    block_size = 16 # context-length
    batch_size = 32 # mini-batch size
    vocab_size = tokenizer.VOCAB_SIZE
    n_embed = 32
    n_heads = 4
    head_size = n_embed // n_heads
    
    n_layers = 5
    
    eval_iters = 500
    lr = 1e-3
    
    attn_dropout = 0.1
    block_dropout = 0.2
    
    eval_interval = 1000
    max_iters = 10_000

In [85]:
class AttentionHead(nn.Module):
    def __init__(self, Config):
        super().__init__()

        self.block_size = Config.block_size
        self.n_embed = Config.n_embed
        self.head_size = Config.head_size
        
        self.key = nn.Linear(self.n_embed, self.head_size, bias=False)
        self.query = nn.Linear(self.n_embed, self.head_size, bias=False)
        
        self.value = nn.Linear(self.n_embed, self.head_size, bias=False)

        self.register_buffer(
            'tril',
            torch.tril(torch.ones(self.block_size,self.block_size))
        )
        
        self.dropout = nn.Dropout(Config.attn_dropout)

    def forward(self, x):

        B,T,C = x.shape

        k = self.key(x)
        q = self.query(x)

        wei = q@k.transpose(-2,-1) * (C ** 0.5)
        wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf'))

        wei = F.softmax(wei, dim=-1)
        
        wei = self.dropout(wei)
        
        v = self.value(x)
        out = wei @ v
        
        return out

- the projection is added to handle the residual connections

In [86]:
class MultiHeadAttention(nn.Module):
    def __init__(self, Config):
        super().__init__()
        self.n_heads = Config.n_heads
        self.head_size = Config.head_size
        
        self.heads = nn.ModuleList([AttentionHead(Config) for _ in range(self.n_heads)])
        
        self.projection = nn.Linear(Config.n_embed, Config.n_embed)
        
        self.dropout = nn.Dropout(Config.attn_dropout)
    
    def forward(self,x):
        x = torch.cat([h(x) for h in self.heads],dim=-1)
        x = self.projection(x)
        x = self.dropout(x)
        return x

- the projection is added to handle the residual connections
- n_embed is multiplied by 4 as per the paper.

dropouts:

- in feed forward layer
- in multihead attention
- in single attention head to dropout the heads

In [87]:
class FeedForward(nn.Module):
    def __init__(self, Config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(Config.n_embed,Config.n_embed * 4),
            nn.ReLU(),
            nn.Linear(Config.n_embed * 4, Config.n_embed), # projection
            nn.Dropout(Config.block_dropout)
        )
    def forward(self,x):
        return self.net(x)

- adding lots of blocks doesn't help since it'll become a large model and the data would trickle down a lot
- solution? residual connections!
- layernorm: normalize along the rows
- instead of normalization after ffwd/multihead_attn, we pre-normalize instead

In [88]:
class TransformerBlock(nn.Module):
    def __init__(self, Config):
        super().__init__()
        self.attn = MultiHeadAttention(Config)
        self.ff = FeedForward(Config)
        self.ln1 = nn.LayerNorm(Config.n_embed)
        self.ln2 = nn.LayerNorm(Config.n_embed)

    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        
        return x

In [89]:
class LanguageModel(nn.Module):
    def __init__(self,Config):
        super().__init__()
        
        self.n_embed = Config.n_embed
        self.block_size = Config.block_size
        
        self.token_embedding_table = nn.Embedding(Config.vocab_size,self.n_embed)
        self.pos_embedding_table = nn.Embedding(self.block_size, self.n_embed)
        
        self.blocks = nn.Sequential(
            *[TransformerBlock(Config)]*Config.n_layers,
            nn.LayerNorm(self.n_embed)
        )

        self.lm_head = nn.Linear(self.n_embed,Config.vocab_size)
        
    def forward(self,idx,targets=None):
        
        B,T = idx.shape
        
        token_embs = self.token_embedding_table(idx)
        pos_embs = self.pos_embedding_table(torch.arange(T))
        
        x = token_embs + pos_embs
        x = self.blocks(x)
        logits = self.lm_head(x)
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
            
        return logits,loss

        
    def generate(self,idx,total):
        for _ in range(total):
            idx_cond = idx[:, -self.block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
            

In [90]:
train_ds = ShakespeareDataset(Config.block_size)
val_ds = ShakespeareDataset(Config.block_size,is_test=True)

train_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=Config.batch_size)
val_dl = torch.utils.data.DataLoader(val_ds,shuffle=False,batch_size=Config.batch_size)

In [91]:
lm = LanguageModel(Config)
optim = torch.optim.AdamW(lm.parameters(),lr=Config.lr)

In [92]:
@torch.no_grad()
def estimate_loss():
    out = {}
    lm.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(Config.eval_iters)
        if split=='train':
            it = iter(train_dl)
        else:
            it = iter(val_dl)
        for k in range(Config.eval_iters):
            X, Y = next(it)
            logits, loss = lm(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    lm.train()
    return out

In [93]:
train_iter = iter(train_dl)

for step in range(10_000):
    inputs,targets = next(train_iter)
    logits,loss=lm(inputs,targets)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % Config.eval_interval == 0 or step == Config.max_iters - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

step 0: train loss 4.2709, val loss 4.2514
step 1000: train loss 2.4884, val loss 2.6560
step 2000: train loss 2.3917, val loss 2.6376
step 3000: train loss 2.2830, val loss 2.5369
step 4000: train loss 2.2121, val loss 2.5069
step 5000: train loss 2.2196, val loss 2.3605
step 6000: train loss 2.2704, val loss 2.3862
step 7000: train loss 2.2593, val loss 2.3365
step 8000: train loss 2.2940, val loss 2.3538
step 9000: train loss 2.3210, val loss 2.3379


In [97]:
generated = lm.generate(
    torch.zeros((1,1),dtype=torch.long), # initial context 0
    total=500
)
generated = tokenizer.decode(generated[0])
print('generated (500 tokens) >>>\n',generated)

generated (500 tokens) >>>
 
KING:
If fly cimeen leth yould mANT:
Saybenter ands And
The din hich teich hern and fmaigh in tailef,
And ine epeeplectiet the cuure beutintish.

ORDY:
And roont ans ans of all StireZstids to to the whe tie youen is to wou tho thoulcel Ritent mred
HaiNG RITHAHAEENTHANY:
My achaS:
Wes ysis
O doreus ingre in
Aydircours dower's to hee lood'd as daidf Endind a wing my
I itis mortwacht abruevanrings ulist wo sande,
Claiine par
Or trin.

Swo the sralls witt thy am ENTHARD Ritind:
In dakel,
Hightlust, 
