In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np
from torch.utils.data import Dataset, DataLoader

# -------------------------------
# Data Preparation
# -------------------------------

text = """Next character prediction is a fundamental task in the field of natural language processing (NLP) that involves predicting the next character in a sequence of text based on the characters that precede it. This task is essential for various applications, including text auto-completion, spell checking, and even in the development of sophisticated AI models capable of generating human-like text.

At its core, next character prediction relies on statistical models or deep learning algorithms to analyze a given sequence of text and predict which character is most likely to follow. These predictions are based on patterns and relationships learned from large datasets of text during the training phase of the model.

One of the most popular approaches to next character prediction involves the use of Recurrent Neural Networks (RNNs), and more specifically, a variant called Long Short-Term Memory (LSTM) networks. RNNs are particularly well-suited for sequential data like text, as they can maintain information in 'memory' about previous characters to inform the prediction of the next character. LSTM networks enhance this capability by being able to remember long-term dependencies, making them even more effective for next character prediction tasks.

Training a model for next character prediction involves feeding it large amounts of text data, allowing it to learn the probability of each character's appearance following a sequence of characters. During this training process, the model adjusts its parameters to minimize the difference between its predictions and the actual outcomes, thus improving its predictive accuracy over time.

Once trained, the model can be used to predict the next character in a given piece of text by considering the sequence of characters that precede it. This can enhance user experience in text editing software, improve efficiency in coding environments with auto-completion features, and enable more natural interactions with AI-based chatbots and virtual assistants.

In summary, next character prediction plays a crucial role in enhancing the capabilities of various NLP applications, making text-based interactions more efficient, accurate, and human-like. Through the use of advanced machine learning models like RNNs and LSTMs, next character prediction continues to evolve, opening new possibilities for the future of text-based technology."""

# Build vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for i, ch in enumerate(chars)}

def encode(text):
    return [char2idx[c] for c in text]

class CharDataset(Dataset):
    """
    Dataset for next character prediction. For each index, we return a sequence
    of characters (as indices) and the target character (the character right after the sequence).
    """
    def __init__(self, text, seq_length):
        self.seq_length = seq_length
        self.data = encode(text)
    
    def __len__(self):
        return len(self.data) - self.seq_length
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx: idx + self.seq_length], dtype=torch.long)
        y = torch.tensor(self.data[idx + self.seq_length], dtype=torch.long)
        return x, y

# -------------------------------
# Model Definitions
# -------------------------------

# --- Transformer Model ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=0.1)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2, dim_feedforward=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.d_model = d_model
        self.fc = nn.Linear(d_model, vocab_size)
        
    def forward(self, src):
        # src: (batch, seq_length)
        src = self.embedding(src) * np.sqrt(self.d_model)
        src = self.pos_encoder(src)
        # Transformer expects (seq_length, batch, d_model)
        src = src.transpose(0, 1)
        output = self.transformer_encoder(src)
        # Use the last token's output for prediction
        out = output[-1, :, :]  # (batch, d_model)
        out = self.fc(out)
        return out

# --- Standard LSTM Model ---
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=128, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x):
        # x: (batch, seq_length)
        x = self.embedding(x)
        output, _ = self.lstm(x)
        # Use last output state for prediction
        out = self.fc(output[:, -1, :])
        return out

