## Setup

Import all packages we need and set the correct device. If a CUDA compatible GPU is found, it will be used. If not, everything will be done on CPU.

In [None]:
import re

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Handling the data

We make two classes for processing our data. The first class, TextEncoder, will be used to encode and decode text data, since we cannot use strings as input directly. All unique words will be found and mapped to an integer. Apart from the unique words we extract from the data, TextEncoder will also have an unknown token: [UNK]. This token will be used if, during inference, we encounter a word that is not part of our vocabulary.

Our second class, TextData, will handle the reading and sampling of our input data. It will use and instance of the TextEncoder class to encode our data. TextData also lets us sample sequences of text. In order to do this, TextData needs an implementation of the \_\_getitem\_\_ method, which will tell it how to handle indices. We also need a \_\_len\_\_ method so that our TextData class can work with pytorch's DataLoader, but more on that later.

All you really need to understand for now though, is that these two classes read data from a txt file and encode it so we can use it for our model.

In [None]:
class TextEncoder():
    def __init__(self, file_path):
        self.vocab = set()
        self.vocab_size = 0
        self.encoder = dict()
        self.decoder = dict()

        self._extract_vocab(file_path)
        self._make_encoder_decoder()
    
    def _extract_vocab(self, file_path):
        with open(file_path, "r") as file:
            text = file.read()
            vocab = text.split()
            vocab = [re.sub('[^A-Za-z0-9]+', '', word.lower()) for word in vocab if word != ""]
            vocab = set(vocab)
        self.vocab = vocab
        self.vocab_size = len(vocab) + 1 # add one for unknown word token

    def _make_encoder_decoder(self):
        word_ids = range(1, self.vocab_size) # reserve 0 for unknown words
        self.encoder = dict(zip(self.vocab, word_ids))
        self.decoder = dict(zip(word_ids, self.vocab))

        # add unknown token and id
        self.encoder["[UNK]"] = 0
        self.decoder[0] = "[UNK]"

class TextData(Dataset):
    def __init__(self, file_path, text_encoder, seq_len):
        super().__init__()

        self.seq_len = seq_len
        self.text_encoder = text_encoder
        self.text = self._read_text(file_path)
        self.encoded_text = self.encode_text(self.text)

    def __len__(self):
        return len(self.encoded_text) - self.seq_len - 2

    def __getitem__(self, index):
        return {
            "sequence": self.encoded_text[index : index + self.seq_len],
            "next_tokens": self.encoded_text[index + 1 : index + self.seq_len + 1],
        }

    def _read_text(self, file_path):
        with open(file_path, "r") as file:
            return file.read()
    
    def encode_text(self, text):
        all_words = text.split()
        all_words = [word.lower() for word in all_words if word != ""]
        encoded_words = [self.text_encoder.encoder[re.sub('[^A-Za-z0-9]+', '', word.lower())] if word in self.text_encoder.vocab else self.text_encoder.encoder["[UNK]"] for word in all_words]

        return np.asarray(encoded_words)

    def decode_text(self, tokens):
        sentence = []
        for token in tokens:
            sentence.append(self.text_encoder.decoder[token])

        return " ".join(sentence)    


## Model

Once we can process our data, we need a model. In this notebook we will train an LSTM with pytorch, which we define as a class here. We choose which layers we want our model to have and define the forward pass, as well as a function that generates text.

The model consists of three parts: 

<ol>
    <li> Embedding layer. This part learns vector representations for each word in the vocabulary.
    <li> LSTM. This is the recurrent part that models text.
    <li> Linear classifier. This is an extra layer after the LSTM that makes the prediction.
</ol>

