In [1]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
import random 
import os
from torch.utils.data import DataLoader,Dataset 
import numpy as np
import tiktoken 
from torch.nn.parallel import DataParallel

os.environ["CUDA_VISIBLE_DEVICES"]= "6"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
device

device(type='cuda')

In [5]:
from datasets import load_dataset
ds = load_dataset("roneneldan/TinyStories")

Repo card metadata block was not found. Setting CardData to empty.


In [5]:
tokenizer = tiktoken.get_encoding("gpt2") 
context_length = 512 
batch_size =16
emb_size= 256
num_heads = 8
dff= 512
decoder_blocks = N = 4
vocab_size = tokenizer.n_vocab
lr = 3e-4
dropout = 0.2

In [7]:
# From 2.1 Million stories loading only 10 random stories at time. Appending start and end token before and after each sentence. The randomly selecting context_length size window 

class CustomDataset(Dataset):
    def __init__(self, dataset,context_length,num_of_stories,batch_size,n_itr, start='\n<|startofstory|>\n', end='\n<|endofstory|>\n'):
        self.dataset = dataset 
        self.context_length = context_length  
        self.num_of_stories = num_of_stories  
        self.batch_size = batch_size 
        self.length = len(dataset)  
        self.start = start 
        self.end = end 
        self.n_itr = n_itr
        self.tokenizer = tiktoken.get_encoding("gpt2") 

    def __len__(self):
        return self.batch_size*self.n_itr
    
    def __getitem__(self, idx):
        n = list(np.random.randint(0,self.length,size= self.num_of_stories))
        text = ''
        
        for i in n:
            
            text += self.start + self.dataset[i] + self.end 
        
            
        tokens = self.tokenizer.encode(text)
        while len(tokens) < 513:
            print("")
            text += text
        tokens = self.tokenizer.encode(text)
            
        start = random.randint(0, len(tokens) - self.context_length-1) 
        inputs = tokens[start: start+self.context_length]
        outputs = tokens[start+1: start+self.context_length+1]
        inputs = torch.tensor(inputs, dtype = torch.long)
        outputs = torch.tensor(outputs, dtype = torch.long)
        assert inputs.shape[-1] == self.context_length, f"Input's last dimension must be {self.context_length}, but got {inputs.shape[-1]}"
        assert outputs.shape[-1] == self.context_length, f"Output's last dimension must be {self.context_length}, but got {outputs.shape[-1]}"

        return inputs,outputs
    
train_dataset =CustomDataset(ds['train']['text'], context_length=context_length ,num_of_stories=200 ,batch_size = batch_size, n_itr =3000)

val_dataset = CustomDataset(ds['validation']['text'], context_length=context_length ,num_of_stories=200 ,batch_size = batch_size, n_itr=200)
# dataloader = DataLoader(ds1, batch_size=1,shuffle = True)

In [8]:

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)

val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [9]:
# for batch in dataloader:
#     break
# print(batch)

In [10]:
# for batch in train_dataloader:
#     inputs, outputs = batch  
#     break

<!-- Model Input -> (Batch_size, number_of_tokens)<br>
Emb output -> (Batch_size, number_of_tokens, emb_size)<br>
Positional embedding output -> (Batch_size, number_of_tokens, emb_size)<br>
Decoder Input --> (Batch_size, number_of_tokens, emb_size)  --> Emb output + Positional embedding output<br>
single head output -> (Batch_size, number_of_tokens, emb_size/num_heads (aka head_dim))<br>
dff -> (32,256,384)--> (32,256,1024) --> (32,256,384)<br>
decoder output-> (Batch_size, number_of_tokens, emb_size)<br>
Model_final_out - > (Batch_size, number_of_tokens, vocab_size)<br>
Y_true --> (Batch_size, number_of_tokens)  --><br>

