In [46]:
import pandas
from datasets import load_dataset
import re
from collections import Counter

In [47]:
train = load_dataset("imdb", split="train[:5%]")
train_text = train["text"]

In [48]:
def clean_nulls(lst):
    sentences = []
    for sent in train_text:
        if not len(sent) < 2:
            sentences.append(sent)
    return sentences

train_text = clean_nulls(train_text)

In [49]:
len(train_text)

1250

In [50]:
def keep_english_only(text):
    text = re.sub(r"[^A-Za-z,'()\"\s\-]", "", text)
    text = re.sub(r"\b\w*(\w)\1{2,}\w*\b", "", text)
    
    # Remove excessive dots, commas, or other repeated punctuation
    text = re.sub(r"[!?.,-]{2,}", " ", text)  

    # Remove any mixed repeated punctuation (like "?!??!", "!!!")
    text = re.sub(r"([!?.,-])\1+", r"\1", text)  
    
    # Remove isolated "br" (often HTML line breaks in IMDB dataset)
    text = re.sub(r"\bbr\b", " ", text)
    
    return text

train_text = [keep_english_only(txt) for txt in train_text]

In [51]:
def clean_bracket_texts(text):
    if not isinstance(text, str):  
        return ""
    cleaned_text = re.sub(r"\s*\([^)]*\)\s*", " ", text)
    return re.sub(r"\s+", " ", cleaned_text).strip() 

train_text = [clean_bracket_texts(txt) for txt in train_text]

In [52]:
from nltk.tokenize import TreebankWordTokenizer
tokenizer = TreebankWordTokenizer()
train_text = [tokenizer.tokenize(txt) for txt in train_text]

In [53]:
tt = [[tok for tok in sent if tok!=''] for sent in train_text]
train_text = tt

In [54]:
train_text = [[txt.lower() for txt in sent if isinstance(txt, str)] for sent in train_text]

In [55]:
train_text = [sent for sent in train_text if len(sent)<501]

In [56]:
len(max(train_text))

306

In [57]:
cleaned_train_text = [[txt for txt in sent if txt not in ['``',"''","'","-"]] for sent in train_text]
train_text = cleaned_train_text

In [58]:
flat_data = [token for sentence in train_text for token in sentence if token]
word_counts = Counter(flat_data)
vocab = {word for word, count in word_counts.items() if count >= 2}
word2idx = {word: idx+2 for idx, word in enumerate(vocab)}
word2idx["<PAD>"] = 0
word2idx["<OOV>"] = 1
idx2word = {idx: word for word, idx in word2idx.items()}

In [59]:
len(word_counts)

17156

In [60]:
len(word2idx)

8097

In [61]:
sequences = []
targets = []

for l in train_text:
    for i in range(len(l) - 1): 
        n_gram_seq = l[:i+1]
        sequences.append(n_gram_seq)
        targets.append(l[i+1])

In [62]:
sequences = [[word2idx.get(token, word2idx["<OOV>"]) for token in seq] for seq in sequences]
targets = [word2idx.get(token, word2idx["<OOV>"]) for token in targets]

In [63]:
import torch
from torch.utils.data import Dataset,DataLoader
from torch.nn.utils.rnn import pad_sequence
sequences_t = [torch.tensor(seq,dtype=torch.long) for seq in sequences]
targets_t = [torch.tensor(t, dtype=torch.long) for t in targets]

In [64]:
class CustomDataset:
    def __init__(self,x,y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self,idx):
        return self.x[idx],self.y[idx]

In [65]:
def collate_fn(batch):
    sequences, targets = zip(*batch)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0, padding_side='left')
    return padded_sequences, torch.tensor(targets)

dataset = CustomDataset(sequences_t,targets_t)
dataloader = DataLoader(dataset=dataset, batch_size=64, collate_fn=collate_fn)

In [66]:
len(dataset)

220756

In [67]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [74]:
import torch
import torch.nn as nn

class NextWordPredictor(nn.Module):
    def __init__(self, vocab_size, embedding_dim=300, hidden_dim=512):
        super(NextWordPredictor, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, embedding_dim) 
        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        
        self.fc = nn.Linear(hidden_dim*2, vocab_size)
        
    def forward(self, x, hidden=None):
        x = self.embed(x)
        lstm_out, hidden = self.lstm(x, hidden)
        last_timestep_output = lstm_out[:, -1, :]
        output = self.fc(last_timestep_output)
    
        return output

