# Next word prediction using LSTM + Attention(Trial 1)

In [12]:
import re

In [13]:
def clean_text(text):
    text = text.lower()
    text = re.sub(r'["“”]', '"', text)
    text = re.sub(r"[’‘']", "'", text)
    text = re.sub(r'[—–]', '-', text)
    text = re.sub(r'\n+', ' ', text)
    text = re.sub(r'\t', ' ', text)
    text = re.sub(r' +', ' ', text)
    text = re.sub(r'[^\w\s\.\,\!\?\;\:\'\"\-\(\)]', '', text)
    return text.strip()

with open('sherlock_holmes.txt', 'r', encoding='utf-8') as f:
    raw_text = f.read()

start_idx = raw_text.find("*** START OF")
end_idx = raw_text.find("*** END OF")
text = raw_text[start_idx:end_idx] if start_idx != -1 and end_idx != -1 else raw_text
text = clean_text(text)

## Download the dataset

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import requests
import os
import re
from collections import Counter
from tqdm import tqdm

In [None]:
def download_sherlock_holmes():
    url = "https://www.gutenberg.org/files/1661/1661-0.txt"
    filename = "sherlock_holmes.txt" # store it for easy access

    if not os.path.exists(filename):
        print("Downloading Sherlock Holmes text...")
        r = requests.get(url, timeout=30)
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(r.text)
        print("Downloaded and saved as", filename)
    else:
        print("Using cached file:", filename)

In [32]:
def clean_text(text):
    text = text.lower()
    text = re.sub(r'["“”]', '"', text)
    text = re.sub(r"[’‘']", "'", text)
    text = re.sub(r'[—–]', '-', text)
    text = re.sub(r'\n+', ' ', text)
    text = re.sub(r'\t', ' ', text)
    text = re.sub(r' +', ' ', text)
    text = re.sub(r'[^\w\s\.\,\!\?\;\:\'\"\-\(\)]', '', text)
    return text.strip()

download_sherlock_holmes()

with open('sherlock_holmes.txt', 'r', encoding='utf-8') as f:
    raw_text = f.read()

start_idx = raw_text.find("*** START OF")
end_idx = raw_text.find("*** END OF")
text = raw_text[start_idx:end_idx] if start_idx != -1 and end_idx != -1 else raw_text
text = clean_text(text)


# Add special tokens for sequence control
START_TOKEN = '<START>'
EOS_TOKEN = '<EOS>'

def proper_sentence_split(text):
    """Better sentence splitting that handles punctuation correctly"""
    # Split on sentence-ending punctuation, but keep the punctuation
    sentences = re.split(r'([.!?]+)', text)

    # Recombine sentences with their punctuation
    result = []
    for i in range(0, len(sentences) - 1, 2):
        sentence = sentences[i].strip()
        punct = sentences[i + 1] if i + 1 < len(sentences) else ''
        if sentence:  # Only add non-empty sentences
            result.append(sentence + punct)

    return result

def create_sliding_window_sequences(text, window_size=10):
    """Create overlapping sequences using sliding window approach"""
    tokens = text.split()
    sequences = []

    # Create overlapping sequences across the entire text
    for i in range(len(tokens) - window_size + 1):
        sequence = tokens[i:i + window_size]
        sequences.append(sequence)

    return sequences

print("Creating vocabulary from tokens...")
tokens = text.split()
vocab = ['<PAD>', '<UNK>', START_TOKEN, EOS_TOKEN] + sorted(set(tokens))
word2idx = {word: i for i, word in enumerate(vocab)}
idx2word = {i: word for word, i in word2idx.items()}

# Get indices for special tokens
start_idx = word2idx[START_TOKEN]
eos_idx = word2idx[EOS_TOKEN]
print(f"START token index: {start_idx}, EOS token index: {eos_idx}")

print("Generating sequences with cross-sentence context...")

