In [1]:
!nvidia-smi

Mon Apr 29 22:31:11 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.76                 Driver Version: 550.76         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2070 ...    Off |   00000000:0A:00.0  On |                  N/A |
|  0%   49C    P5             24W /  215W |     519MiB /   8192MiB |     38%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import opendatasets as od
import string
import random
import unicodedata
import os
import re

from collections import Counter
import nltk
from nltk.tokenize import word_tokenize

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Normalize, ToTensor

from models.models import Encoder, Decoder, EncDec

In [3]:
if torch.cuda.is_available():
    device=torch.device(type="cuda", index=0)
else:
    device=torch.device(type="cpu", index=0)
print(device)

cuda:0


In [4]:
dataset_path = "Datasets/final.csv"
dataset = pd.read_csv(dataset_path)
dataset.head()

Unnamed: 0.1,Unnamed: 0,id,og,t
0,0,42928-1500614319216-63344,You do not meet a man but frowns:,Every man you meet these days is frowning.
1,1,42928-1500614326583-89821,our bloods No more obey the heavens than our...,Our bodies are in agreement with the planetar...
2,2,A-63849,But what's the matter?,What's wrong?
3,3,42930-1500614347266-80123,"His daughter, and the heir of's kingdom, whom...","The king wanted his daughter, the only heir to..."
4,4,42930-1500614355280-38326,she's wedded; Her husband banish'd; she impr...,"She's married, her husband is banished, she's..."


In [5]:
pairs = list(zip(dataset['t'], dataset['og']))
print(len(pairs))
print(pairs[0])

51787
('Every man you meet these days is frowning.', ' You do not meet a man but frowns: ')


In [6]:
def normalizeString(s):
    sres=''
    for c in unicodedata.normalize('NFD', s):
        if unicodedata.category(c) != 'Mn':
            sres+=c
    
    sres = re.sub(r"([.!?])", r" \1", sres) 
    sres = re.sub(r"[^a-zA-Z!?]+", r" ", sres) 
    return sres.strip()

def createNormalizedPairs(pairs):
    initpairs = []
    for pair in pairs:
        s1, s2 = pair
        s1=normalizeString(s1.lower().strip())
        s2=normalizeString(s2.lower().strip())
        # print(len(s1), " ",len(s2))
        initpairs.append([s1,s2])
    # print(len(initpairs))
    return initpairs

In [7]:
MAX_LENGTH = 15

def filterpairs(initpairs):
    eng_prefixes = (
        "i am ", "i m ",
        "he is", "he s ",
        "she is", "she s ",
        "you are", "you re ",
        "we are", "we re ",
        "they are", "they re ",
        "i have ", "i ve ",
        "you have ", "you ve ",
        "we have ", "we ve ",
        "they have ", "they ve ",
        "is not ", "isn t ",
        "are not ", "aren t ",
        "was not ", "wasn t ",
        "were not ", "weren t ",
        "have not ", "haven t ",
        "has not ", "hasn t ",
        "had not ", "hadn t ",
        "will not ", "won t ",
        "would not ", "wouldn t ",
        "do not ", "don t ",
        "does not ", "doesn t ",
        "did not ", "didn t ",
        "can not ", "can t ",
        "could not ", "couldn t ",
        "should not ", "shouldn t ",
        "might not ", "mightn t ",
        "must not ", "mustn t "
    )
    pairs = []
    for pair in initpairs:
        if len(pair[0].split(" ")) < MAX_LENGTH and len(pair[1].split(" ")) < MAX_LENGTH and pair[0].lower().startswith(eng_prefixes):
            pairs.append(pair)
        
    # print("Filtered pairs:", len(pairs))
    return pairs

In [8]:
class Vocab:
    def __init__(self, name):
        self.name = name
        self.word2index={'SOS':0, 'EOS':1}
        self.index2word={0:'SOS', 1:'EOS'}
        self.word2count={}
        self.nwords = 2
        
    def buildVocab(self, s):
        for word in s.split(" "):
            if word not in self.word2index:
                self.word2index[word] = self.nwords
                self.index2word[self.nwords] = word
                self.word2count[word]=1
                self.nwords+=1
            else:
                self.word2count[word]+=1

