12 days of christmas - creating a Long Short-Term Memory (LSTM) model from christmas movie synopsis

Source: https://debuggercafe.com/word-level-text-generation-using-lstm/
For random selection of words: https://agrimpaneru.com.np/blog/lstm-from-scratch/

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torch.utils.data import Dataset, DataLoader
from collections import Counter

In [3]:
# Dataset Preparation
with open('/home/sam_hp/rnn-christmas-movies/data/christmas_movies.txt', 'r', encoding='utf-8') as file:
    text = file.read()
# Tokenize the text into words
words = text.split()

word_counts = Counter(words)
vocab = list(word_counts.keys())
vocab_size = len(vocab)
word_to_int = {word: i for i, word in enumerate(vocab)}
int_to_word = {i: word for word, i in word_to_int.items()}
SEQUENCE_LENGTH = 64
samples = [words[i:i+SEQUENCE_LENGTH+1] for i in range(len(words)-SEQUENCE_LENGTH)]
print(vocab)
print(word_to_int)
print(int_to_word)

['Follows', 'the', 'lives', 'of', 'eight', 'very', 'different', 'couples', 'in', 'dealing', 'with', 'their', 'love', 'various', 'loosely', 'interrelated', 'tales', 'all', 'set', 'during', 'a', 'frantic', 'month', 'before', 'Christmas', 'London,', 'England.', 'An', 'eight-year-old', 'troublemaker,', 'mistakenly', 'left', 'home', 'alone,', 'must', 'defend', 'his', 'against', 'pair', 'burglars', 'on', 'Eve.', 'The', 'Griswold', "family's", 'plans', 'for', 'big', 'family', 'predictably', 'turn', 'into', 'disaster.', 'Raised', 'as', 'an', 'oversized', 'elf,', 'Buddy', 'travels', 'from', 'North', 'Pole', 'to', 'New', 'York', 'City', 'meet', 'biological', 'father,', 'Walter', 'Hobbs,', 'who', "doesn't", 'know', 'he', 'exists', 'and', 'is', 'desperate', 'need', 'some', 'spirit.', 'On', 'outskirts', 'Whoville', 'green,', 'revenge-seeking', 'Grinch', 'ruin', 'citizens', 'town.', 'A', 'grumpy', 'plots', 'village', 'Whoville.', 'police', 'officer', 'tries', 'save', 'estranged', 'wife', 'several', 

Creating data loaders

In [6]:
class TextDataset(Dataset):
    """
    A class that accepts the samples which is a list of lists containing sequences of 64 words each.
    Also accepts word_to_int dictionary for mapping to integers.
    """
    def __init__(self, samples, word_to_int):
        self.samples = samples
        self.word_to_int = word_to_int
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        input_seq = torch.LongTensor([self.word_to_int[word] for word in sample[:-1]])
        target_seq = torch.LongTensor([self.word_to_int[word] for word in sample[1:]])
        return input_seq, target_seq

In [7]:
BATCH_SIZE = 32
dataset = TextDataset(samples, word_to_int)
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
)
print(dataset[1])

(tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,  2,  8, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24,  8, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 32, 37, 20, 38,  3, 39, 40, 24, 41, 42, 43, 44, 45, 46, 20,
        47, 48, 24, 49, 50, 51, 20, 47, 52, 53]), tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,  2,  8, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24,  8, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
        35, 36, 32, 37, 20, 38,  3, 39, 40, 24, 41, 42, 43, 44, 45, 46, 20, 47,
        48, 24, 49, 50, 51, 20, 47, 52, 53, 54]))


The LSTM model