# Method 1: Sliding window across entire text (for cross-sentence learning)
sliding_sequences = create_sliding_window_sequences(text, window_size=15)

# Method 2: Proper sentence-based sequences with START/EOS tokens
sentences = proper_sentence_split(text)
sentence_sequences = []

for sentence in sentences:
    sentence_words = sentence.strip().split()
    if len(sentence_words) >= 2:  # Only process sentences with at least 2 words
        # Add START and EOS tokens
        tokenized = [START_TOKEN] + sentence_words + [EOS_TOKEN]
        # Create progressive sequences
        for i in range(2, len(tokenized) + 1):
            sentence_sequences.append(tokenized[:i])

print(f"Generated {len(sliding_sequences):,} sliding window sequences")
print(f"Generated {len(sentence_sequences):,} sentence-based sequences")

# Combine both approaches for richer training data
all_word_sequences = sliding_sequences + sentence_sequences

# Convert to indices
sequences = []
for seq in all_word_sequences:
    tokenized = [word2idx.get(w, word2idx['<UNK>']) for w in seq]
    sequences.append(tokenized)

print(f"Total sequences: {len(sequences):,}")
print("Sample sequences (first 5):")
for i, seq in enumerate(sequences[:5]):
    words = [idx2word[idx] for idx in seq]
    print(f"  {i+1}: {' '.join(words)}")

max_seq_len = min(50, max(len(seq) for seq in sequences))
sequences = [([0] * (max_seq_len - len(seq)) + seq)[-max_seq_len:] for seq in sequences]

sequences = np.array(sequences)
X, y = sequences[:, :-1], sequences[:, -1]
y = torch.tensor(y).long()

X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.1765, random_state=42)



class TextDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X).long()
        self.y = torch.tensor(y).long() if not isinstance(y, torch.Tensor) else y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_loader = DataLoader(TextDataset(X_train, y_train), batch_size=64, shuffle=True)
val_loader = DataLoader(TextDataset(X_val, y_val), batch_size=64)
test_loader = DataLoader(TextDataset(X_test, y_test), batch_size=64)


class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim, attention_units):
        super(AttentionLayer, self).__init__()
        self.attention_dense = nn.Linear(hidden_dim, attention_units)
        self.context_vector = nn.Linear(attention_units, 1, bias=False)

    def forward(self, lstm_out):
        score = torch.tanh(self.attention_dense(lstm_out))
        attention_weights = F.softmax(self.context_vector(score), dim=1)
        context_vector = attention_weights * lstm_out
        context_vector = torch.sum(context_vector, dim=1)
        return context_vector, attention_weights.squeeze(-1)

class LSTMAttentionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, lstm_units=100, attention_units=128):
        super(LSTMAttentionModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True, dropout=0.2)
        self.attention = AttentionLayer(lstm_units, attention_units)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(lstm_units, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        context, attention_weights = self.attention(lstm_out)
        out = self.dropout(context)
        return self.fc(out), attention_weights



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMAttentionModel(len(vocab)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []

for epoch in range(100):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, labels)  # labels are now integer indices, not one-hot
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
        total += inputs.size(0)

    train_loss = total_loss / total
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)  # labels are now integer indices, not one-hot
            val_loss += loss.item() * inputs.size(0)
            val_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
            val_total += inputs.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1:2d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")


def calculate_perplexity(model, data_loader, device):
    """Calculate perplexity on a dataset"""
    model.eval()
    total_loss = 0
    total_samples = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

    avg_loss = total_loss / total_samples
    perplexity = np.exp(avg_loss)
    return perplexity