# Initialize the model
vocab_size = len(word2idx)
embedding_dim = 300  # You can use pre-trained embeddings if available
hidden_dim = 512 # LSTM hidden dimension
model = NextWordPredictor(vocab_size, embedding_dim, hidden_dim).to(device)


In [75]:
epochs = 50
criterion = nn.CrossEntropyLoss() 
optim = torch.optim.Adam(model.parameters(),lr = 0.0003,weight_decay=1e-5)

In [None]:
import tqdm

for e in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    num_batches = 0

    for X, y in tqdm.tqdm(dataloader, f'{e}/{epochs}'):
        X, y = X.to(device), y.to(device)

        y_pred = model(X)

        loss = criterion(y_pred, y)

        acc = (torch.argmax(y_pred, dim=1) == y).float().mean() * 100  

        epoch_loss += loss.item()
        epoch_accuracy += acc.item()
        num_batches += 1

        optim.zero_grad()
        loss.backward()
        optim.step()

    avg_epoch_loss = epoch_loss / num_batches
    avg_epoch_accuracy = epoch_accuracy / num_batches

    print(f'Epoch {e+1}/{epochs} - Loss: {avg_epoch_loss:.4f}, Accuracy: {avg_epoch_accuracy:.2f}%')

0/50: 100%|██████████| 3450/3450 [02:34<00:00, 22.29it/s]


Epoch 1/50 - Loss: 5.7680, Accuracy: 12.11%


1/50: 100%|██████████| 3450/3450 [02:31<00:00, 22.77it/s]


Epoch 2/50 - Loss: 5.1286, Accuracy: 15.10%


2/50: 100%|██████████| 3450/3450 [02:54<00:00, 19.78it/s]


Epoch 3/50 - Loss: 4.7618, Accuracy: 16.70%


3/50: 100%|██████████| 3450/3450 [02:57<00:00, 19.49it/s]


Epoch 4/50 - Loss: 4.4319, Accuracy: 18.50%


4/50: 100%|██████████| 3450/3450 [02:48<00:00, 20.50it/s]


Epoch 5/50 - Loss: 4.1268, Accuracy: 20.53%


5/50: 100%|██████████| 3450/3450 [02:35<00:00, 22.21it/s]


Epoch 6/50 - Loss: 3.8498, Accuracy: 22.84%


6/50: 100%|██████████| 3450/3450 [02:39<00:00, 21.61it/s]


Epoch 7/50 - Loss: 3.5986, Accuracy: 25.50%


7/50: 100%|██████████| 3450/3450 [02:52<00:00, 19.94it/s]


Epoch 8/50 - Loss: 3.3599, Accuracy: 28.59%


8/50: 100%|██████████| 3450/3450 [02:21<00:00, 24.41it/s]


Epoch 9/50 - Loss: 3.1431, Accuracy: 31.76%


9/50: 100%|██████████| 3450/3450 [02:42<00:00, 21.22it/s]


Epoch 10/50 - Loss: 2.9411, Accuracy: 35.08%


10/50: 100%|██████████| 3450/3450 [02:56<00:00, 19.57it/s]


Epoch 11/50 - Loss: 2.7550, Accuracy: 38.25%


11/50: 100%|██████████| 3450/3450 [02:51<00:00, 20.15it/s]


Epoch 12/50 - Loss: 2.5837, Accuracy: 41.52%


12/50: 100%|██████████| 3450/3450 [02:57<00:00, 19.49it/s]


Epoch 13/50 - Loss: 2.4175, Accuracy: 44.84%


13/50: 100%|██████████| 3450/3450 [02:50<00:00, 20.22it/s]


Epoch 14/50 - Loss: 2.2647, Accuracy: 48.14%


14/50: 100%|██████████| 3450/3450 [02:52<00:00, 19.99it/s]


Epoch 15/50 - Loss: 2.1245, Accuracy: 51.14%


15/50: 100%|██████████| 3450/3450 [02:58<00:00, 19.34it/s]


Epoch 16/50 - Loss: 1.9927, Accuracy: 54.19%


16/50: 100%|██████████| 3450/3450 [02:53<00:00, 19.90it/s]


Epoch 17/50 - Loss: 1.8981, Accuracy: 56.25%