In [5]:
class TextGenerationLSTM(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_dim,
        hidden_size,
        num_layers
    ):
        super(TextGenerationLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            input_size=embedding_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
    def forward(self, x, hidden=None):
        if hidden == None:
            hidden = self.init_hidden(x.shape[0])
        x = self.embedding(x)
        out, (h_n, c_n) = self.lstm(x, hidden)
        out = out.contiguous().view(-1, self.hidden_size)
        out = self.fc(out)
        return out, (h_n, c_n)
    def init_hidden(self, batch_size):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return h0, c0

Training Hyperparameters

In [6]:
# Training Setup
embedding_dim = 16
hidden_size = 32
num_layers = 1
learning_rate = 0.01
epochs = 50

Training the LSTM model

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextGenerationLSTM(
    vocab_size, 
    embedding_dim, 
    hidden_size, 
    num_layers
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
# Training
def train(model, epochs, dataloader, criterion):
    model.train()
    for epoch in range(epochs):
        running_loss = 0
        for input_seq, target_seq in dataloader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            outputs, _ = model(input_seq)
            loss = criterion(outputs, target_seq.view(-1))
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.detach().cpu().numpy()
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch} loss: {epoch_loss:.3f}")
train(model, epochs, dataloader, criterion)

Epoch 0 loss: 4.402
Epoch 1 loss: 2.568
Epoch 2 loss: 2.036
Epoch 3 loss: 1.752
Epoch 4 loss: 1.571
Epoch 5 loss: 1.441
Epoch 6 loss: 1.343
Epoch 7 loss: 1.265
Epoch 8 loss: 1.207
Epoch 9 loss: 1.157
Epoch 10 loss: 1.120
Epoch 11 loss: 1.084
Epoch 12 loss: 1.055
Epoch 13 loss: 1.029
Epoch 14 loss: 1.009
Epoch 15 loss: 0.986
Epoch 16 loss: 0.969
Epoch 17 loss: 0.953
Epoch 18 loss: 0.939
Epoch 19 loss: 0.926
Epoch 20 loss: 0.912
Epoch 21 loss: 0.901
Epoch 22 loss: 0.893
Epoch 23 loss: 0.881
Epoch 24 loss: 0.876
Epoch 25 loss: 0.866
Epoch 26 loss: 0.860
Epoch 27 loss: 0.851
Epoch 28 loss: 0.846
Epoch 29 loss: 0.840
Epoch 30 loss: 0.835
Epoch 31 loss: 0.827
Epoch 32 loss: 0.820
Epoch 33 loss: 0.819
Epoch 34 loss: 0.812
Epoch 35 loss: 0.815
Epoch 36 loss: 0.800
Epoch 37 loss: 0.802
Epoch 38 loss: 0.795
Epoch 39 loss: 0.792
Epoch 40 loss: 0.789
Epoch 41 loss: 0.781
Epoch 42 loss: 0.781
Epoch 43 loss: 0.777
Epoch 44 loss: 0.777
Epoch 45 loss: 0.768
Epoch 46 loss: 0.768
Epoch 47 loss: 0.771
Ep

In [7]:
# Inference
def generate_text(model, start_string, num_words):
    model.eval()
    words = start_string.split()
    for _ in range(num_words):
        input_seq = torch.LongTensor([word_to_int[word] for word in words[-SEQUENCE_LENGTH:]]).unsqueeze(0).to(device)
        h, c = model.init_hidden(1)
        output, (h, c) = model(input_seq, (h, c))

        next_token_prob = torch.softmax(output, dim = -1)
        next_token_index = torch.multinomial(next_token_prob, num_samples=1)[-1].item()
        words.append(int_to_word[next_token_index])
    return " ".join(words)



In [237]:
# Example usage:
generate_text(model, start_string="Christmas ", num_words=40)

'Christmas decorator fulfilling three sisters toys and her family. A story from Hudson, Wisconsin on a reality cooking show, she feels a strong attraction and decides to reunite her age just before Christmas, well a stranger. The legend of Santa Claus'

In [233]:
torch.save(model.state_dict(), "single-word-weights.pth")

In [8]:
model2 = TextGenerationLSTM(
    vocab_size, 
    embedding_dim, 
    hidden_size, 
    num_layers)
model2.load_state_dict(torch.load("single-word-weights.pth"))
model2.eval()

TextGenerationLSTM(
  (embedding): Embedding(6554, 16)
  (lstm): LSTM(16, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=6554, bias=True)
)

In [15]:
generate_text(model2, start_string="Once ", num_words=40)

"Once be true lot for Christmas, Kate and Teddy Pierce, whose save family's Colorado ranch, a promotion and unexpected guest and some turn hostile, leading to suspicion of foul play. Darcy's been striving for Christmas. Ellie does around the latest of"

In [2]:
list(range(0, 10))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]