def generate_text(model, start_text, max_length=50, temperature=1.0):
    """Generate text using the trained model, stopping at EOS token"""
    model.eval()

    # Tokenize start text
    tokens = start_text.lower().split()
    sequence = [start_idx]  # Start with START token

    # Add start text tokens
    for token in tokens:
        if token in word2idx:
            sequence.append(word2idx[token])
        else:
            sequence.append(word2idx['<UNK>'])

    generated_tokens = tokens.copy()

    with torch.no_grad():
        for _ in range(max_length):
            # Prepare input sequence
            input_seq = torch.tensor([sequence[-max_seq_len+1:]]).to(device)
            if len(sequence) < max_seq_len - 1:
                # Pad if sequence is shorter than expected
                padding = [0] * (max_seq_len - 1 - len(sequence))
                input_seq = torch.tensor([padding + sequence]).to(device)

            # Get prediction
            output, _ = model(input_seq)

            # Apply temperature sampling
            logits = output[0] / temperature
            probabilities = torch.softmax(logits, dim=0)

            # Sample next token
            next_token_idx = torch.multinomial(probabilities, 1).item()

            # Stop if EOS token is generated
            if next_token_idx == eos_idx:
                break

            # Add token to sequence
            sequence.append(next_token_idx)

            # Convert to word and add to generated text
            if next_token_idx in idx2word:
                word = idx2word[next_token_idx]
                if word not in ['<PAD>', '<UNK>', START_TOKEN]:
                    generated_tokens.append(word)

    return ' '.join(generated_tokens)

print("\n" + "="*50)
print("📊 MODEL EVALUATION")
print("="*50)

# Calculate perplexity for all datasets
train_perplexity = calculate_perplexity(model, train_loader, device)
val_perplexity = calculate_perplexity(model, val_loader, device)
test_perplexity = calculate_perplexity(model, test_loader, device)

print(f"🎯 PERPLEXITY SCORES:")
print(f"  📈 Training Perplexity: {train_perplexity:.2f}")
print(f"  📈 Validation Perplexity: {val_perplexity:.2f}")
print(f"  📈 Test Perplexity: {test_perplexity:.2f}")

print(f"\n🔮 TEXT GENERATION EXAMPLES:")
print("-" * 40)

# Test text generation
test_prompts = [
    "Holmes said",
    "Watson was",
    "The detective",
    "I saw",
    "It was a dark"
]

for prompt in test_prompts:
    generated = generate_text(model, prompt, max_length=20, temperature=0.8)
    print(f"Prompt: '{prompt}'")
    print(f"Generated: '{generated}'")
    print("-" * 40)

print(f"\n✅ SUMMARY:")
print(f"  🎯 Best Test Perplexity: {test_perplexity:.2f}")
print(f"  🔤 Vocabulary Size: {len(vocab):,}")
print(f"  🏁 START/EOS tokens: Implemented")
print(f"  📊 Label Encoding: Used (instead of one-hot)")
print("="*50)



✓ Using cached file: sherlock_holmes.txt
🔧 Creating vocabulary from tokens...
START token index: 2, EOS token index: 3
🔧 Generating sequences with cross-sentence context...
Generated 104,462 sliding window sequences
Generated 113,459 sentence-based sequences
Total sequences: 217,921
Sample sequences (first 5):
  1: start of the project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock
  2: of the project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes
  3: the project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes by
  4: project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes by arthur
  5: gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes by arthur conan




Epoch  1 | Train Loss: 6.7340 | Train Acc: 0.0794 | Val Loss: 6.3409 | Val Acc: 0.1089
Epoch  2 | Train Loss: 6.0277 | Train Acc: 0.1281 | Val Loss: 5.9546 | Val Acc: 0.1413
Epoch  3 | Train Loss: 5.5935 | Train Acc: 0.1515 | Val Loss: 5.7317 | Val Acc: 0.1543
Epoch  4 | Train Loss: 5.2343 | Train Acc: 0.1672 | Val Loss: 5.5857 | Val Acc: 0.1665
Epoch  5 | Train Loss: 4.9163 | Train Acc: 0.1814 | Val Loss: 5.4564 | Val Acc: 0.1754
Epoch  6 | Train Loss: 4.6337 | Train Acc: 0.1979 | Val Loss: 5.3674 | Val Acc: 0.1798
Epoch  7 | Train Loss: 4.3812 | Train Acc: 0.2162 | Val Loss: 5.2944 | Val Acc: 0.1866
Epoch  8 | Train Loss: 4.1515 | Train Acc: 0.2372 | Val Loss: 5.2139 | Val Acc: 0.1913
Epoch  9 | Train Loss: 3.9431 | Train Acc: 0.2582 | Val Loss: 5.1553 | Val Acc: 0.1977
Epoch 10 | Train Loss: 3.7612 | Train Acc: 0.2792 | Val Loss: 5.1113 | Val Acc: 0.2028
Epoch 11 | Train Loss: 3.5949 | Train Acc: 0.3000 | Val Loss: 5.0390 | Val Acc: 0.2102
Epoch 12 | Train Loss: 3.4452 | Train Acc: 

