<a href="https://colab.research.google.com/github/pratikjagtapofficial/Next-Word-Prediction-LSTM/blob/main/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
import nltk

In [4]:
import kagglehub

path = kagglehub.dataset_download("ashishpandey2062/next-word-predictor-text-generator-dataset")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'next-word-predictor-text-generator-dataset' dataset.
Path to dataset files: /kaggle/input/next-word-predictor-text-generator-dataset


In [5]:
import os

files = os.listdir(path)
print(files)

['next_word_predictor.txt']


In [6]:
file_path = os.path.join(path, files[0])

with open(file_path, "r", encoding="utf-8") as f:
    document = f.read()

print(type(document))

<class 'str'>


In [7]:
len(document)

167445

In [8]:
# Tokenization
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [9]:
# tokenize
tokens = word_tokenize(document.lower())

In [10]:
# build Vocab

from collections import Counter

counter = Counter(tokens)

vocab = {}

# Add special tokens FIRST
vocab["<PAD>"] = 0
vocab["<SOS>"] = 1
vocab["<EOS>"] = 2
vocab["<UNK>"] = 3

Counter(tokens) # this provides how many times unique words came in document

Counter({'the': 1414,
         'sun': 16,
         'was': 147,
         'shining': 1,
         'brightly': 2,
         'in': 307,
         'clear': 7,
         'blue': 5,
         'sky': 16,
         ',': 2225,
         'and': 849,
         'a': 706,
         'gentle': 1,
         'breeze': 2,
         'rustled': 1,
         'leaves': 8,
         'of': 689,
         'tall': 3,
         'trees': 7,
         '.': 1752,
         'people': 44,
         'were': 59,
         'out': 77,
         'enjoying': 3,
         'beautiful': 5,
         'weather': 3,
         'some': 37,
         'sitting': 6,
         'park': 7,
         'others': 7,
         'taking': 7,
         'leisurely': 3,
         'stroll': 1,
         'along': 5,
         'riverbank': 1,
         'children': 3,
         'playing': 7,
         'games': 1,
         'laughter': 9,
         'filled': 14,
         'air': 32,
         'as': 120,
         'day': 14,
         'turned': 3,
         'into': 49,
         'evening': 2,
 

In [11]:
for word in counter.keys():
    if word not in vocab:
        vocab[word] = len(vocab)
vocab # Vocabulary built

{'<PAD>': 0,
 '<SOS>': 1,
 '<EOS>': 2,
 '<UNK>': 3,
 'the': 4,
 'sun': 5,
 'was': 6,
 'shining': 7,
 'brightly': 8,
 'in': 9,
 'clear': 10,
 'blue': 11,
 'sky': 12,
 ',': 13,
 'and': 14,
 'a': 15,
 'gentle': 16,
 'breeze': 17,
 'rustled': 18,
 'leaves': 19,
 'of': 20,
 'tall': 21,
 'trees': 22,
 '.': 23,
 'people': 24,
 'were': 25,
 'out': 26,
 'enjoying': 27,
 'beautiful': 28,
 'weather': 29,
 'some': 30,
 'sitting': 31,
 'park': 32,
 'others': 33,
 'taking': 34,
 'leisurely': 35,
 'stroll': 36,
 'along': 37,
 'riverbank': 38,
 'children': 39,
 'playing': 40,
 'games': 41,
 'laughter': 42,
 'filled': 43,
 'air': 44,
 'as': 45,
 'day': 46,
 'turned': 47,
 'into': 48,
 'evening': 49,
 'temperature': 50,
 'started': 51,
 'to': 52,
 'drop': 53,
 'transformed': 54,
 'canvas': 55,
 'vibrant': 56,
 'colors': 57,
 'families': 58,
 'gathered': 59,
 'for': 60,
 'picnics': 61,
 'smell': 62,
 'barbecues': 63,
 'wafted': 64,
 'through': 65,
 'it': 66,
 'perfect': 67,
 'picnic': 68,
 'by': 69,
 'la

In [12]:
len(vocab)

5061

In [13]:
# Create Reverse Vocabulary

id_to_word = {idx: word for word, idx in vocab.items()}

In [14]:
# Convert words to numbers

def words_to_numbers(sentence, vocab):
  numerical_sentence = []

  for token in sentence:
    if token in vocab:
      numerical_sentence.append(vocab[token])
    else:
      numerical_sentence.append(vocab['<UNK>'])

  return numerical_sentence

In [15]:
from nltk.tokenize import sent_tokenize

input_sentences = sent_tokenize(document)

In [16]:
# After converting words into numbers assign numbers to sentences

input_numerical_sentences = []

for sentence in input_sentences:
    tokens = word_tokenize(sentence.lower())
    tokens = ["<SOS>"] + tokens + ["<EOS>"]
    input_numerical_sentences.append(words_to_numbers(tokens, vocab))

In [17]:
len(input_numerical_sentences)

2564

In [18]:
# Create training Sequence
training_sequences = []
for sentence in input_numerical_sentences:

  for i in range(1, len(sentence)):
    training_sequences.append(sentence[:i+1])
training_sequences

[[1, 4],
 [1, 4, 5],
 [1, 4, 5, 6],
 [1, 4, 5, 6, 7],
 [1, 4, 5, 6, 7, 8],
 [1, 4, 5, 6, 7, 8, 9],
 [1, 4, 5, 6, 7, 8, 9, 4],
 [1, 4, 5, 6, 7, 8, 9, 4, 10],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16, 17],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 19],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 19, 20],
 [1, 4, 5, 6, 7, 8, 9, 4, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 19, 20, 4],
 [1,
  4,
  5,
  6,
  7,
  8,
  9,
  4,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  4,
  19,
  20,
  4,
  21],
 [1,
  4,
  5,
  6,
  7,
  8,
  9,
  4,
 

In [19]:
len(training_sequences)

38458

In [20]:
training_sequences[:5]

[[1, 4], [1, 4, 5], [1, 4, 5, 6], [1, 4, 5, 6, 7], [1, 4, 5, 6, 7, 8]]

In [21]:
# we have 942 sequence but all are having different size we want same size of sequence for model training
# so we find max len of sentence and using padding make all sentences same len
len_list = []

for sequence in training_sequences:
  len_list.append(len(sequence))

max(len_list)

89

In [22]:
padded_training_sequence = []
for sequence in training_sequences:
  padded_training_sequence.append([0]*(max(len_list) - len(sequence)) + sequence)

In [23]:
len(padded_training_sequence[8]) # now any sequence is of same length

89

In [24]:
# Now for training LSTM model we need to convert 2d vector into 3d Tensor
# for that we will do

padded_training_sequence = torch.tensor(padded_training_sequence, dtype=torch.long)

In [25]:
padded_training_sequence # Now our data is ready for training

tensor([[   0,    0,    0,  ...,    0,    1,    4],
        [   0,    0,    0,  ...,    1,    4,    5],
        [   0,    0,    0,  ...,    4,    5,    6],
        ...,
        [   0,    0,    0,  ...,  114,  228, 4777],
        [   0,    0,    0,  ...,  228, 4777,   23],
        [   0,    0,    0,  ..., 4777,   23,    2]])

In [26]:
# Split data into x&y

x = padded_training_sequence[:, :-1]
y = padded_training_sequence[:, -1]

In [27]:
x

tensor([[   0,    0,    0,  ...,    0,    0,    1],
        [   0,    0,    0,  ...,    0,    1,    4],
        [   0,    0,    0,  ...,    1,    4,    5],
        ...,
        [   0,    0,    0,  ..., 2300,  114,  228],
        [   0,    0,    0,  ...,  114,  228, 4777],
        [   0,    0,    0,  ...,  228, 4777,   23]])

In [28]:
y

tensor([   4,    5,    6,  ..., 4777,   23,    2])

In [29]:
# Now Create Dataset & DataLoader Class

class CustomDataset(Dataset):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __len__(self):
    return self.x.shape[0]

  def __getitem__(self, idx):
    return self.x[idx], self.y[idx]

In [30]:
dataset = CustomDataset(x,y)

In [31]:
len(dataset)

38458

In [32]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [33]:
# Create LSTM Architecture

class LSTMModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, 100)
    self.lstm = nn.LSTM(100, 150, batch_first=True)
    self.fc = nn.Linear(150, vocab_size)

  def forward(self, x):
    embedded = self.embedding(x)
    intermediate_hidden_states, (final_hidden_state, final_cell_state) = self.lstm(embedded)
    output = self.fc(final_hidden_state.squeeze(0))
    return output

In [34]:
model = LSTMModel(len(vocab))

In [35]:
# Load Model on GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

LSTMModel(
  (embedding): Embedding(5061, 100)
  (lstm): LSTM(100, 150, batch_first=True)
  (fc): Linear(in_features=150, out_features=5061, bias=True)
)

In [36]:
# Create loss & optimizer Function
epochs = 2
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [37]:
# Create training Loop

for epoch in range(epochs):
  total_loss = 0
  for batch_x, batch_y in dataloader:
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)
    optimizer.zero_grad()
    output = model(batch_x)
    loss = loss_fn(output, batch_y)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  print(f"Epoch: {epoch+1}/{epochs} Loss: {loss.item()}")

