# Exercise 1: LSTM for Generating Text



This exercise demonstrates the use of an LSTM (Long Short-Term Memory) neural network to generate text based on song lyrics. The process includes loading song lyrics, tokenizing the text, creating sequences, training the LSTM model, and generating new lyrics based on a given seed sentence.
    

In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import re

# Load the lyrics
def load_lyrics(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        lyrics = f.read()
    return lyrics.lower()

lyrics = load_lyrics('./bieber.txt')  # Provide path to your lyrics file

# Tokenize and Vectorize the Text
from collections import Counter

def preprocess_text(text):
    text = re.sub(f"[{string.punctuation}]", "", text)  # Remove punctuation
    words = text.split()
    return words

words = preprocess_text(lyrics)
word_count = Counter(words)
vocab = sorted(word_count, key=word_count.get, reverse=True)
vocab_to_int = {word: i for i, word in enumerate(vocab, 1)}

# Convert the lyrics to a sequence of integers
lyrics_int = [vocab_to_int[word] for word in words]


In [3]:

# Create n-grams and pad sequences
def create_sequences(lyrics_int, seq_length):
    sequences = []
    for i in range(seq_length, len(lyrics_int)):
        seq = lyrics_int[i-seq_length:i+1]
        sequences.append(seq)
    return sequences

seq_length = 10
sequences = create_sequences(lyrics_int, seq_length)


In [4]:

# Create Predictors and Labels
sequences = np.array(sequences)
X, y = sequences[:, :-1], sequences[:, -1]
X = torch.tensor(X)
y = torch.tensor(y)

# One-hot encode labels
y = torch.nn.functional.one_hot(y, num_classes=len(vocab_to_int) + 1).float()


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from termcolor import colored
import numpy as np

# Build and Train the LSTM Model
class LSTMLyricsModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(LSTMLyricsModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])
        return out

# Define hyperparameters
vocab_size = len(vocab_to_int) + 1
embedding_dim = 128
hidden_dim = 256
output_dim = vocab_size

# Initialize model, loss function, and optimizer
model = LSTMLyricsModel(vocab_size, embedding_dim, hidden_dim, output_dim)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Function to calculate accuracy
def accuracy(preds, labels):
    _, pred_labels = torch.max(preds, 1)  # Get the index of the max log-probability
    correct = (pred_labels == labels).sum().item()  # Compare predictions and true labels
    return correct / labels.size(0)

# Function to train the model
def train_model(model, X, y, epochs=20, batch_size=64):
    dataset = TensorDataset(X, y)  # Ensure y are integer labels, not one-hot
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_acc = 0.0

        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch_X, batch_y in progress_bar:
            optimizer.zero_grad()
            output = model(batch_X)  # Forward pass
            loss = loss_fn(output, batch_y)  # Compute loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights

            # Accumulate loss and accuracy
            epoch_loss += loss.item()
            epoch_acc += accuracy(output, batch_y)

            # Update progress bar with loss and accuracy
            progress_bar.set_postfix({
                'Loss': loss.item(),
                'Accuracy': epoch_acc / len(loader)
            })

        # Print colored summary at the end of each epoch
        print(colored(f"Epoch {epoch+1}/{epochs} Summary: ", 'green'))
        print(colored(f"Average Loss: {epoch_loss / len(loader):.4f}, Accuracy: {epoch_acc / len(loader):.4f}", 'yellow'))

# Remove one-hot encoding from labels
# Make sure the labels `y` are integers, not one-hot encoded vectors
y = torch.tensor(sequences[:, -1], dtype=torch.long)  # The last word as label

# Train the model
train_model(model, X, y)


Epoch 1/20: 100%|██████████| 459/459 [00:32<00:00, 14.19it/s, Loss=3.18, Accuracy=0.12]


Epoch 1/20 Summary: 
Average Loss: 5.2889, Accuracy: 0.1200


Epoch 2/20: 100%|██████████| 459/459 [00:33<00:00, 13.61it/s, Loss=3.6, Accuracy=0.287]


Epoch 2/20 Summary: 
Average Loss: 3.8596, Accuracy: 0.2873


Epoch 3/20: 100%|██████████| 459/459 [00:33<00:00, 13.78it/s, Loss=2.86, Accuracy=0.43]


Epoch 3/20 Summary: 
Average Loss: 2.9057, Accuracy: 0.4296


Epoch 4/20: 100%|██████████| 459/459 [00:32<00:00, 13.99it/s, Loss=2.48, Accuracy=0.558]


Epoch 4/20 Summary: 
Average Loss: 2.1742, Accuracy: 0.5578


Epoch 5/20: 100%|██████████| 459/459 [00:31<00:00, 14.38it/s, Loss=1.7, Accuracy=0.668]


Epoch 5/20 Summary: 
Average Loss: 1.6092, Accuracy: 0.6680


Epoch 6/20: 100%|██████████| 459/459 [00:32<00:00, 14.07it/s, Loss=0.521, Accuracy=0.765]


Epoch 6/20 Summary: 
Average Loss: 1.1698, Accuracy: 0.7648


Epoch 7/20: 100%|██████████| 459/459 [00:32<00:00, 14.22it/s, Loss=1.28, Accuracy=0.835]


Epoch 7/20 Summary: 
Average Loss: 0.8433, Accuracy: 0.8345


Epoch 8/20: 100%|██████████| 459/459 [00:31<00:00, 14.45it/s, Loss=0.469, Accuracy=0.886]


Epoch 8/20 Summary: 
Average Loss: 0.6012, Accuracy: 0.8863


Epoch 9/20: 100%|██████████| 459/459 [00:32<00:00, 13.98it/s, Loss=0.546, Accuracy=0.927]