In [11]:
class Embedding(nn.Module):
    def __init__(self, emb_size, vocab_size, context_length, device):
        super().__init__()
        self.emb_size=emb_size
        self.vocab_size=vocab_size
        self.context_length= context_length
        self.token_embd= nn.Embedding(vocab_size, emb_size)
        self.pos_embd= nn.Embedding(context_length, emb_size)
        self.to(device)
        self.device = device
        
    def forward(self, x):
        # print(f"Input shape is {x.shape}")
        x = self.token_embd(x)
        # print(f"Token embedding shape is {x.shape}")
        pos = torch.arange(x.shape[-2]).unsqueeze(0).to(device)
        # print(f"Positional Input is {pos.shape}")
        # print(f"Positional {pos}")

        pos_embd = self.pos_embd(pos)
        # print(f"Positional Output is {pos_embd.shape}")
        x= pos_embd+x
        # print(f"Final embd output is {x.shape}")
        
        return x 
   

In [12]:
# obj = Embedding(emb_size,vocab_size,context_length, device)
# embd_out = obj(inputs.to(device))

In [13]:
import math

class SingleHead(nn.Module):
    def __init__(self, emb_size, head_dim,context_length,dropout, device):
        super().__init__()
        
        self.emb_size=emb_size
        self.head_dim=head_dim
        
        self.Wq = nn.Linear(emb_size, head_dim,bias=False)
        self.Wk = nn.Linear(emb_size, head_dim,bias=False)
        self.Wv = nn.Linear(emb_size, head_dim,bias=False)
        
        self.mask=torch.tril(torch.ones(context_length,context_length)).view(1,context_length,context_length).to(device)
        self.drop = nn.Dropout(dropout)
        self.to(device)
                                
    def forward(self,x):
        # print(x.shape)
        T= x.shape[-2]  
        q = self.Wq(x)
        k =  self.Wk(x)
        v=  self.Wv(x)
                                
        att_weights = q @ k.transpose(1,2) / math.sqrt(self.head_dim)        
        att_weights = att_weights.masked_fill( self.mask[:,:T,:T]==0,-float("inf"))
        att_weights = F.softmax(att_weights,dim = -1)
        att_weights = self.drop(att_weights)
        # print(f"Attention Weights -  {att_weights.shape}")
        out = att_weights @ v
        # print(f"Output {out.shape}")
        return out
    

In [14]:
# sh = SingleHead(emb_size,emb_size//num_heads,context_length,dropout, device)

In [15]:
# head_output = sh(embd_out)