17/50: 100%|██████████| 3450/3450 [02:49<00:00, 20.33it/s]


Epoch 18/50 - Loss: 1.7782, Accuracy: 59.15%


18/50: 100%|██████████| 3450/3450 [02:38<00:00, 21.81it/s]


Epoch 19/50 - Loss: 1.6770, Accuracy: 61.52%


19/50: 100%|██████████| 3450/3450 [02:42<00:00, 21.21it/s]


Epoch 20/50 - Loss: 1.5841, Accuracy: 63.68%


20/50: 100%|██████████| 3450/3450 [02:51<00:00, 20.09it/s]


Epoch 21/50 - Loss: 1.4868, Accuracy: 66.06%


21/50: 100%|██████████| 3450/3450 [02:52<00:00, 20.03it/s]


Epoch 22/50 - Loss: 1.4111, Accuracy: 67.83%


22/50: 100%|██████████| 3450/3450 [02:52<00:00, 19.99it/s]


Epoch 23/50 - Loss: 1.3266, Accuracy: 70.01%


23/50: 100%|██████████| 3450/3450 [02:54<00:00, 19.73it/s]


Epoch 24/50 - Loss: 1.2624, Accuracy: 71.66%


24/50: 100%|██████████| 3450/3450 [02:56<00:00, 19.52it/s]


Epoch 25/50 - Loss: 1.1943, Accuracy: 73.31%


25/50: 100%|██████████| 3450/3450 [02:54<00:00, 19.74it/s]


Epoch 26/50 - Loss: 1.1223, Accuracy: 75.00%


26/50: 100%|██████████| 3450/3450 [02:30<00:00, 22.92it/s]


Epoch 27/50 - Loss: 1.0686, Accuracy: 76.33%


27/50: 100%|██████████| 3450/3450 [02:38<00:00, 21.75it/s]


Epoch 28/50 - Loss: 0.9977, Accuracy: 78.32%


28/50: 100%|██████████| 3450/3450 [02:52<00:00, 19.98it/s]


Epoch 29/50 - Loss: 0.9928, Accuracy: 78.08%


29/50: 100%|██████████| 3450/3450 [02:57<00:00, 19.45it/s]


Epoch 30/50 - Loss: 0.9264, Accuracy: 79.93%


30/50: 100%|██████████| 3450/3450 [02:57<00:00, 19.46it/s]


Epoch 31/50 - Loss: 0.8704, Accuracy: 81.36%


31/50: 100%|██████████| 3450/3450 [02:55<00:00, 19.64it/s]


Epoch 32/50 - Loss: 0.8464, Accuracy: 81.92%


32/50: 100%|██████████| 3450/3450 [02:49<00:00, 20.30it/s]


Epoch 33/50 - Loss: 0.7949, Accuracy: 83.35%


33/50: 100%|██████████| 3450/3450 [02:53<00:00, 19.92it/s]


Epoch 34/50 - Loss: 0.7471, Accuracy: 84.67%


34/50: 100%|██████████| 3450/3450 [02:59<00:00, 19.27it/s]


Epoch 35/50 - Loss: 0.7264, Accuracy: 85.09%


35/50: 100%|██████████| 3450/3450 [02:50<00:00, 20.20it/s]


Epoch 36/50 - Loss: 0.6909, Accuracy: 85.93%


36/50: 100%|██████████| 3450/3450 [02:49<00:00, 20.41it/s]


Epoch 37/50 - Loss: 0.6578, Accuracy: 86.83%


37/50: 100%|██████████| 3450/3450 [02:55<00:00, 19.71it/s]


Epoch 38/50 - Loss: 0.6354, Accuracy: 87.47%


38/50: 100%|██████████| 3450/3450 [02:53<00:00, 19.94it/s]


Epoch 39/50 - Loss: 0.6011, Accuracy: 88.49%


39/50: 100%|██████████| 3450/3450 [02:59<00:00, 19.23it/s]


Epoch 40/50 - Loss: 0.5756, Accuracy: 89.10%


40/50: 100%|██████████| 3450/3450 [03:01<00:00, 19.05it/s]


Epoch 41/50 - Loss: 0.5577, Accuracy: 89.47%


41/50: 100%|██████████| 3450/3450 [02:53<00:00, 19.83it/s]


