In [156]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm

### Fetching all the filepath to load the text files

In [16]:
text = ""
with open("shakespeare_scripts.txt") as f:
    text = f.read()

In [18]:
print("Dataset with length: ", len(text))

Dataset with length:  5536916


## Vocabulary

For this example we are working with characters instead of words or subwords 
Apart from character embedding we 
- Word Embedding 
- SubWord Embedding (Google uses SentencePiece)

In [61]:
vocab = sorted(set(text))
vocab_size = len(vocab)
print("Vocab: ", "".join(vocab))
print("Vocab Size: ", len(vocab))

Vocab:  	
 !$'(),-.0123456789:?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz
Vocab Size:  78


In [36]:
ctoi = {v: k for k, v in enumerate(vocab)}
itoc = {k:v for k, v in enumerate(vocab)}

In [37]:
encoder = lambda x: [ctoi[idx] for idx in x]
decoder = lambda x: "".join([itoc[idx] for idx in x])

In [38]:
encoded_text = encoder("Hello")
print("Encoded Text: ", encoded_text)
decoded_text = decoder(encoded_text)
print("Decoded Text: ", decoded_text)

Encoded Text:  [30, 56, 63, 63, 66]
Decoded Text:  Hello


## Data Spliting 

In [39]:
n = int(0.9 * len(text))

train_text = text[:n]
test_text = text[n:]

In [41]:
print("Total Data Size: ", len(text))
print("Train Data Size: ", len(train_text))
print("Test Data Size: ", len(test_text))

Total Data Size:  5536916
Train Data Size:  4983224
Test Data Size:  553692


### Convert data into tensors

In [42]:
train_data = torch.tensor(encoder(train_data), dtype=torch.long)
test_data = torch.tensor(encoder(test_data), dtype = torch.long)

In [43]:
train_data[:10]

tensor([33, 31, 36, 29, 21,  1, 23, 25, 42,  2])

## Sequence Window Generator

In [44]:
block_size = 8
train_data[:block_size+1]

tensor([33, 31, 36, 29, 21,  1, 23, 25, 42])

In [48]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for i in range(block_size):
    
    context = x[:i+1]
    target = y[i]
    print(f"Context: {context}, Target: {target}")

Context: tensor([33]), Target: 31
Context: tensor([33, 31]), Target: 36
Context: tensor([33, 31, 36]), Target: 29
Context: tensor([33, 31, 36, 29]), Target: 21
Context: tensor([33, 31, 36, 29, 21]), Target: 1
Context: tensor([33, 31, 36, 29, 21,  1]), Target: 23
Context: tensor([33, 31, 36, 29, 21,  1, 23]), Target: 25
Context: tensor([33, 31, 36, 29, 21,  1, 23, 25]), Target: 42


### Generate Batch Data

In [49]:
block_size = 8
batch_size = 4

In [58]:
def generate_batch(split):
    
    data = train_data if split == "train" else test_data
    
    rand_idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[idx: idx + block_size] for idx in rand_idx])
    y = torch.stack([data[idx + 1: idx + block_size + 1] for idx in rand_idx])
    return {"x": x, "y": y}

In [59]:
generate_batch("train")

{'x': tensor([[70, 71, 70,  2, 55, 60, 70, 67],
         [64, 66, 69, 69, 66, 74,  2, 74],
         [ 2, 71, 66,  2, 71, 59, 56,  2],
         [66, 69,  2, 31,  8,  2, 65, 66]]),
 'y': tensor([[71, 70,  2, 55, 60, 70, 67, 52],
         [66, 69, 69, 66, 74,  2, 74, 59],
         [71, 66,  2, 71, 59, 56,  2, 59],
         [69,  2, 31,  8,  2, 65, 66, 69]])}

In [53]:
idx

tensor([2226723, 3333214, 4982009, 3996242])

TypeError: only integer tensors of a single element can be converted to an index