In [16]:
class MultiHead(nn.Module):
    def __init__(self,emb_size,num_heads,context_length,dropout, device):
        super().__init__() 
        self.heads = nn.ModuleList([SingleHead(emb_size,emb_size//num_heads,context_length,dropout=dropout, device=device) for _ in range(num_heads)])
        self.proj = nn.Linear(emb_size, emb_size)
        self.drop = nn.Dropout(dropout)
        self.to(device)
        
    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        # print(x.shape)
        x = self.proj(x) 
        x = self.drop(x)
        
        return  x
        

In [17]:
# Mh = MultiHead(emb_size,num_heads,context_length,dropout, device)

In [18]:
# multihead_output  = Mh(embd_out)

In [19]:
class FeedForward(nn.Module):
    def __init__(self,emb_size,dff, dropout, device):
        super().__init__() 
        
        self.fc1 = nn.Linear(emb_size, dff)
        self.fc2 = nn.Linear(dff, emb_size)
        self.drop = nn.Dropout(dropout)
        self.to(device)
        
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = self.drop(self.fc2(x))
        return x
    
# ff = FeedForward(emb_size,dff, dropout, device)      
# feed_forward_out= ff(multihead_output)

In [20]:
class Decoder(nn.Module):
    def __init__(self,emb_size,dff,num_heads,context_length,dropout, device):
        super().__init__() 
        
        self.multihead = MultiHead(emb_size,num_heads,context_length,dropout, device)
        self.feedforward = FeedForward(emb_size,dff, dropout, device)      
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
        self.to(device)
        
    def forward(self,x):
        
        out = self.multihead(x)
        out = F.relu(self.ln1(out)+x)
        out1 = self.feedforward(out)
        out1 = F.relu(self.ln2(out1)+out)
        return out1
    
# decoder = Decoder(emb_size,dff,num_heads,context_length,dropout, device)
# decoder_out = decoder(embd_out)

In [21]:
class DecoderStack(nn.Module):
    def __init__(self,emb_size,dff,num_heads,context_length,N, dropout, device):
        super().__init__() 
        
        self.decstack = nn.ModuleList([Decoder(emb_size,dff,num_heads,context_length,dropout, device) for _ in range(N)])
        self.to(device)
        
    def forward(self,x):
        for decoder in self.decstack:
            x= decoder(x)
        
        return x
    
# ds = DecoderStack(emb_size,dff,num_heads,context_length,N, dropout, device)
# ds_out = ds(embd_out)

In [22]:
#Main Model _Classification
class GPT(nn.Module):
    def __init__(self,emb_size,dff,num_heads,context_length,N, vocab_size, dropout, device):
        super().__init__() 
        
        self.embd =  Embedding(emb_size,vocab_size,context_length, device)
        self.decoderstack = DecoderStack(emb_size,dff,num_heads,context_length,N, dropout, device)
        self.out = nn.Linear(emb_size,vocab_size, device)
        self.to(device)
        
    def forward(self,x):
        x = self.embd(x)
        x= self.decoderstack(x)
        x = F.relu(x)
        x = self.out(x)
        return x 

gpt =  GPT(emb_size,dff,num_heads,context_length,N, vocab_size, dropout, device)  
# gpt_out = gpt(inputs.to(device))


In [23]:
num_params = sum(p.numel() for p in gpt.parameters() if p.requires_grad)

num_params_million = num_params / 1e6

print(f"Number of parameters in the model: {num_params_million:.2f} million")


Number of parameters in the model: 28.02 million


In [24]:

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(gpt.parameters(),lr=lr)

from tqdm import tqdm

def calculate_loss(dataloader):
    val_loss = 0
    gpt.eval() 
    with torch.inference_mode():
        for inputs, outputs in val_dataloader:
            inputs, outputs = inputs.to(device), outputs.to(device)
            pred = gpt(inputs)
            loss = loss_fn(pred.view(-1, vocab_size), outputs.view(-1))
            val_loss += loss.item()
    gpt.train() 
    return val_loss / len(dataloader)


def training(n_epochs):
    total_loss = 0
    losses = []
    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs}")
        for inputs, outputs in train_dataloader:
            inputs, outputs = inputs.to(device), outputs.to(device)
            
            pred = gpt(inputs)
            loss = loss_fn(pred.view(-1, vocab_size), outputs.view(-1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            total_loss += loss.item()
        
        avg_train_loss = total_loss / len(train_dataloader)
        print(f"Avg training loss for epoch {epoch + 1}: {avg_train_loss}")
        total_loss = 0
        
        val_loss = calculate_loss(val_dataloader)
        print(f"Avg validation loss after epoch {epoch + 1}: {val_loss}")
        model_save_path = f'gpt_model_{epoch}.pth'

        torch.save(gpt.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")
    
    return losses
    
    

In [None]:
losses = training(n_epochs=25)

Epoch 1/25


100%|███████████████████████████████████████| 3000/3000 [59:17<00:00,  1.19s/it]


Avg training loss for epoch 1: 3.57179763118426
Avg validation loss after epoch 1: 2.9561573684215547
Epoch 2/25


100%|█████████████████████████████████████| 3000/3000 [1:00:15<00:00,  1.21s/it]


Avg training loss for epoch 2: 2.8619314579963686
Avg validation loss after epoch 2: 2.6150129997730254
Epoch 3/25


 88%|██████████████████████████████████▎    | 2642/3000 [54:26<07:23,  1.24s/it]

In [None]:
#Inference Mode
def next_word_pred(start, max_length, gpt):
    start_tokens = encode(start)
    inp = torch.tensor(start_tokens).unsqueeze(0)
    for _ in range(max_length):
        # print(inp.shape)
        prob = gpt(inp[:,-256:])
        probs = F.softmax(prob[:,-1,:],dim=-1)
        next_token = torch.multinomial(probs,num_samples=1)
        inp = torch.cat([inp, next_token], dim=-1)
        
    return decode(inp.tolist()[0])

out = next_word_pred('Once', 289, gpt)