**As we can see from this experiment, even though initially the validation loss had began to lower, it soon started oscillating and ended up rising. This is mostly due to overfiting. The model maybe too complex to learn from such a small corpus.**

# Stacked LSTM with Attention

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import requests
import os
import re
from collections import Counter
from tqdm import tqdm

# -----------------------------
# 1. Download & Preprocess Text
# -----------------------------

def download_sherlock_holmes():
    url = "https://www.gutenberg.org/files/1661/1661-0.txt"
    filename = "sherlock_holmes.txt"

    if not os.path.exists(filename):
        print("📥 Downloading Sherlock Holmes text...")
        r = requests.get(url, timeout=30)
        with open(filename, 'w', encoding='utf-8') as f:
            f.write(r.text)
        print("✅ Downloaded and saved as", filename)
    else:
        print("✓ Using cached file:", filename)

def clean_text(text):
    text = text.lower()
    text = re.sub(r'["“”]', '"', text)
    text = re.sub(r"[’‘']", "'", text)
    text = re.sub(r'[—–]', '-', text)
    text = re.sub(r'\n+', ' ', text)
    text = re.sub(r'\t', ' ', text)
    text = re.sub(r' +', ' ', text)
    text = re.sub(r'[^\w\s\.\,\!\?\;\:\'\"\-\(\)]', '', text)
    return text.strip()

download_sherlock_holmes()

with open('sherlock_holmes.txt', 'r', encoding='utf-8') as f:
    raw_text = f.read()

start_idx = raw_text.find("*** START OF")
end_idx = raw_text.find("*** END OF")
text = raw_text[start_idx:end_idx] if start_idx != -1 and end_idx != -1 else raw_text
text = clean_text(text)

# -----------------------------
# 2. Tokenization & Sequence Creation
# -----------------------------

import re

# Add special tokens for sequence control
START_TOKEN = '<START>'
EOS_TOKEN = '<EOS>'

def proper_sentence_split(text):
    """Better sentence splitting that handles punctuation correctly"""
    # Split on sentence-ending punctuation, but keep the punctuation
    sentences = re.split(r'([.!?]+)', text)

    # Recombine sentences with their punctuation
    result = []
    for i in range(0, len(sentences) - 1, 2):
        sentence = sentences[i].strip()
        punct = sentences[i + 1] if i + 1 < len(sentences) else ''
        if sentence:  # Only add non-empty sentences
            result.append(sentence + punct)

    return result

def create_sliding_window_sequences(text, window_size=10):
    """Create overlapping sequences using sliding window approach"""
    tokens = text.split()
    sequences = []

    # Create overlapping sequences across the entire text
    for i in range(len(tokens) - window_size + 1):
        sequence = tokens[i:i + window_size]
        sequences.append(sequence)

    return sequences

print("🔧 Creating vocabulary from tokens...")
tokens = text.split()
vocab = ['<PAD>', '<UNK>', START_TOKEN, EOS_TOKEN] + sorted(set(tokens))
word2idx = {word: i for i, word in enumerate(vocab)}
idx2word = {i: word for word, i in word2idx.items()}