Epoch 9/20 Summary: 
Average Loss: 0.4233, Accuracy: 0.9266


Epoch 10/20: 100%|██████████| 459/459 [00:32<00:00, 13.92it/s, Loss=0.396, Accuracy=0.952]


Epoch 10/20 Summary: 
Average Loss: 0.3022, Accuracy: 0.9516


Epoch 11/20: 100%|██████████| 459/459 [00:32<00:00, 13.95it/s, Loss=0.457, Accuracy=0.965]


Epoch 11/20 Summary: 
Average Loss: 0.2236, Accuracy: 0.9645


Epoch 12/20: 100%|██████████| 459/459 [00:31<00:00, 14.39it/s, Loss=0.149, Accuracy=0.971]


Epoch 12/20 Summary: 
Average Loss: 0.1745, Accuracy: 0.9708


Epoch 13/20: 100%|██████████| 459/459 [00:32<00:00, 13.94it/s, Loss=0.0803, Accuracy=0.973]


Epoch 13/20 Summary: 
Average Loss: 0.1456, Accuracy: 0.9730


Epoch 14/20: 100%|██████████| 459/459 [00:31<00:00, 14.37it/s, Loss=0.168, Accuracy=0.975]


Epoch 14/20 Summary: 
Average Loss: 0.1269, Accuracy: 0.9750


Epoch 15/20: 100%|██████████| 459/459 [00:32<00:00, 13.97it/s, Loss=0.248, Accuracy=0.975]


Epoch 15/20 Summary: 
Average Loss: 0.1192, Accuracy: 0.9746


Epoch 16/20: 100%|██████████| 459/459 [00:32<00:00, 14.25it/s, Loss=0.0514, Accuracy=0.975]


Epoch 16/20 Summary: 
Average Loss: 0.1126, Accuracy: 0.9751


Epoch 17/20: 100%|██████████| 459/459 [00:33<00:00, 13.58it/s, Loss=0.0438, Accuracy=0.975]


Epoch 17/20 Summary: 
Average Loss: 0.1124, Accuracy: 0.9746


Epoch 18/20: 100%|██████████| 459/459 [00:32<00:00, 14.18it/s, Loss=0.714, Accuracy=0.974]


Epoch 18/20 Summary: 
Average Loss: 0.1124, Accuracy: 0.9740


Epoch 19/20: 100%|██████████| 459/459 [00:32<00:00, 14.19it/s, Loss=0.0396, Accuracy=0.975]


Epoch 19/20 Summary: 
Average Loss: 0.1090, Accuracy: 0.9745


Epoch 20/20: 100%|██████████| 459/459 [00:32<00:00, 13.91it/s, Loss=0.0202, Accuracy=0.975]

Epoch 20/20 Summary: 
Average Loss: 0.1014, Accuracy: 0.9755





In [24]:
# Print the words in the vocabulary
print("Words in the vocabulary:", list(vocab_to_int.keys()))

import torch.nn.functional as F

def generate_lyrics(model, start_words, length=30, temperature=1.0):
    model.eval()

    # Convert input to lowercase to match preprocessed vocabulary
    words = preprocess_text(start_words.lower())

    # Check if words exist in vocabulary, or skip them
    state = [vocab_to_int[word] for word in words if word in vocab_to_int]

    # Handle the case where no valid words are found in vocab
    if len(state) == 0:
        raise ValueError(f"None of the words in '{start_words}' are in the vocabulary.")

    generated_words = words[:]  # Start with the input words

    # Generate lyrics by predicting the next word
    for _ in range(length):
        # Ensure the tensor is of type Long (as required by the embedding layer)
        state_tensor = torch.tensor([state[-seq_length:]], dtype=torch.long)
        output = model(state_tensor)

        # Apply temperature to the output logits
        output = output / temperature
        probabilities = F.softmax(output, dim=1)

        # Sample from the probability distribution instead of argmax
        predicted_idx = torch.multinomial(probabilities, 1).item()
        state.append(predicted_idx)

        # Append the generated word
        generated_words.append(vocab[predicted_idx])

    return ' '.join(generated_words)

# Test with temperature sampling
start_words = "my"
generated_lyrics = generate_lyrics(model, start_words, temperature=0.8)  # Experiment with temperature
print(generated_lyrics)



Words in the vocabulary: ['you', 'i', 'the', 'me', 'to', 'and', 'oh', 'my', 'im', 'it', 'that', 'be', 'love', 'dont', 'in', 'baby', 'your', 'yeah', 'a', 'no', 'all', 'one', 'we', 'girl', 'like', 'do', 'know', 'is', 'with', 'what', 'for', 'but', 'youre', 'so', 'on', 'cause', 'up', 'make', 'need', 'if', 'right', 'now', 'got', 'just', 'when', 'never', 'its', 'can', 'go', 'let', 'of', 'time', 'want', 'wanna', 'say', 'ill', 'down', 'this', 'us', 'only', 'she', 'smile', 'cant', 'heart', 'see', 'get', 'aint', 'tell', 'back', 'out', 'are', 'was', 'whoa', 'could', 'gonna', 'at', 'life', 'have', 'ooh', 'theres', 'nothing', 'should', 'hey', 'how', 'world', 'not', 'gotta', 'give', 'as', 'mean', 'here', 'take', 'would', 'mind', 'where', 'way', 'believe', 'thats', 'from', 'had', 'away', 'without', 'around', 'were', 'will', 'less', 'better', 'show', 'they', 'somebody', 'ever', 'madly', 'crazy', 'been', 'alright', 'live', 'more', 'think', 'day', 'about', 'mine', 'her', 'am', 'lonely', 'eyes', 'there',