In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import random


In [35]:
# Load the dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Extract text data
train_texts = dataset["train"]["text"]
valid_texts = dataset["validation"]["text"]
test_texts = dataset["test"]["text"]


In [37]:
def preprocess_texts(texts):
    processed_texts = []
    for text in texts:
        if text.strip():  # Ignore empty strings
            processed_texts.append(text.lower().split())
    return processed_texts

train_texts = preprocess_texts(train_texts)
valid_texts = preprocess_texts(valid_texts)
test_texts = preprocess_texts(test_texts)


In [39]:
def build_vocab(texts, max_vocab_size=20000):
    all_words = [word for text in texts for word in text]
    word_counts = {}
    
    for word in all_words:
        word_counts[word] = word_counts.get(word, 0) + 1

    sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
    
    vocab = {'<PAD>': 0, '<UNK>': 1, '<START>': 2, '<END>': 3}
    for word, _ in sorted_words[:max_vocab_size - 4]:
        vocab[word] = len(vocab)

    return vocab

vocab = build_vocab(train_texts + valid_texts)


In [41]:
class WikiTextDataset(Dataset):
    def __init__(self, texts, seq_length, vocab):
        self.sequences = []
        self.targets = []

        for text in texts:
            for i in range(len(text) - seq_length):
                seq = text[i:i+seq_length]
                target = text[i+seq_length]
                
                seq_indices = [vocab.get(word, vocab['<UNK>']) for word in seq]
                target_index = vocab.get(target, vocab['<UNK>'])
                
                self.sequences.append(seq_indices)
                self.targets.append(target_index)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx]), torch.tensor(self.targets[idx])

seq_length = 5  # Number of words in input sequence
train_dataset = WikiTextDataset(train_texts, seq_length, vocab)
valid_dataset = WikiTextDataset(valid_texts, seq_length, vocab)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32)


In [45]:
def train_model(model, train_loader, valid_loader, criterion, optimizer, device, epochs=5):
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        print(f"\n🚀 Starting Epoch {epoch+1}/{epochs}") 
        for sequences, targets in train_loader:
            sequences, targets = sequences.to(device), targets.to(device)
            
            optimizer.zero_grad()
            predictions = model(sequences)
            loss = criterion(predictions, targets)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
        
        # Validation Step
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for sequences, targets in valid_loader:
                sequences, targets = sequences.to(device), targets.to(device)
                predictions = model(sequences)
                loss = criterion(predictions, targets)
                total_val_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{epochs}: Training Loss: {total_train_loss/len(train_loader):.4f}, Validation Loss: {total_val_loss/len(valid_loader):.4f}')


In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model hyperparameters
vocab_size = len(vocab)
embedding_dim = 100
hidden_dim = 256
num_layers = 2

# Initialize model
model = LSTMLanguageModel(vocab_size, embedding_dim, hidden_dim, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_model(model, train_loader, valid_loader, criterion, optimizer, device, epochs=5)


Epoch 1/5: Training Loss: 5.9384, Validation Loss: 5.9379
Epoch 2/5: Training Loss: 5.6177, Validation Loss: 5.8742
Epoch 3/5: Training Loss: 5.4888, Validation Loss: 5.8345
Epoch 4/5: Training Loss: 5.3972, Validation Loss: 5.8200
Epoch 5/5: Training Loss: 5.3327, Validation Loss: 5.8332


In [48]:
def predict_next_word(model, sequence, vocab, device, top_k=5):
    model.eval()
    
    seq_indices = [vocab.get(word, vocab['<UNK>']) for word in sequence]
    tensor = torch.LongTensor(seq_indices).unsqueeze(0).to(device)
    
    with torch.no_grad():
        predictions = model(tensor)
    
    top_k_indices = torch.topk(predictions, top_k).indices[0]
    top_k_words = [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in top_k_indices]
    
    return top_k_words

# Example Predictions
test_sequences = [["the", "king", "of", "the", "jungle"], ["deep", "learning", "is", "a", "subset"]]

for sequence in test_sequences:
    predictions = predict_next_word(model, sequence, vocab, device)
    print(f"Input: {' '.join(sequence)}")
    print(f"Predicted words: {predictions}\n")


Input: the king of the jungle
Predicted words: [',', '.', 'of', 'and', "'s"]

Input: deep learning is a subset
Predicted words: [',', 'of', '.', 'and', '@-@']