Epoch: 1/2 Loss: 4.823441982269287
Epoch: 2/2 Loss: 5.0960307121276855


In [42]:
import torch
import torch.nn.functional as F
import time

def generate_stream(model, vocab, id_to_word, input_text, max_length=50, temperature=1.0):
    model.eval()

    words = ["<SOS>"] + input_text.lower().split()

    print("Bot:", end=" ", flush=True)

    for _ in range(max_length):
        input_ids = [vocab.get(word, vocab["<UNK>"]) for word in words]

        input_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_tensor)

        logits = output[0] / temperature
        probs = F.softmax(logits, dim=0)

        next_word_id = torch.multinomial(probs, 1).item()
        next_word = id_to_word[next_word_id]

        if next_word in ["<EOS>", "<PAD>", "<SOS>"]:
            break

        words.append(next_word)

        print(next_word, end=" ", flush=True)
        time.sleep(0.15)

    print()

In [43]:
print("Type 'exit' to end the chat.\n")

while True:
    user_input = input("You: ")

    if user_input.lower() == "exit":
        print("Bot: Goodbye!")
        break

    generate_stream(model, vocab, {v: k for k, v in vocab.items()}, user_input)

Type 'exit' to end the chat.

You: hi
Bot: emergency colorful eliminates . 
You: how is your day
Bot: ? 
You: hello
Bot: does the future went . 
You: does the future went
Bot: his programs kill cell wind but twister . ] 
You: twister
Bot: this once ' 
You: once
Bot: ross : well , maybe holding 
You: yeah
Bot: instead encoding ? 
You: what is encoding
Bot: , mr. , with the paper contributing jewelry okay . 
You: who are you
Bot: . 
You: dont know
Bot: 's impartial , and ross . 
You: exit
Bot: Goodbye!