In [9]:
def get_input_ids(sentence, langobj):
    input_ids = []
    for word in sentence.split(" "):
        input_ids.append(langobj.word2index[word])
    
       
    if langobj.name=='shake':
        input_ids.append(langobj.word2index['EOS'])
    else:
        input_ids.insert(0,langobj.word2index['SOS'])
        input_ids.append(langobj.word2index['EOS'])
    return torch.tensor(input_ids)

In [10]:
class customDataset(Dataset):
    def __init__(self):
        super().__init__()
        
    def __len__(self):
        return length
    
    def __getitem__(self, idx):
        t, s = pairs[idx]
        s_input_ids=torch.zeros(MAX_LENGTH+1, dtype=torch.int64)
        t_input_ids=torch.zeros(MAX_LENGTH+2, dtype=torch.int64)
        s_input_ids[:len(s.split(" "))+1]=get_input_ids(s, shake)
        t_input_ids[:len(t.split(" "))+2]=get_input_ids(t, eng)
        
        return s_input_ids, t_input_ids

In [11]:
def train_one_epoch():
    encoder.train()
    decoder.train()
    track_loss = 0
    
    for i, (s_ids, t_ids)in enumerate(train_dataloader):
        s_ids = s_ids.to(device)
        t_ids = t_ids.to(device)
        encoder_hidden = encoder(s_ids)
        decoder_hidden = encoder_hidden
        yhats, decoder_hidden = decoder(t_ids[:,0:-1], decoder_hidden)
        
        gt = t_ids[:,1:]
        
        yhats_reshaped = yhats.view(-1,yhats.shape[-1])
        
        gt=gt.reshape(-1)
        
        loss=loss_fn(yhats_reshaped,gt)
        track_loss+=loss.item()
        
        opte.zero_grad()
        optd.zero_grad()
        
        loss.backward()
        
        opte.step()
        optd.step()
    
    return track_loss/len(train_dataloader)
        

In [12]:
def ids2Sentence(ids,vocab):
    sentence=""
    for id in ids.squeeze():
        if id==0:
            continue
        word=vocab.index2word[id.item()]
        sentence+=word + " "
        if id==1:  
            break
    return sentence

In [13]:
def eval_one_epoch(e,n_epochs):
    encoder.eval()
    decoder.eval()
    track_loss=0
    with torch.no_grad():
        for i, (s_ids,t_ids) in enumerate(test_dataloader):
            s_ids=s_ids.to(device)
            t_ids=t_ids.to(device)
            
            encoder_hidden = encoder(s_ids)
            decoder_hidden = encoder_hidden
            input_ids=t_ids[:,0]
            yhats=[]
            if e+1==n_epochs:
                pred_sentence=""
            for j in range(1, MAX_LENGTH+2):
                probs, decoder_hidden = decoder(input_ids.unsqueeze(1),decoder_hidden)
                yhats.append(probs)
                _,input_ids=torch.topk(probs,1,dim=-1)
                input_ids=input_ids.squeeze(1,2)
                if e+1==n_epochs:
                    word=eng.index2word[input_ids.item()]
                    pred_sentence+=word + " "
                if input_ids.item() == 1:
                    break;
                
            if e+1==n_epochs:
                src_sentence=ids2Sentence(s_ids,shake)
                gt_sentence=ids2Sentence(t_ids[:,1:],eng) 
                
                print("\n-----------------------------------")
                print("Source Sentence:",src_sentence)
                print("GT Sentence:",gt_sentence)
                print("Predicted Sentence:",pred_sentence)
                
            yhats_cat=torch.cat(yhats,dim=1)
            yhats_reshaped=yhats_cat.view(-1,yhats_cat.shape[-1])
            gt=t_ids[:,1:j+1]
            gt=gt.view(-1)
            
            loss=loss_fn(yhats_reshaped,gt)
            track_loss+=loss.item()
        
        if e+1==n_epochs:    
            print("-----------------------------------")
            
        return track_loss/len(test_dataloader)
            