# Get indices for special tokens
start_idx = word2idx[START_TOKEN]
eos_idx = word2idx[EOS_TOKEN]
print(f"START token index: {start_idx}, EOS token index: {eos_idx}")

print("🔧 Generating sequences with cross-sentence context...")

# Method 1: Sliding window across entire text (for cross-sentence learning)
sliding_sequences = create_sliding_window_sequences(text, window_size=15)

# Method 2: Proper sentence-based sequences with START/EOS tokens
sentences = proper_sentence_split(text)
sentence_sequences = []

for sentence in sentences:
    sentence_words = sentence.strip().split()
    if len(sentence_words) >= 2:  # Only process sentences with at least 2 words
        # Add START and EOS tokens
        tokenized = [START_TOKEN] + sentence_words + [EOS_TOKEN]
        # Create progressive sequences
        for i in range(2, len(tokenized) + 1):
            sentence_sequences.append(tokenized[:i])

print(f"Generated {len(sliding_sequences):,} sliding window sequences")
print(f"Generated {len(sentence_sequences):,} sentence-based sequences")

# Combine both approaches for richer training data
all_word_sequences = sliding_sequences + sentence_sequences

# Convert to indices
sequences = []
for seq in all_word_sequences:
    tokenized = [word2idx.get(w, word2idx['<UNK>']) for w in seq]
    sequences.append(tokenized)

print(f"Total sequences: {len(sequences):,}")
print("Sample sequences (first 5):")
for i, seq in enumerate(sequences[:5]):
    words = [idx2word[idx] for idx in seq]
    print(f"  {i+1}: {' '.join(words)}")

max_seq_len = min(50, max(len(seq) for seq in sequences))
sequences = [([0] * (max_seq_len - len(seq)) + seq)[-max_seq_len:] for seq in sequences]

sequences = np.array(sequences)
X, y = sequences[:, :-1], sequences[:, -1]
# Keep y as label encoding (integer labels) instead of one-hot encoding
y = torch.tensor(y).long()

X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.1765, random_state=42)

# -----------------------------
# 3. Dataset & DataLoader
# -----------------------------

class TextDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X).long()
        self.y = torch.tensor(y).long() if not isinstance(y, torch.Tensor) else y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_loader = DataLoader(TextDataset(X_train, y_train), batch_size=64, shuffle=True)
val_loader = DataLoader(TextDataset(X_val, y_val), batch_size=64)
test_loader = DataLoader(TextDataset(X_test, y_test), batch_size=64)

# -----------------------------
# 4. Model Definition
# -----------------------------

class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim, attention_units):
        super(AttentionLayer, self).__init__()
        self.attention_dense = nn.Linear(hidden_dim, attention_units)
        self.context_vector = nn.Linear(attention_units, 1, bias=False)

    def forward(self, lstm_out):
        score = torch.tanh(self.attention_dense(lstm_out))
        attention_weights = torch.softmax(self.context_vector(score), dim=1)
        context_vector = attention_weights * lstm_out
        context_vector = torch.sum(context_vector, dim=1)
        return context_vector, attention_weights.squeeze(-1)

class LSTMAttentionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, lstm_units=100, attention_units=64):
        super(LSTMAttentionModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm1 = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
        self.lstm2 = nn.LSTM(lstm_units, lstm_units, batch_first=True)
        self.attention = AttentionLayer(lstm_units, attention_units)
        self.layer_norm = nn.LayerNorm(lstm_units)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(lstm_units, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = self.layer_norm(x)
        context, attention_weights = self.attention(x)
        out = self.dropout(context)
        return self.fc(out), attention_weights
# -----------------------------
# 5. Training Loop
# -----------------------------

import torch.nn.utils as nn_utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMAttentionModel(len(vocab)).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True,
    min_lr=1e-6
)

train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []

for epoch in range(100):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping: clip gradients to max norm 1.0
        nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
        total += inputs.size(0)

    train_loss = total_loss / total
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            val_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
            val_total += inputs.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1:2d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Step scheduler with validation loss
    scheduler.step(val_loss)


# Save Model
torch.save(model.state_dict(), "sherlock_lstm_attention_pytorch.pth")
print("Model saved as 'sherlock_lstm_attention_pytorch_stacked.pth'")



✓ Using cached file: sherlock_holmes.txt
🔧 Creating vocabulary from tokens...
START token index: 2, EOS token index: 3
🔧 Generating sequences with cross-sentence context...
Generated 104,462 sliding window sequences
Generated 113,459 sentence-based sequences
Total sequences: 217,921
Sample sequences (first 5):
  1: start of the project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock
  2: of the project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes
  3: the project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes by
  4: project gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes by arthur
  5: gutenberg ebook the adventures of sherlock holmes the adventures of sherlock holmes by arthur conan
Epoch  1 | Train Loss: 7.1975 | Train Acc: 0.0683 | Val Loss: 7.0215 | Val Acc: 0.0733
Epoch  2 | Train Loss: 6.7649 | Train Acc: 0.1010 | Val Loss: 6.5504 | Val Acc

In [45]:
torch.save(model.state_dict(), "sherlock_lstm_attention_pytorch_stacked.pth")


In [39]:
def generate_text(model, seed_text, max_length=50, temperature=1.0):
    """
    Generate text from a seed text by iteratively predicting next tokens.
    For each predicted token, print the top 5 most likely next tokens with probabilities.

    Args:
        model: Trained LSTM attention model
        seed_text (str): Initial text prompt to start generation
        max_length (int): Maximum length of generated tokens (including seed)
        temperature (float): Sampling temperature for controlling randomness

    Returns:
        str: Generated text sequence including the seed_text
    """
    model.eval()

    # Tokenize seed text
    tokens = seed_text.lower().split()
    sequence = [start_idx]  # Start with <START> token

    # Map seed tokens to indices (using <UNK> if not found)
    for token in tokens:
        sequence.append(word2idx.get(token, word2idx['<UNK>']))

    generated_tokens = tokens.copy()

    with torch.no_grad():
        for _ in range(max_length - len(sequence) + 1):  # Adjust for seed length
            # Prepare input: last max_seq_len tokens, padded if needed
            if len(sequence) < max_seq_len:
                padding = [0] * (max_seq_len - len(sequence))
                input_seq = torch.tensor([padding + sequence]).to(device)
            else:
                input_seq = torch.tensor([sequence[-max_seq_len:]]).to(device)

            # Get model output logits
            output, _ = model(input_seq)
            logits = output[0] / temperature
            probs = torch.softmax(logits, dim=0)

            # Get top 5 predicted tokens and probabilities
            top5_probs, top5_idx = torch.topk(probs, 5)
            top5_words = [idx2word[idx.item()] for idx in top5_idx]

            # Print top 5 predictions with probabilities
            print(f"Top 5 next words: {[(w, float(p)) for w, p in zip(top5_words, top5_probs)]}")

            # Sample next token from probability distribution
            next_token_idx = torch.multinomial(probs, 1).item()

            # Stop if EOS token generated
            if next_token_idx == eos_idx:
                break

            # Append predicted token to sequence and generated output tokens
            sequence.append(next_token_idx)
            next_word = idx2word.get(next_token_idx, '<UNK>')

            # Avoid adding special tokens to generated text output
            if next_word not in ['<PAD>', '<UNK>', START_TOKEN]:
                generated_tokens.append(next_word)

    return ' '.join(generated_tokens)


In [52]:
# For trial
def generate_text_beam_search(model, seed_text, max_length=50, beam_width=3):
    """
    Generate text using beam search instead of greedy decoding or sampling.

    Args:
        model: Trained model
        seed_text (str): Initial seed prompt
        max_length (int): Max number of tokens in the output
        beam_width (int): Number of beams to keep at each step

    Returns:
        str: Best generated sequence
    """
    model.eval()
    tokens = seed_text.lower().split()

    # Initial sequence
    initial_sequence = [start_idx] + [word2idx.get(token, word2idx['<UNK>']) for token in tokens]
    initial_sequence = initial_sequence[-max_seq_len:]

    # Pad if needed
    if len(initial_sequence) < max_seq_len:
        initial_sequence = [0] * (max_seq_len - len(initial_sequence)) + initial_sequence

    # Beams are tuples of (sequence, log_prob)
    beams = [(initial_sequence, 0)]

    with torch.no_grad():
        for _ in range(max_length):
            all_candidates = []

            for seq, score in beams:
                input_seq = torch.tensor([seq[-max_seq_len:]]).to(device)
                output, _ = model(input_seq)
                logits = output[0]
                log_probs = torch.log_softmax(logits, dim=0)

                # Get top `beam_width` next tokens
                top_log_probs, top_indices = torch.topk(log_probs, beam_width)

                for i in range(beam_width):
                    next_token = top_indices[i].item()
                    next_score = score + top_log_probs[i].item()
                    new_seq = seq + [next_token]

                    if next_token == eos_idx:
                        # Stop early if EOS token predicted
                        all_candidates.append((new_seq, next_score))
                    else:
                        all_candidates.append((new_seq, next_score))

            # Keep only top beam_width candidates
            beams = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)[:beam_width]

            # Early stopping if all beams have ended in EOS
            if all(seq[-1] == eos_idx for seq, _ in beams):
                break

    # Choose the best beam
    best_sequence, _ = beams[0]

    # Convert indices to tokens, skipping special tokens
    generated_tokens = [
        idx2word[idx] for idx in best_sequence
        if idx not in [word2idx.get('<PAD>', -1), word2idx.get('<UNK>', -1), start_idx, eos_idx]
    ]

    return ' '.join(generated_tokens)