Epoch 42/50 - Loss: 0.5243, Accuracy: 90.41%


42/50: 100%|██████████| 3450/3450 [02:55<00:00, 19.60it/s]


Epoch 43/50 - Loss: 0.5310, Accuracy: 90.10%


43/50: 100%|██████████| 3450/3450 [02:57<00:00, 19.47it/s]


Epoch 44/50 - Loss: 0.5113, Accuracy: 90.62%


44/50: 100%|██████████| 3450/3450 [02:56<00:00, 19.57it/s]


Epoch 45/50 - Loss: 0.4925, Accuracy: 91.08%


45/50: 100%|██████████| 3450/3450 [02:54<00:00, 19.79it/s]


Epoch 46/50 - Loss: 0.4826, Accuracy: 91.31%


46/50: 100%|██████████| 3450/3450 [02:44<00:00, 20.99it/s]


Epoch 47/50 - Loss: 0.4669, Accuracy: 91.70%


47/50: 100%|██████████| 3450/3450 [02:29<00:00, 23.03it/s]


Epoch 48/50 - Loss: 0.4636, Accuracy: 91.80%


48/50: 100%|██████████| 3450/3450 [02:51<00:00, 20.15it/s]


Epoch 49/50 - Loss: 0.5877, Accuracy: 87.79%


49/50: 100%|██████████| 3450/3450 [02:32<00:00, 22.58it/s]

Epoch 50/50 - Loss: 0.5177, Accuracy: 89.95%





In [119]:
torch.save(model.state_dict(),r'D:\Codes\Python\DL\Codes\nextwordpred.pt')

In [120]:
import torch.nn.functional as F

def prediction(model, vocab, inv_vocab, text, tokens, temperature=1.0, top_k=5):
    sequences = [vocab.get(token, vocab["<OOV>"]) for token in text.lower().split()]
    generated_words = [inv_vocab[w] for w in sequences]  

    for _ in range(tokens):
        sequences_t = torch.tensor(sequences, dtype=torch.long)
        
        if len(sequences_t) < 101:
            z = torch.zeros(101 - len(sequences_t), dtype=torch.long)
            padded = torch.cat((z, sequences_t)).unsqueeze(0).to(device)
        else:
            padded = sequences_t[-101:].unsqueeze(0).to(device)

        o = model(padded)

        probs = F.softmax(o / temperature, dim=1)

        top_probs, top_indices = torch.topk(probs, top_k, dim=1)
        top_probs = top_probs.squeeze()
        top_indices = top_indices.squeeze()

        sampled_idx = top_indices[torch.multinomial(top_probs, 1).item()].item()

        predicted_word = inv_vocab.get(sampled_idx, "<OOV>")
        sequences.append(sampled_idx)
        generated_words.append(predicted_word)  
    
    final = []    
    i = 0
    while i < len(generated_words) - 1:
        if "'" in generated_words[i+1]:
            final.append(generated_words[i] + generated_words[i+1])
            del generated_words[i+1]  

        else:
            final.append(generated_words[i])
        i += 1  

    if generated_words:
        final.append(generated_words[-1])
        
    print(' '.join(final))


In [103]:
prediction(model, word2idx, idx2word, 'We', 10, temperature=1.0, top_k=5)

we saw this late at best , because it seems to


In [105]:
prediction(model, word2idx, idx2word, 'I am', 12, temperature=1.0, top_k=5)

i am really shocked that this poor excuse for those things on i'd i


In [108]:
prediction(model, word2idx, idx2word, 'Everyone', 14, temperature=1.0, top_k=5)

everyone else that director might be better well written and directed this movie look for


In [None]:
prediction(model, word2idx, idx2word, 'Chapter', 16, temperature=1.0, top_k=5)  # chapter is not in the vocabulary

<OOV> this movie appears to be in one movie and expected he attempts to raise an endless


In [111]:
prediction(model, word2idx, idx2word, 'maybe', 16, temperature=1.0, top_k=5)

maybe someone that make me think by saying i don't get this one by saying i


In [113]:
prediction(model, word2idx, idx2word, 'lights', 18, temperature=1.0, top_k=5)

lights out of the original planet , and this is why this movie is so awful , that it


In [114]:
prediction(model, word2idx, idx2word, 'Remember', 16, temperature=1.0, top_k=5)

remember when i first saw the preview of his first film so do not be good enough