# --- LSTM with Cross-Attention Model ---
class LSTMWithAttentionModel(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=128, num_layers=2, nhead=4, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        # MultiheadAttention expects inputs as (batch, seq_length, hidden_size) when batch_first=True.
        self.attention = nn.MultiheadAttention(hidden_size, nhead, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x):
        # x: (batch, seq_length)
        emb = self.embedding(x)
        lstm_out, _ = self.lstm(emb)  # (batch, seq_length, hidden_size)
        # Use the last hidden state as query (unsqueeze to get sequence length=1)
        query = lstm_out[:, -1:, :]  # (batch, 1, hidden_size)
        # keys and values are the LSTM outputs
        attn_output, _ = self.attention(query, lstm_out, lstm_out)
        attn_output = attn_output.squeeze(1)  # (batch, hidden_size)
        out = self.fc(attn_output)
        return out

# -------------------------------
# Training and Evaluation Functions
# -------------------------------

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            total_loss += loss.item() * x.size(0)
            preds = output.argmax(dim=1)
            correct += (preds == y).sum().item()
    accuracy = correct / len(dataloader.dataset)
    return total_loss / len(dataloader.dataset), accuracy

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# -------------------------------
# Experiment Function
# -------------------------------

def run_experiment(model_class, seq_length, device, **model_kwargs):
    # Create dataset and split (80% training, 20% validation)
    dataset = CharDataset(text, seq_length)
    n_train = int(0.8 * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    model = model_class(vocab_size, **model_kwargs).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    n_epochs = 10  # adjust the number of epochs as needed
    start_time = time.time()
    for epoch in range(n_epochs):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
        print(f"Epoch {epoch+1}/{n_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Val Acc: {val_accuracy:.4f}")
    elapsed_time = time.time() - start_time
    param_count = count_parameters(model)
    return train_loss, val_accuracy, elapsed_time, param_count

# -------------------------------
# Main Experiment Loop
# -------------------------------

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sequence_lengths = [10, 20, 30]
    
    results = {}
    
    for seq_length in sequence_lengths:
        print(f"\n=== Sequence Length: {seq_length} ===")
        results[seq_length] = {}
        
        # Transformer Model
        print("\n-- Transformer Model --")
        t_loss, t_acc, t_time, t_params = run_experiment(
            TransformerModel, seq_length, device,
            d_model=128, nhead=4, num_layers=2, dim_feedforward=256, dropout=0.1
        )
        results[seq_length]['Transformer'] = {
            'Train Loss': t_loss,
            'Val Accuracy': t_acc,
            'Time (s)': t_time,
            'Parameters': t_params
        }
        
        # Standard LSTM Model
        print("\n-- LSTM Model --")
        l_loss, l_acc, l_time, l_params = run_experiment(
            LSTMModel, seq_length, device,
            embed_size=128, hidden_size=128, num_layers=2, dropout=0.1
        )
        results[seq_length]['LSTM'] = {
            'Train Loss': l_loss,
            'Val Accuracy': l_acc,
            'Time (s)': l_time,
            'Parameters': l_params
        }
        
        # LSTM with Cross-Attention Model
        print("\n-- LSTM with Cross-Attention Model --")
        la_loss, la_acc, la_time, la_params = run_experiment(
            LSTMWithAttentionModel, seq_length, device,
            embed_size=128, hidden_size=128, num_layers=2, nhead=4, dropout=0.1
        )
        results[seq_length]['LSTM_Attn'] = {
            'Train Loss': la_loss,
            'Val Accuracy': la_acc,
            'Time (s)': la_time,
            'Parameters': la_params
        }
    
    # -------------------------------
    # Reporting Results
    # -------------------------------
    print("\n=== Summary of Results ===")
    for seq_length, models in results.items():
        print(f"\nSequence Length: {seq_length}")
        for model_type, metrics in models.items():
            print(f"{model_type}: Loss={metrics['Train Loss']:.4f}, Val Acc={metrics['Val Accuracy']:.4f}, Time={metrics['Time (s)']:.2f}s, Params={metrics['Parameters']}")



=== Sequence Length: 10 ===

-- Transformer Model --


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/10 - Train Loss: 2.9157 - Val Loss: 2.5738 - Val Acc: 0.2725
Epoch 2/10 - Train Loss: 2.4867 - Val Loss: 2.4235 - Val Acc: 0.2809
Epoch 3/10 - Train Loss: 2.3430 - Val Loss: 2.3759 - Val Acc: 0.3187
Epoch 4/10 - Train Loss: 2.2437 - Val Loss: 2.3469 - Val Acc: 0.3375
Epoch 5/10 - Train Loss: 2.1544 - Val Loss: 2.3629 - Val Acc: 0.2956
Epoch 6/10 - Train Loss: 2.0942 - Val Loss: 2.3632 - Val Acc: 0.3229
Epoch 7/10 - Train Loss: 2.0102 - Val Loss: 2.3593 - Val Acc: 0.3229
Epoch 8/10 - Train Loss: 1.9641 - Val Loss: 2.3928 - Val Acc: 0.3333
Epoch 9/10 - Train Loss: 1.8931 - Val Loss: 2.4157 - Val Acc: 0.3145
Epoch 10/10 - Train Loss: 1.8604 - Val Loss: 2.3679 - Val Acc: 0.3396

-- LSTM Model --
Epoch 1/10 - Train Loss: 3.1897 - Val Loss: 2.9889 - Val Acc: 0.1488
Epoch 2/10 - Train Loss: 2.8849 - Val Loss: 2.7496 - Val Acc: 0.2746
Epoch 3/10 - Train Loss: 2.5786 - Val Loss: 2.5321 - Val Acc: 0.3061
Epoch 4/10 - Train Loss: 2.3375 - Val Loss: 2.4302 - Val Acc: 0.3501
Epoch 5/10 - Tr



Epoch 1/10 - Train Loss: 2.9305 - Val Loss: 2.5391 - Val Acc: 0.2653
Epoch 2/10 - Train Loss: 2.5344 - Val Loss: 2.4950 - Val Acc: 0.2379
Epoch 3/10 - Train Loss: 2.4194 - Val Loss: 2.4219 - Val Acc: 0.2674
Epoch 4/10 - Train Loss: 2.3650 - Val Loss: 2.4206 - Val Acc: 0.2926
Epoch 5/10 - Train Loss: 2.2951 - Val Loss: 2.4283 - Val Acc: 0.2737
Epoch 6/10 - Train Loss: 2.2667 - Val Loss: 2.3851 - Val Acc: 0.2989
Epoch 7/10 - Train Loss: 2.1993 - Val Loss: 2.4233 - Val Acc: 0.2884
Epoch 8/10 - Train Loss: 2.1511 - Val Loss: 2.3745 - Val Acc: 0.2989
Epoch 9/10 - Train Loss: 2.1070 - Val Loss: 2.4130 - Val Acc: 0.2947
Epoch 10/10 - Train Loss: 2.0837 - Val Loss: 2.4203 - Val Acc: 0.2968

-- LSTM Model --
Epoch 1/10 - Train Loss: 3.1892 - Val Loss: 2.9880 - Val Acc: 0.1726
Epoch 2/10 - Train Loss: 2.8191 - Val Loss: 2.6932 - Val Acc: 0.2758
Epoch 3/10 - Train Loss: 2.5061 - Val Loss: 2.4904 - Val Acc: 0.3242
Epoch 4/10 - Train Loss: 2.2735 - Val Loss: 2.3527 - Val Acc: 0.3516
Epoch 5/10 - Tr