In [14]:
pairs = createNormalizedPairs(pairs)
pairs = filterpairs(pairs)
length=len(pairs)
print(len(pairs))

1423


In [15]:
print(pairs[1])

['you are like poison in my blood', 'thou rt poison to my blood']


In [16]:
shake = Vocab('shake')
eng = Vocab('eng')

for pair in pairs:
    eng.buildVocab(pair[0])
    shake.buildVocab(pair[1])
    
print("English Vocab Length:",eng.nwords)
print("Shakespere Vocab Length:",shake.nwords)

English Vocab Length: 2033
Shakespere Vocab Length: 2228


In [17]:
dataset = customDataset()
train_dataset, test_dataset = random_split(dataset, [0.99,0.01])

batch_size = 32
train_dataloader=DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=False)
test_dataloader=DataLoader(dataset=test_dataset,batch_size=1, shuffle=False)

In [18]:
encoder = Encoder(shake.nwords).to(device)
decoder = Decoder(eng.nwords).to(device)
print(encoder,decoder)

Encoder(
  (embedding): Embedding(2228, 300)
  (dropout): Dropout(p=0.1, inplace=False)
  (rnn): GRU(300, 512, batch_first=True)
) Decoder(
  (embedding): Embedding(2033, 300)
  (relu): ReLU()
  (lsmax): LogSoftmax(dim=-1)
  (rnn): GRU(300, 512, batch_first=True)
  (linear): Linear(in_features=512, out_features=2033, bias=True)
)


In [19]:
loss_fn=nn.NLLLoss(ignore_index=0).to(device)
lr=0.001
opte=optim.Adam(params=encoder.parameters(), lr=lr, weight_decay=0.001)
optd=optim.Adam(params=decoder.parameters(), lr=lr, weight_decay=0.001)

n_epochs = 80

In [20]:
for e in range(n_epochs):
    print("Epoch=",e+1, sep="", end=", ")
    print("Train Loss=", round(train_one_epoch(),4), sep="", end=", ")
    print("Eval Loss=",round(eval_one_epoch(e,n_epochs),4), sep="")

Epoch=1, Train Loss=5.4379, Eval Loss=4.402
Epoch=2, Train Loss=4.6169, Eval Loss=4.5406
Epoch=3, Train Loss=4.2712, Eval Loss=4.6321
Epoch=4, Train Loss=3.998, Eval Loss=4.9279
Epoch=5, Train Loss=3.7426, Eval Loss=5.1585
Epoch=6, Train Loss=3.5323, Eval Loss=5.4019
Epoch=7, Train Loss=3.3376, Eval Loss=5.6446
Epoch=8, Train Loss=3.1522, Eval Loss=5.6333
Epoch=9, Train Loss=3.0099, Eval Loss=5.5565
Epoch=10, Train Loss=2.8713, Eval Loss=5.565
Epoch=11, Train Loss=2.7522, Eval Loss=5.653
Epoch=12, Train Loss=2.6274, Eval Loss=5.5692
Epoch=13, Train Loss=2.4892, Eval Loss=5.6997
Epoch=14, Train Loss=2.3706, Eval Loss=5.8156
Epoch=15, Train Loss=2.2436, Eval Loss=5.5963
Epoch=16, Train Loss=2.1511, Eval Loss=5.3016
Epoch=17, Train Loss=2.0633, Eval Loss=5.7814
Epoch=18, Train Loss=1.9726, Eval Loss=5.7538
Epoch=19, Train Loss=1.8928, Eval Loss=5.7836
Epoch=20, Train Loss=1.8003, Eval Loss=5.5923
Epoch=21, Train Loss=1.7025, Eval Loss=5.6358
Epoch=22, Train Loss=1.6186, Eval Loss=5.7064
E