Trying with beam search

In [54]:
test_prompts = [
    "Holmes is ",
    "Watson was",
    "The detective",
    "I saw",
    "It was a dark"
]

for prompt in test_prompts:
    generated = generate_text_beam_search(model, prompt, max_length=20)
    print(f"Prompt: '{prompt}'")
    print(f"Generated: '{generated}'")
    print("-" * 40)

Prompt: 'Holmes is '
Generated: 'holmes is a very pretty girl and has given him out of the room.'
----------------------------------------
Prompt: 'Watson was'
Generated: 'watson was so then, that i should be able to post me up. too.'
----------------------------------------
Prompt: 'The detective'
Generated: 'the detective which i have already made up my mind that i had returned to civil practice and took a few words'
----------------------------------------
Prompt: 'I saw'
Generated: 'i saw that it was the most preposterous position which had ever seen that she had left the young man had left'
----------------------------------------
Prompt: 'It was a dark'
Generated: 'it was a dark that i had not yet returned.'
----------------------------------------


The sentences do not make much more sense than the ones with greedy.


Now, we try the greedy decoding whose results I have pasted in the notebook.

In [48]:
test_prompts = [
    "Holmes is ",
    "Watson was",
    "The detective",
    "I saw",
    "It was a dark"
]

for prompt in test_prompts:
    generated = generate_text(model, prompt, max_length=20, temperature=0.8)
    print(f"Prompt: '{prompt}'")
    print(f"Generated: '{generated}'")
    print("-" * 40)

