## ShakespeareGPT

> based on [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY)

In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass

In [80]:
torch.manual_seed(1357)

<torch._C.Generator at 0x7fbec4214050>

In [81]:
@dataclass
class Config:
    block_size = 8 # context-length
    batch_size = 4 # mini-batch size

# Preparing Data

In [82]:
with open('./dataset/shakespeare.txt','r',encoding='utf-8') as f:
    data = f.read()
    
print(f"{len(data)=}\n{data[:100]}")

len(data)=1114985
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## Tokenizer

In [83]:
class CharacterLevelTokenizer:
    def __init__(self,data):
        self.data = data
        self.vocab = sorted(list(set(self.data)))
        self.VOCAB_SIZE = len(self.vocab)
        
        self.i_s = {i:s for i,s in enumerate(self.vocab)}
        self.s_i = {s:i for i,s in self.i_s.items()}
        
    def encode(self,s):
        return torch.tensor([self.s_i[c] for c in s],dtype=torch.long)

    def decode(self,s):
        return ''.join([self.i_s[i.item()] for i in s])

In [84]:
tokenizer = CharacterLevelTokenizer(data)
print(tokenizer.vocab)
print(tokenizer.VOCAB_SIZE)

['\n', ' ', '!', "'", ',', '-', '.', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
62


In [85]:
tokenizer.encode('et tu brute?')

tensor([40, 55,  1, 55, 56,  1, 37, 53, 56, 55, 40,  9])

In [86]:
tokenizer.decode(tokenizer.encode('et tu brute?'))

'et tu brute?'

## Dataset & Dataloader

In [87]:
class ShakespeareDataset:
    def __init__(self,block_size:int, is_test=False) -> None:
        self.tokenizer = CharacterLevelTokenizer(data)
        self.is_test = is_test
        self.full_data = self.tokenizer.encode(self.tokenizer.data)
        if self.is_test:
            self.data = self.full_data[int(0.9*len(self.full_data)):]
        else:
            self.data = self.full_data[:int(0.9*len(self.full_data))]
        self.block_size = block_size

    def __len__(self) -> int:
        return len(self.data)

    def get_block_size(self) -> int:
        return self.block_size

    def get_vocab_size(self) -> int:
        return self.tokenizer.VOCAB_SIZE

    def __getitem__(self,idx):
        item = self.data[idx:idx+self.block_size+1]
        x = item[:-1]
        y = item[1:]
        return x,y

In [88]:
train_ds = ShakespeareDataset(Config.block_size)
print(f'{train_ds.get_block_size()=}\n{train_ds.get_vocab_size()=}\n{len(train_ds)=}')

val_ds = ShakespeareDataset(Config.block_size,is_test=True)
print(f'{len(val_ds)=}')

train_ds.get_block_size()=8
train_ds.get_vocab_size()=62
len(train_ds)=1003486
len(val_ds)=111499


In [89]:
train_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=Config.batch_size)

In [90]:
inputs,targets=next(iter(train_dl))
print(inputs.shape,targets.shape)
inputs,targets

torch.Size([4, 8]) torch.Size([4, 8])


(tensor([[15, 44, 53, 54, 55,  1, 12, 44],
         [44, 53, 54, 55,  1, 12, 44, 55],
         [53, 54, 55,  1, 12, 44, 55, 44],
         [54, 55,  1, 12, 44, 55, 44, 61]]),
 tensor([[44, 53, 54, 55,  1, 12, 44, 55],
         [53, 54, 55,  1, 12, 44, 55, 44],
         [54, 55,  1, 12, 44, 55, 44, 61],
         [55,  1, 12, 44, 55, 44, 61, 40]]))

# Bi-gram Language Model

In [125]:
class BigramLM(nn.Module):
    def __init__(self,vocab_size):
        super(BigramLM,self).__init__()
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)
        
    def forward(self,idx,targets=None):
        logits = self.token_embedding_table(idx) # (B,T,C:vocab_size)
        
        if targets is None:
            loss = None
        else:
            # torch cross entropy expects B,C,T instead of B,T,C
            # and for targets, we need B*T instead of B,T
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)
            
        return logits,loss

        
    def generate(self,idx,total):
        # idx (B,T) in current context
        for _ in range(total):
            logits,loss = self(idx)
            # since the last element is the next character, we pluck out -1 from T
            logits = logits[:,-1,:] # (B*T,C) -> (B,C)
            probs = F.softmax(logits,dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx,idx_next],dim=1) # (B, T+=1)
            
        return idx

In [126]:
bglm = BigramLM(tokenizer.VOCAB_SIZE)
logits,loss = bglm(inputs,targets)
print(logits.shape,loss)

generated = bglm.generate(
    torch.zeros((1,1),dtype=torch.long), # initial context 0
    total=100
)
generated = tokenizer.decode(generated[0])
generated

torch.Size([256, 62]) tensor(4.4188, grad_fn=<NllLossBackward0>)


"\nxqI E:Mp.HtfZDyhnSa!uQaSjIncCgX'xwUv-P;DzahqW.RY;ldx CmYAQgCT.noI\nXqtX JCeZzbMPkGcbMxQAt;l--ZFlHmUSc"

## training the bigram LM

In [136]:
bglm = BigramLM(tokenizer.VOCAB_SIZE)

optim = torch.optim.AdamW(bglm.parameters(),lr=1e-3)
bglm_dl = torch.utils.data.DataLoader(train_ds,shuffle=False,batch_size=32)

it = iter(bglm_dl)
for steps in range(25_000):
    inputs,targets = next(it)
    logits,loss=bglm(inputs,targets)
    optim.zero_grad()
    loss.backward()
    optim.step()
    if steps%2500==0:
        print(f'step: {steps} loss: {loss.item()}')

step: 0 loss: 4.459361553192139
step: 2500 loss: 3.365318536758423
step: 5000 loss: 3.2022905349731445
step: 7500 loss: 2.7361788749694824
step: 10000 loss: 2.567321538925171
step: 12500 loss: 2.655674457550049
step: 15000 loss: 2.230267286300659
step: 17500 loss: 2.302700996398926
step: 20000 loss: 2.6413631439208984
step: 22500 loss: 2.369800329208374


In [137]:
generated = bglm.generate(
    torch.zeros((1,1),dtype=torch.long), # initial context 0
    total=500
)
generated = tokenizer.decode(generated[0])
print(generated)


Onda!
RESubend f I; thikl My s yo pucitll d the, the dathioaye: illipy, wirat LI be f ta pa br fly ake's, d p
TENe.
S:
Spen d Orts tilllos ve.
Ithanugr BE: gre d ck.
Heturare n tean f tyothingoullfus han.
ELo aveaverelamy ises 'st,
Mak'leerast, ourellan wo;

NThy me.
I flasin, be me nkeesherot pies the t s shoug, ms mig wed I'l VADasusho is t ie isom pou lerthe car f ty imideless
TEE ce my,
Win? tce ests, trioepal l?
Lisg, ll t tstiseat d alolissoucod, meenild al m sifeavims f tar sisonthais
ONo
