In [2]:
!pip install nltk



In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from nltk.tokenize import word_tokenize, sent_tokenize
import nltk
import re

In [4]:
# Tokenization
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [5]:
!wget https://s3.amazonaws.com/fast-ai-nlp/wikitext-2.tgz

--2026-01-02 08:08:13--  https://s3.amazonaws.com/fast-ai-nlp/wikitext-2.tgz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.199.168, 16.15.187.65, 52.217.230.184, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.199.168|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4070055 (3.9M) [application/x-tar]
Saving to: ‘wikitext-2.tgz’


2026-01-02 08:08:13 (19.7 MB/s) - ‘wikitext-2.tgz’ saved [4070055/4070055]



In [6]:
def load_data(filepath):

    f=open(filepath)
    return f.read()

In [7]:
!tar -xzvf "/content/wikitext-2.tgz" -C "/content/"

wikitext-2/
wikitext-2/train.csv
wikitext-2/test.csv


In [8]:
train=load_data("/content/wikitext-2/train.csv")
data = train[:]

In [9]:
def clean_text(data):

    # Remove all brackets
    data = re.sub(r"[()]", "", data)

    # Remove Wikipedia-style headings (e.g. = Heading =, == Heading ==)
    data = re.sub(r"=+.*?=+", "", data)

    # Remove <unk> tokens
    data = re.sub(r"<unk>", "", data)

    # Replace hyphens with space
    data = re.sub(r"-", " ", data)

    # Remove all non-alphanumeric characters
    # Retain period (.) and apostrophe (')
    data = re.sub(r"[^\w\s\.\']", "", data)

    # Normalize whitespace
    data = re.sub(r"[ \t]+", " ", data)   # normalize spaces only
    data = re.sub(r"\n+", "\n", data)     # normalize newlines

    return data.strip()


In [10]:
data = clean_text(data)

In [11]:
print(data[:300])
type(data)

The 2013 14 season was the season of competitive association football and 77th season in the Football League played by York City Football Club a professional football club based in York North Yorkshire England . Their 17th place finish in 2012 13 meant it was their second consecutive season in Leagu


str

In [12]:
#to split our raw txt into sentences and words, by which we can form a training sequence
def split_data(data, num_sentences=-1):
    """Splits text data into words and sentences """
    #Sentence tokenization
    if num_sentences==-1:
        sentences=sent_tokenize(data)
    else:
        sentences=sent_tokenize(data)[:num_sentences]

    #Word tokenization
    words=set()
    for sent in sentences:
        for word in sent.split():
            words.add(word)
    words=list(words)
    return sentences, words

In [13]:
input_sentences, tokens = split_data(data, num_sentences=5000) #less data to train

In [14]:
input_sentences[0]

'The 2013 14 season was the season of competitive association football and 77th season in the Football League played by York City Football Club a professional football club based in York North Yorkshire England .'

In [15]:
# build vocab
vocab = {'<pad>': 0, '<unk>': 1}

for token in Counter(tokens).keys():
  if token not in vocab:
    vocab[token] = len(vocab)

vocab