In [None]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super().__init__()
        # embedding
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

        # recurrent network
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # classification head
        self.linear_classifier = nn.Linear(in_features=hidden_dim, out_features=vocab_size)
    
    def forward(self, x: torch.Tensor, previous_state=None):
        embeddings = self.embedding_layer(x)
        h_all, (h_n, c_n) = self.lstm(embeddings, previous_state)
        logits_all = self.linear_classifier(h_all)

        # make sure to take the last layer output for classification
        logits_n = logits_all[:,-1,:].squeeze(1)
        logits_all = logits_all.permute(0,2,1)
        
        # in the case we are generating we have to keep track of the hidden (cell) state
        if previous_state is not None:
            return logits_n, (h_n, c_n)

        return logits_all, logits_n
    
    def generate(self, x, gen_len=32, sample=False, temperature=1.0):

        for i in range(0, gen_len):

            if i == 0:
                logits_n, (h_n, c_n) = self.forward(x, previous_state=self.init_state())
            else:
                logits_n, (h_n, c_n) = self.forward(x, previous_state=(h_n, c_n))

            # control how random we want to be in our sampling
            logits_n = logits_n / temperature

            # get a probability distribution over your outputs (vocabulary)
            probs = torch.softmax(logits_n, dim=1)

            if sample:
                next_word = torch.multinomial(probs, num_samples=1)
            else:
                next_word = torch.argmax(probs, dim=1)    
            
            # add the next word to your train of token inputs and repeat the process
            next_word = next_word.unsqueeze(1)
            x = torch.cat((x, next_word), dim=1)

        return x.squeeze(0).cpu().tolist()


    def init_state(self):
        """
        When generating we want to initialize the hidden (cell) states to 0.
        See it as generating from a blank slate.
        """
        return (torch.zeros(self.lstm.num_layers, 1, self.lstm.hidden_size),
                torch.zeros(self.lstm.num_layers, 1, self.lstm.hidden_size))

## Hyperparameters

Here we set some hyperparameters that will define how the model is trained. You can leave them as is or play with them and see what happens.

In [None]:
BATCH_SIZE = 1028
EPOCHS = 20
LEARNING_RATE = 5e-4
SEQUENCE_LEN = 32
HIDDEN_DIM = 512
EMBEDDING_DIM = 32
NUM_LAYERS = 1

## Make some functions

The most important function we're creating here is train\_loop, as this is how we train our model. The generate function is used to generate text. It is called at the end of every epoch, so we can see the model improve. The last function here is print\_model\_size, which we simply add to give you a sense of how big such a model is. If you do play around with the hyperparameters above, you can see how this influences not only the performance, but also the size. 

We add a separate cell where you can change the DATA\_FILE\_PATH to any of the txt files in the Data folder, that way you can choose which data you want to train on. You can also type the prompt that is used to generate a sentence each epoch. Keep in mind the model will only recognize words it has seen in the txt file you train on!

In [None]:
DATA_FILE_PATH = "./Data/dummy_text.txt"
PROMPT = "It is a long established fact that"

In [None]:
def train_loop(model, dataloader):
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), LEARNING_RATE)

    for epoch in range(EPOCHS):
        losses = []
        for i, batch in enumerate(dataloader):
            model.train()

            input_sequence = batch["sequence"].to(device)
            next_tokens = batch["next_tokens"].to(device)

            optimizer.zero_grad()
            outputs, _ = model(input_sequence)
            loss = loss_function(outputs, next_tokens)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            if i % 9 == 0:
                print(f"[{epoch + 1}, {i + 1:5d}] train loss: {np.mean(losses):.3f}")
        
        print(generate_sentence(model, dataloader, PROMPT, 32))
    print("Finished training.")

def generate_sentence(model, dataloader, sentence, gen_len):
    model.eval()

    # tokenize your prompt text
    x = torch.tensor([dataloader.dataset.encode_text(sentence)])

    # generate tokens
    generated_tokens = model.generate(x, gen_len=gen_len, sample=False, temperature=1.0)

    # decode the tokens back to normal text
    sentence = dataloader.dataset.decode_text(generated_tokens)

    return sentence
    
def print_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement()*param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement()*buffer.element_size()

    size = (param_size + buffer_size) / 1024**2
    print("model number of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
    print("model size: {:.3f}MB".format(size))

## Train your LSTM!

Here we call the individual parts we defined above and actually do the training.

In [None]:
text_encoder = TextEncoder(DATA_FILE_PATH)
dataset = TextData(DATA_FILE_PATH, text_encoder, SEQUENCE_LEN)
dataloader = DataLoader(dataset, BATCH_SIZE)

model = LSTM(vocab_size=text_encoder.vocab_size, embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS).to(device)
print(model)
print_model_size(model)

train_loop(model, dataloader)