Character-Level RNN for Text Generation with PyTorch

This notebook demonstrates how to build a Recurrent Neural Network (RNN)
using PyTorch to generate text character by character.

It covers:
- Converting text into numerical representations (one-hot vectors).
- Handling sequential dependencies with hidden states.
- Training with Teacher Forcing.
- Implementing sampling strategies like greedy search and temperature-based sampling.

Bonus: Provides a brief outline for extending to word-level generation.

## Libraries

In [9]:
import torch
import torch.nn as nn
import numpy as np
import random
import time

# Sample text data

In [10]:
text = "The quick brown fox jumps over the lazy dog. " * 5

# Create a set of unique characters in the text

In [11]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Unique characters: {chars}")
print(f"Vocabulary size: {vocab_size}")

Unique characters: [' ', '.', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocabulary size: 29


# Create character-to-index and index-to-character mappings

In [12]:
char_to_index = {char: i for i, char in enumerate(chars)}
index_to_char = {i: char for i, char in enumerate(chars)}

# Convert text to a sequence of indices

In [13]:
data = [char_to_index[char] for char in text]
print(f"Length of the text: {len(data)}")
print(f"First 10 characters as indices: {data[:10]}")

Length of the text: 225
First 10 characters as indices: [2, 10, 7, 0, 19, 23, 11, 5, 13, 0]


## 2. Model Definition

In [14]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(CharRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden):
        # x shape: (batch_size, seq_len)
        # hidden shape (for LSTM): (h_n, c_n) where h_n and c_n are (num_layers, batch_size, hidden_dim)

        # Embed the input characters
        embedded = self.embedding(x) # shape: (batch_size, seq_len, embedding_dim)

        # Pass the embedded input through the RNN
        out, hidden = self.rnn(embedded, hidden) # out shape: (batch_size, seq_len, hidden_dim)

        # Pass the output through the fully connected layer
        out = self.fc(out) # shape: (batch_size, seq_len, vocab_size)

        return out, hidden

    def init_hidden(self, batch_size, device):
        # Initialize hidden and cell states for LSTM
        return (torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(device),
                torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(device))

## 3. Training the Model

In [15]:
# Hyperparameters
embedding_dim = 50
hidden_dim = 100
num_layers = 2
seq_length = 10 # Reduced seq_length
batch_size = 16 # Reduced batch_size
learning_rate = 0.01
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = CharRNN(vocab_size, embedding_dim, hidden_dim, num_layers).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Prepare data for training
def get_batches(data, seq_length, batch_size):
    n_batches = len(data) // (seq_length * batch_size)
    valid_batches_x = []
    valid_batches_y = []

    for i in range(n_batches):
        start_index = i * batch_size * seq_length
        end_index = start_index + batch_size * seq_length

        batch_x = np.zeros((batch_size, seq_length), dtype=np.int64)
        batch_y = np.zeros((batch_size, seq_length), dtype=np.int64)

        for j in range(batch_size):
            seq_start = start_index + j * seq_length
            seq_end = seq_start + seq_length
            if seq_end + 1 <= len(data):
                batch_x[j, :] = data[seq_start:seq_end]
                batch_y[j, :] = data[seq_start + 1:seq_end + 1]
            else:
                # If not enough data for a full sequence, skip this batch
                batch_x = None
                batch_y = None
                break

        if batch_x is not None and batch_y is not None:
            valid_batches_x.append(torch.from_numpy(batch_x).to(device))
            valid_batches_y.append(torch.from_numpy(batch_y).to(device))

    if not valid_batches_x:
        return torch.empty(0, batch_size, seq_length, dtype=torch.long).to(device), torch.empty(0, batch_size, seq_length, dtype=torch.long).to(device)

    return torch.stack(valid_batches_x), torch.stack(valid_batches_y)


# Training loop
start_time = time.time()
for epoch in range(num_epochs):
    x_batches, y_batches = get_batches(data, seq_length, batch_size)

    if x_batches.size(0) == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], No batches available. Skipping epoch.")
        continue

    num_batches = x_batches.size(0)
    epoch_loss = 0

    for batch_idx in range(num_batches):
        inputs = x_batches[batch_idx]
        targets = y_batches[batch_idx]
        current_batch_size = inputs.size(0) # Get the actual batch size

        # Initialize hidden state
        hidden = model.init_hidden(current_batch_size, device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass (Teacher Forcing)
        outputs, hidden = model(inputs, hidden)

        # Calculate loss
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        epoch_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5) # Gradient clipping
        optimizer.step()

    avg_loss = epoch_loss / num_batches if num_batches > 0 else 0
    elapsed_time = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f}s")

