In [26]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
import random 
import os
from torch.utils.data import DataLoader,Dataset 
import numpy as np
import tiktoken 
from torch.nn.parallel import DataParallel

os.environ["CUDA_VISIBLE_DEVICES"]= "7"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


from datasets import load_dataset

tokenizer = tiktoken.get_encoding("gpt2") 
context_length = 512 
batch_size =16
emb_size= 256
num_heads = 8
dff= 512
decoder_blocks = N = 4
vocab_size = tokenizer.n_vocab
lr = 3e-4
dropout = 0.2
    



class Embedding(nn.Module):
    def __init__(self, emb_size, vocab_size, context_length, device):
        super().__init__()
        self.emb_size=emb_size
        self.vocab_size=vocab_size
        self.context_length= context_length
        self.token_embd= nn.Embedding(vocab_size, emb_size)
        self.pos_embd= nn.Embedding(context_length, emb_size)
        self.to(device)
        self.device = device
        
    def forward(self, x):
        # print(f"Input shape is {x.shape}")
        x = self.token_embd(x)
        # print(f"Token embedding shape is {x.shape}")
        pos = torch.arange(x.shape[-2]).unsqueeze(0).to(device)
        # print(f"Positional Input is {pos.shape}")
        # print(f"Positional {pos}")

        pos_embd = self.pos_embd(pos)
        # print(f"Positional Output is {pos_embd.shape}")
        x= pos_embd+x
        # print(f"Final embd output is {x.shape}")
        
        return x 


import math

class SingleHead(nn.Module):
    def __init__(self, emb_size, head_dim,context_length,dropout, device):
        super().__init__()
        
        self.emb_size=emb_size
        self.head_dim=head_dim
        
        self.Wq = nn.Linear(emb_size, head_dim,bias=False)
        self.Wk = nn.Linear(emb_size, head_dim,bias=False)
        self.Wv = nn.Linear(emb_size, head_dim,bias=False)
        
        self.mask=torch.tril(torch.ones(context_length,context_length)).view(1,context_length,context_length).to(device)
        self.drop = nn.Dropout(dropout)
        self.to(device)
                                
    def forward(self,x):
        # print(x.shape)
        T= x.shape[-2]  
        q = self.Wq(x)
        k =  self.Wk(x)
        v=  self.Wv(x)
                                
        att_weights = q @ k.transpose(1,2) / math.sqrt(self.head_dim)        
        att_weights = att_weights.masked_fill( self.mask[:,:T,:T]==0,-float("inf"))
        att_weights = F.softmax(att_weights,dim = -1)
        att_weights = self.drop(att_weights)
        # print(f"Attention Weights -  {att_weights.shape}")
        out = att_weights @ v
        # print(f"Output {out.shape}")
        return out
    


class MultiHead(nn.Module):
    def __init__(self,emb_size,num_heads,context_length,dropout, device):
        super().__init__() 
        self.heads = nn.ModuleList([SingleHead(emb_size,emb_size//num_heads,context_length,dropout=dropout, device=device) for _ in range(num_heads)])
        self.proj = nn.Linear(emb_size, emb_size)
        self.drop = nn.Dropout(dropout)
        self.to(device)
        
    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        # print(x.shape)
        x = self.proj(x) 
        x = self.drop(x)
        
        return  x
        


class FeedForward(nn.Module):
    def __init__(self,emb_size,dff, dropout, device):
        super().__init__() 
        
        self.fc1 = nn.Linear(emb_size, dff)
        self.fc2 = nn.Linear(dff, emb_size)
        self.drop = nn.Dropout(dropout)
        self.to(device)
        
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = self.drop(self.fc2(x))
        return x
    


class Decoder(nn.Module):
    def __init__(self,emb_size,dff,num_heads,context_length,dropout, device):
        super().__init__() 
        
        self.multihead = MultiHead(emb_size,num_heads,context_length,dropout, device)
        self.feedforward = FeedForward(emb_size,dff, dropout, device)      
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
        self.to(device)
        
    def forward(self,x):
        
        out = self.multihead(x)
        out = F.relu(self.ln1(out)+x)
        out1 = self.feedforward(out)
        out1 = F.relu(self.ln2(out1)+out)
        return out1
    


class DecoderStack(nn.Module):
    def __init__(self,emb_size,dff,num_heads,context_length,N, dropout, device):
        super().__init__() 
        
        self.decstack = nn.ModuleList([Decoder(emb_size,dff,num_heads,context_length,dropout, device) for _ in range(N)])
        self.to(device)
        
    def forward(self,x):
        for decoder in self.decstack:
            x= decoder(x)
        
        return x
    
class GPT(nn.Module):
    def __init__(self,emb_size,dff,num_heads,context_length,N, vocab_size, dropout, device):
        super().__init__() 
        
        self.embd =  Embedding(emb_size,vocab_size,context_length, device)
        self.decoderstack = DecoderStack(emb_size,dff,num_heads,context_length,N, dropout, device)
        self.out = nn.Linear(emb_size,vocab_size, device)
        self.to(device)
        
    def forward(self,x):
        x = self.embd(x)
        x= self.decoderstack(x)
        x = F.relu(x)
        x = self.out(x)
        return x 

gpt =  GPT(emb_size,dff,num_heads,context_length,N, vocab_size, dropout, device)  

gpt.load_state_dict(torch.load("gpt_model_19.pth", map_location=device))

# Move the model to the appropriate device (CPU or GPU)
gpt.to(device)

num_params = sum(p.numel() for p in gpt.parameters() if p.requires_grad)

num_params_million = num_params / 1e6

print(f"Number of parameters in the model: {num_params_million:.2f} million")


loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(gpt.parameters(),lr=lr)



def next_word_pred(start, gpt):
    start = "\n<|startofstory|>\n"+start
    start_tokens = tokenizer.encode(start)
    inp = torch.tensor(start_tokens).unsqueeze(0).to(device)
    while inp[0][-9:].tolist() != [198, 27, 91, 437, 1659, 13571, 91, 29, 198]:
        with torch.inference_mode():
            prob = gpt(inp[:,-context_length:].to(device))
        probs = F.softmax(prob[:,-1,:],dim=-1)
        next_token = torch.multinomial(probs,num_samples=1)
        inp = torch.cat([inp, next_token], dim=-1)
        tokenizer.decode(inp.tolist()[0])
        decoded_string = tokenizer.decode(inp.tolist()[0])
        decoded_string = decoded_string.replace('\n<|startofstory|>\n', '')
        decoded_string = decoded_string.replace('\n<|endofstory|>\n', '')
    return decoded_string

out = next_word_pred('Once upon a time, there was a little boy named Tim.', gpt)



Number of parameters in the model: 28.02 million


In [27]:
print(out)

Once upon a time, there was a little boy named Tim. He was a big kids who loved to play outside. One day, Tim saw a little boy named Timmy's fur. Timmy looked sad and hungry. Timmy ran to his mom and asked, "What are you doing?"

His mom said, "I am looking for food, Timmy. Do not worry because this puddle is yummy. They get food for you to eat."

After a while, Timmy's mom came into the kitchen. She took a plate of cake and tasted it right down. She ate it and cheese. Timmy was so happy and ate his meal! From then on, Timmy loved the busy balance his big sack with his friends.