In [112]:
class BiGramModel(nn.Module):
    
    def __init__(self, vocab_size):
        super(BiGramModel, self).__init__()
        self.emb = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, x, y = None):
        logits = self.emb(x) # B x T x C
        if y != None:
            # Converting the Tensors as pytorch requires BT x C tensor
            B, T, C = logits.shape
            log = logits.view(B*T, C)
            target = y.view(B*T)
            loss = F.cross_entropy(log, target)
            return logits, loss
        return logits, None
       
    def generate(self, x, max_new_tokens):
        
        for _ in range(max_new_tokens):
            
            logits, loss = self(x)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            
            # Sampling from sample distribution 
            idx_next = torch.multinomial(probs, num_samples = 1)
            
            x = torch.cat((x, idx_next), dim=1)
        return x
    

In [149]:
model = BiGramModel(vocab_size)

In [150]:
logits, loss = model(**generate_batch("test"))

print("Logits Shape: ", logits.shape)
print("Loss Shape: ", loss)

Logits Shape:  torch.Size([64, 128, 78])
Loss Shape:  tensor(4.8736, grad_fn=<NllLossBackward0>)


In [151]:
initial = torch.zeros(1, 1, dtype=torch.long)
initial

tensor([[0]])

In [152]:
print(decoder(model.generate(initial, 100)[0].numpy()))

	i]t24A8_6Inp9a4R)BXTD94U6 hFj7S9Gh1]w(Fsdi	6'jv1
lqaLtmp,!53]x$b8Rll.v2yu2rU2dRwyQhJ5gOV8KLOPfYdzudm


## Training Loop

In [153]:
optim = torch.optim.AdamW(model.parameters())

In [154]:
block_size = 128
batch_size = 64

epochs = 10000

In [158]:
final_loss = []
final_validation_loss = []
for epoch in tqdm(range(epochs)):
    idxs = generate_batch("train")
    logits, loss = model(**idxs)
    optim.zero_grad()
    loss.backward()
    optim.step()
    final_loss.append(loss.item())
    if epoch % 200:
        with torch.no_grad():
            _, loss = model(**generate_batch("val"))
            final_validation_loss.append(loss.item())
        
print(f"Train Loss: {np.mean(final_loss)} Validation Loss: {np.mean(final_validation_loss)}")
    

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:14<00:00, 134.16it/s]

Train Loss: 2.4837432108402253 Validation Loss: 2.51157966175271





In [159]:
print(decoder(model.generate(initial, 1000)[0].numpy()))

	T:
Buchixech marlet lo mblas.
VIILLIUENI fal.
CENDreane kodoraveepumfoundinthenca id her, avik memyore se sot hon O:
CUCIVOLOHERDOFLIUEROUSTOLI lanamye ipe.
ERES:
Awe hena BELENA where y t camanfos upue owourenkilinfandusollablld d inounoul mst atere kn y t,
ELLOThee w, ing.
WI'lloe p'd, ary, IV:
TAKIArobewow HE:
MOLOR: t, ns bethis thttshe a t nocost! winliteseak ig bou'she t:
CI PES:
Y t t ourst l,
A t I sw MARRENRY:
SESer?
Whe aselllitharkedilode tos amyelal atithtit fie Wo!
G wnf IAND CAGLSTh torexpr test MICLVI s ure, w hat tace.
GLI wne he, r Simunge hin ilcof alin sticourerasl s les at,
GOS:
G harengourichon.
PUS PANE:
God atll handy l shixim dice, k ongheshelt blldof thicu hefe walert dwor s
WOClyorime ind Whorefon'sold se ain!'sth herolard Ja d myre RUCENESCK hesune
BUERU:
Whyof aro h be mo SA:
INDGOTha me y, VALONGLUS:
QUCHELOr ce.
DUSTOTRBES:
KINDO:
S: IUCUThabre ppr LOMARI:
SI igss nds f an bery, m s ggeshe soul hit matoreay tscomy, PULBONYON:
FAENG mot
HEN:
Exerele arid C