Epoch [1/100], Loss: 3.3761, Time: 0.01s
Epoch [2/100], Loss: 3.2622, Time: 0.01s
Epoch [3/100], Loss: 3.0422, Time: 0.03s
Epoch [4/100], Loss: 3.1399, Time: 0.04s
Epoch [5/100], Loss: 2.8525, Time: 0.04s
Epoch [6/100], Loss: 2.7028, Time: 0.05s
Epoch [7/100], Loss: 2.5701, Time: 0.05s
Epoch [8/100], Loss: 2.3846, Time: 0.06s
Epoch [9/100], Loss: 2.1569, Time: 0.06s
Epoch [10/100], Loss: 1.9037, Time: 0.07s
Epoch [11/100], Loss: 1.6364, Time: 0.07s
Epoch [12/100], Loss: 1.3820, Time: 0.08s
Epoch [13/100], Loss: 1.1517, Time: 0.09s
Epoch [14/100], Loss: 0.9463, Time: 0.09s
Epoch [15/100], Loss: 0.7687, Time: 0.09s
Epoch [16/100], Loss: 0.6177, Time: 0.10s
Epoch [17/100], Loss: 0.4923, Time: 0.10s
Epoch [18/100], Loss: 0.3939, Time: 0.11s
Epoch [19/100], Loss: 0.3145, Time: 0.11s
Epoch [20/100], Loss: 0.2500, Time: 0.11s
Epoch [21/100], Loss: 0.1995, Time: 0.12s
Epoch [22/100], Loss: 0.1612, Time: 0.12s
Epoch [23/100], Loss: 0.1324, Time: 0.12s
Epoch [24/100], Loss: 0.1106, Time: 0.13s
E

## 4. Text Generation (Sampling)

In [16]:
def predict(model, start_char, predict_len=200, temperature=1.0):
    model.eval()
    generated_text = [start_char]
    input_eval = torch.tensor([[char_to_index[start_char]]], dtype=torch.long).to(device)
    hidden = model.init_hidden(1, device)

    with torch.no_grad():
        for _ in range(predict_len):
            outputs, hidden = model(input_eval, hidden)
            # outputs shape: (1, 1, vocab_size)

            # Apply temperature
            output_logits = outputs[:, -1, :] / temperature
            probabilities = nn.functional.softmax(output_logits, dim=-1)

            # Sample the next character
            predicted_index = torch.multinomial(probabilities, num_samples=1).item()
            predicted_char = index_to_char[predicted_index]

            generated_text.append(predicted_char)
            input_eval = torch.tensor([[predicted_index]], dtype=torch.long).to(device)

    return "".join(generated_text)

# Greedy Search
print("\n--- Greedy Search ---")
start_char = random.choice(chars)
generated_text_greedy = predict(model, start_char, predict_len=200, temperature=0.001) # Low temperature for greedy
print(f"Starting with '{start_char}':\n{generated_text_greedy}")

# Temperature-based Sampling
print("\n--- Temperature Sampling (T=0.5) ---")
start_char = random.choice(chars)
generated_text_temp_05 = predict(model, start_char, predict_len=200, temperature=0.5)
print(f"Starting with '{start_char}':\n{generated_text_temp_05}")

print("\n--- Temperature Sampling (T=1.0) ---")
start_char = random.choice(chars)
generated_text_temp_10 = predict(model, start_char, predict_len=200, temperature=1.0)
print(f"Starting with '{start_char}':\n{generated_text_temp_10}")

print("\n--- Temperature Sampling (T=1.5) ---")
start_char = random.choice(chars)
generated_text_temp_15 = predict(model, start_char, predict_len=200, temperature=1.5)
print(f"Starting with '{start_char}':\n{generated_text_temp_15}")


--- Greedy Search ---
Starting with ' ':
 fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the l

--- Temperature Sampling (T=0.5) ---
Starting with 'y':
y over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. 

--- Temperature Sampling (T=1.0) ---
Starting with 'w':
we the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy To. The quiick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The 

--- Temperature Sampling (T=1.5) ---
Starting with 'f':
fox jumps over the lazy dog. The quick brown fox jumps ocer the lazy born fox jumps over quick brown fox. jumps over the quick brown fox jumps over the lazy dog. The quick brown ffx