{'<pad>': 0,
 '<unk>': 1,
 'North': 2,
 'dismantled': 3,
 'signing': 4,
 'conflicting': 5,
 'supposed': 6,
 'bite': 7,
 'with': 8,
 'junior': 9,
 'Hendrix': 10,
 'Gulf': 11,
 'meet': 12,
 'Point': 13,
 'archaeological': 14,
 'launching': 15,
 'Fish': 16,
 'Being': 17,
 'knife': 18,
 'crewmen': 19,
 'knight': 20,
 '73rd': 21,
 'consecutive': 22,
 'Nanak': 23,
 'funnels': 24,
 'rated': 25,
 'resistance': 26,
 'Landmarks': 27,
 'Thus': 28,
 'examination': 29,
 'zebra': 30,
 'announce': 31,
 'corpse': 32,
 'myth': 33,
 'publicity': 34,
 'manor': 35,
 'sides': 36,
 'stick': 37,
 'separated': 38,
 'elongated': 39,
 '1980s': 40,
 'conference': 41,
 'New': 42,
 'repel': 43,
 'terminal': 44,
 'erected': 45,
 'Conversely': 46,
 'demonstration': 47,
 'neighboring': 48,
 'nine': 49,
 'defence': 50,
 'indeed': 51,
 'Army': 52,
 'openings': 53,
 'Valentin': 54,
 'plenty': 55,
 'Reproduction': 56,
 'morel': 57,
 'Rare': 58,
 'moor': 59,
 'accused': 60,
 'millimetre': 61,
 'Ironically': 62,
 'tears': 

In [16]:
len(vocab)

12475

In [17]:
def text_to_indices(sentence, vocab):

  numerical_sentence = []

  for token in sentence:
    if token in vocab:
      numerical_sentence.append(vocab[token])
    else:
      numerical_sentence.append(vocab['<unk>'])

  return numerical_sentence

In [18]:
input_numerical_sentences = []

for sentence in input_sentences:
  input_numerical_sentences.append(text_to_indices(word_tokenize(sentence.lower()), vocab))

In [19]:
len(input_numerical_sentences)
input_numerical_sentences

[[11102,
  11490,
  7019,
  6574,
  5356,
  11102,
  6574,
  1871,
  9045,
  4443,
  9950,
  5640,
  3514,
  6574,
  10207,
  11102,
  9950,
  11805,
  8074,
  1731,
  1,
  2641,
  9950,
  3129,
  5455,
  7199,
  9950,
  3129,
  11578,
  10207,
  1,
  980,
  1,
  1,
  3337],
 [2789,
  2721,
  1310,
  1302,
  10207,
  8765,
  8136,
  6162,
  379,
  5356,
  2789,
  2137,
  22,
  6574,
  10207,
  11805,
  7028,
  3337],
 [11102, 6574, 1891, 2204, 3499, 1, 11490, 11127, 7390, 1, 1970, 3337],
 [1,
  1,
  10157,
  1301,
  5779,
  1771,
  6574,
  184,
  1,
  9119,
  10580,
  3138,
  527,
  11254,
  10330,
  3337],
 [1731,
  11102,
  1956,
  1871,
  11102,
  10333,
  1,
  4712,
  10240,
  6829,
  11102,
  772,
  492,
  2680,
  1015,
  9285,
  5587,
  5455,
  5464,
  12392,
  7029,
  3481,
  2306,
  11102,
  8465,
  1302,
  10207,
  2349,
  1310,
  10207,
  11102,
  10285,
  8465,
  11490,
  7019,
  9950,
  11805,
  7028,
  3337],
 [4317,
  6162,
  1,
  4263,
  12016,
  11102,
  4251,
  10045,


In [20]:
training_sequence = []
for sentence in input_numerical_sentences:
  for i in range(1, len(sentence)):
    training_sequence.append(sentence[:i+1])

In [21]:
len(training_sequence)

107307

In [22]:
training_sequence[:5]

[[11102, 11490],
 [11102, 11490, 7019],
 [11102, 11490, 7019, 6574],
 [11102, 11490, 7019, 6574, 5356],
 [11102, 11490, 7019, 6574, 5356, 11102]]

In [23]:
len_list = []

for sequence in training_sequence:
  len_list.append(len(sequence))

max(len_list)

91

In [24]:
class CustomDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long)

In [25]:
def collate_fn(batch):

    padded = pad_sequence(batch, batch_first=True, padding_value=0)
    X = padded[:, :-1]
    y = padded[:, -1]
    return X, y

In [26]:
dataset = CustomDataset(training_sequence)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [27]:
class LSTMModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, 100, padding_idx=0)
    self.lstm = nn.LSTM(100, 200, batch_first=True)
    self.fc = nn.Linear(200, vocab_size)

  def forward(self, x):
    embedded = self.embedding(x)
    intermediate_hidden_states, (final_hidden_state, final_cell_state) = self.lstm(embedded)
    output = self.fc(final_hidden_state.squeeze(0))
    return output

In [28]:
model = LSTMModel(len(vocab))

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [30]:
model.to(device)

LSTMModel(
  (embedding): Embedding(12475, 100, padding_idx=0)
  (lstm): LSTM(100, 200, batch_first=True)
  (fc): Linear(in_features=200, out_features=12475, bias=True)
)

In [31]:
epochs = 20
learning_rate = 0.0015

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [32]:
# training loop

for epoch in range(epochs):
  total_loss = 0

  for batch_x, batch_y in dataloader:

    batch_x, batch_y = batch_x.to(device), batch_y.to(device)

    optimizer.zero_grad()

    output = model(batch_x)

    loss = criterion(output, batch_y)

    loss.backward()

    optimizer.step()

    total_loss = total_loss + loss.item()

  total_loss = total_loss/len(dataloader)
  print(f"Epoch: {epoch + 1}, Loss: {total_loss:.4f}")

Epoch: 1, Loss: 0.2880
Epoch: 2, Loss: 0.2048
Epoch: 3, Loss: 0.1878
Epoch: 4, Loss: 0.1623
Epoch: 5, Loss: 0.1477
Epoch: 6, Loss: 0.1353
Epoch: 7, Loss: 0.1267
Epoch: 8, Loss: 0.1195
Epoch: 9, Loss: 0.1113
Epoch: 10, Loss: 0.1069
Epoch: 11, Loss: 0.1038
Epoch: 12, Loss: 0.0962
Epoch: 13, Loss: 0.0923
Epoch: 14, Loss: 0.0925
Epoch: 15, Loss: 0.0960
Epoch: 16, Loss: 0.0879
Epoch: 17, Loss: 0.0844
Epoch: 18, Loss: 0.0828
Epoch: 19, Loss: 0.0804
Epoch: 20, Loss: 0.0811


In [33]:
def prediction(model, vocab, text):

  # tokenize
  tokenized_text = text.lower().split()

  # text to numerical indices
  numerical_text = text_to_indices(tokenized_text, vocab)
  numerical_text = torch.tensor(numerical_text, dtype=torch.long).unsqueeze(0)

  # send to model
  output = model(numerical_text)

  # forbid <pad> and <unk>
  output[0, vocab['<pad>']] = -1e9
  output[0, vocab['<unk>']] = -1e9

  # predicted index
  value, index = torch.max(output, dim=1)

  # index to word
  idx_to_word = {i: w for w, i in vocab.items()}

  # merge with text
  return text + " " + idx_to_word[index.item()]


In [34]:
import time
text = "the season was"
for i in range(10):
  output_text = prediction(model, vocab, text)
  print(output_text)
  text = output_text
  time.sleep(0.5)

the season was built
the season was built the
the season was built the navy
the season was built the navy .
the season was built the navy . in
the season was built the navy . in the
the season was built the navy . in the country
the season was built the navy . in the country .
the season was built the navy . in the country . of
the season was built the navy . in the country . of the


In [35]:
# Function to calculate accuracy
def calculate_accuracy(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            # Get model predictions
            outputs = model(batch_x)

            # Get the predicted word indices
            _, predicted = torch.max(outputs, dim=1)

            # Compare with actual labels
            correct += (predicted == batch_y).sum().item()
            total += batch_y.size(0)

    accuracy = correct / total * 100
    return accuracy

# Compute accuracy
accuracy = calculate_accuracy(model, dataloader, device)
print(f"Model Accuracy: {accuracy:.2f}%")


Model Accuracy: 98.79%
