# Phase 2: LSTM Text Generation with PyTorch

In this notebook, we will train a Long Short-Term Memory (LSTM) network to generate text in the style of Donald Trump.
We will use the `TextDataset` class we created in `src/nlp/data_loader.py`.

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

sys.path.append(os.path.abspath('../../src'))

from nlp.data_loader import TextDataset, get_vocab_size

In [None]:

SEQUENCE_LENGTH = 10
BATCH_SIZE = 64
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2
LEARNING_RATE = 0.001
EPOCHS = 5
MAX_VOCAB_SIZE = 5000


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [None]:

dataset = TextDataset(
    parquet_path='../../data/transcriptions_cleaned.parquet', 
    sequence_length=SEQUENCE_LENGTH,
    max_vocab_size=MAX_VOCAB_SIZE
)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
vocab_size = get_vocab_size(dataset)
print(f"Vocabulary Size: {vocab_size}")
print(f"Total Sequences: {len(dataset)}")

MemoryError: 

In [None]:

class LSTMGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden):
        embeds = self.embedding(x)
        out, hidden = self.lstm(embeds, hidden)
        out = out[:, -1, :]
        
        out = self.fc(out)
        return out, hidden
    
    def init_hidden(self, batch_size):
        return (torch.zeros(NUM_LAYERS, batch_size, HIDDEN_DIM).to(device),
                torch.zeros(NUM_LAYERS, batch_size, HIDDEN_DIM).to(device))

In [None]:
model = LSTMGenerator(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        
        optimizer.zero_grad()
        
        hidden = model.init_hidden(x.size(0))
        output, _ = model(x, hidden)
        
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}/{EPOCHS}, Batch {batch_idx}, Loss: {loss.item():.4f}")
            
    print(f"Epoch {epoch+1} Complete. Average Loss: {total_loss / len(dataloader):.4f}")

Epoch 1/5, Batch 0, Loss: 8.5251
Epoch 1/5, Batch 100, Loss: 6.5810
Epoch 1/5, Batch 200, Loss: 6.2387
Epoch 1/5, Batch 300, Loss: 5.8398
Epoch 1/5, Batch 400, Loss: 6.0487
Epoch 1/5, Batch 500, Loss: 6.5078
Epoch 1/5, Batch 600, Loss: 6.2709
Epoch 1/5, Batch 700, Loss: 5.8747
Epoch 1/5, Batch 800, Loss: 6.1920
Epoch 1/5, Batch 900, Loss: 6.6175
Epoch 1/5, Batch 1000, Loss: 5.6697
Epoch 1/5, Batch 1100, Loss: 5.4491
Epoch 1/5, Batch 1200, Loss: 5.5430
Epoch 1/5, Batch 1300, Loss: 5.5711
Epoch 1/5, Batch 1400, Loss: 5.3918
Epoch 1/5, Batch 1500, Loss: 5.8156
Epoch 1/5, Batch 1600, Loss: 6.0520
Epoch 1/5, Batch 1700, Loss: 5.8310
Epoch 1/5, Batch 1800, Loss: 5.3203
Epoch 1/5, Batch 1900, Loss: 6.0250
Epoch 1/5, Batch 2000, Loss: 5.7461
Epoch 1/5, Batch 2100, Loss: 6.0078
Epoch 1/5, Batch 2200, Loss: 5.3119
Epoch 1/5, Batch 2300, Loss: 5.9969
Epoch 1/5, Batch 2400, Loss: 5.5040
Epoch 1/5, Batch 2500, Loss: 5.6189
Epoch 1/5, Batch 2600, Loss: 4.8527
Epoch 1/5, Batch 2700, Loss: 5.0092
Epoc

: 

In [None]:
def generate_text(model, start_text, length=20):
    model.eval()
    words = start_text.lower().split()
    state_h, state_c = model.init_hidden(1)
    
    for _ in range(length):
        if len(words) < SEQUENCE_LENGTH:
             input_seq = [dataset.word_to_int.get(w, 0) for w in words]
             while len(input_seq) < SEQUENCE_LENGTH:
                 input_seq.insert(0, 0) 
        else:
            input_seq = [dataset.word_to_int.get(w, 0) for w in words[-SEQUENCE_LENGTH:]]
            
        x = torch.tensor([input_seq], dtype=torch.long).to(device)
        
        with torch.no_grad():
            output, (state_h, state_c) = model(x, (state_h, state_c))
            
        probs = torch.nn.functional.softmax(output[0], dim=0)
        next_word_idx = torch.multinomial(probs, 1).item()
        next_word = dataset.int_to_word[next_word_idx]
        
        words.append(next_word)
        
    return " ".join(words)

print(generate_text(model, "Make America Great", 20))