Top 5 next words: [('a', 0.5823346972465515), ('the', 0.066647469997406), ('that', 0.045102473348379135), ('an', 0.03783230856060982), ('very', 0.031026897951960564)]
Top 5 next words: [('very', 0.26571258902549744), ('little', 0.19101107120513916), ('small', 0.12731096148490906), ('man', 0.08128730207681656), ('fierce', 0.0749133825302124)]
Top 5 next words: [('heavy', 0.22029222548007965), ('pretty', 0.19918301701545715), ('little', 0.19044387340545654), ('large', 0.06172473356127739), ('very', 0.037800583988428116)]
Top 5 next words: [('and', 0.8441948294639587), ('between', 0.062399979680776596), ('sleeper,', 0.04337615519762039), ('which', 0.01718759350478649), ('with', 0.007617570459842682)]
Top 5 next words: [('darkness', 0.13406139612197876), ('walked', 0.10114217549562454), ('heavy', 0.06598377972841263), ('held', 0.018099963665008545), ('iron', 0.017817718908190727)]
Top 5 next words: [('the', 0.4372996985912323), ('a', 0.1787565052509308), ('his', 0.04859798401594162), ('him

In [58]:
# Calculate perplexity for all datasets
train_perplexity = calculate_perplexity(model, train_loader, device)
val_perplexity = calculate_perplexity(model, val_loader, device)
test_perplexity = calculate_perplexity(model, test_loader, device)

print(f"PERPLEXITY SCORES:")
print(f"Training Perplexity: {train_perplexity:.2f}")
print(f"Validation Perplexity: {val_perplexity:.2f}")
print(f"Test Perplexity: {test_perplexity:.2f}")

PERPLEXITY SCORES:
Training Perplexity: 4.07
Validation Perplexity: 80.79
Test Perplexity: 82.58


In [65]:

print(f"\nTEXT GENERATION EXAMPLES:")
print("-" * 40)

# Test text generation
test_prompts = [
    "Holmes said",
    "Watson was",
    "The detective",
    "I saw",
    "It was a dark"
]

for prompt in test_prompts:
    generated = generate_text(model, prompt, max_length=20, temperature=0.8)
    print(f"Prompt: '{prompt}'")
    print(f"Generated: '{generated}'")
    print("-" * 40)




TEXT GENERATION EXAMPLES:
----------------------------------------
Prompt: 'Holmes said'
Generated: 'holmes said that we were engaged in the cellar, and so he stood with a one of a small one, and was'
----------------------------------------
Prompt: 'Watson was'
Generated: 'watson was not'
----------------------------------------
Prompt: 'The detective'
Generated: 'the detective bag with a long made a man in a low voice, it was a something which that was a week'
----------------------------------------
Prompt: 'I saw'
Generated: 'i saw that it would swim sink. i rate at my'
----------------------------------------
Prompt: 'It was a dark'
Generated: 'it was a dark little by that that she had left me, but there i was a marriage to have stiff." and the more'
----------------------------------------


In [50]:
def calculate_accuracy(model, data_loader, device):
    """Calculate top-1 accuracy on a dataset"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)  # [batch_size, vocab_size]
            predictions = torch.argmax(outputs, dim=1)  # [batch_size]
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return accuracy

In [66]:
# Calculate accuracy on test set
test_accuracy = calculate_accuracy(model, test_loader, device)

print(f"TEST ACCURACY:")
print(f"Top-1 Accuracy: {test_accuracy * 100:.2f}%")


TEST ACCURACY:
Top-1 Accuracy: 29.30%


Evaluating later, from a loaded model

In [55]:
# Load the model state dictionary
model = LSTMAttentionModel(len(vocab)).to(device) # Re-instantiate the model
model.load_state_dict(torch.load("sherlock_lstm_attention_pytorch_stacked.pth"))
model.eval() # Set the model to evaluation mode

print("✅ Model loaded successfully from 'sherlock_lstm_attention_pytorch_stacked.pth'")

✅ Model loaded successfully from 'sherlock_lstm_attention_pytorch_stacked.pth'


In [63]:
seed_text = "It was dark"
generated_text = generate_text(model, seed_text, max_length=20, temperature=0.7)
print(f"\nGenerated text: {generated_text}")



Generated text: it was dark when he went up to the other us in the name of the flight, but at the all he was
