#### RNN Language Model

Given a sequence of word embedding vectors $w_1, .., w_N$, we pass the sequence through a `uni-directional RNN` to obtain a sequence of hidden states $h_1,..,h_N$. Then note that each hidden state $h_i$ can be regarded as a contextual repsentation of the words $w_1, .., w_i$. So using this hidden state, we can compute a probability distribution for the next word that follows all the preceding words:

$P(w_{i+1} | w_1,...,w_i) = f(h_i)$

where $f$ is a function that transforms $h_i$ into the probability distribution. $f$ can be a feedforward network, in the simplest case a linear projection followed by a softmax. Also note that we use a uni-directional RNN (and not bi-directional) because for a language model, we want to predict the next word using only the previous words as context.

Optional: The performance of an RNN model can be further improved if we choose the embedding dimensions and the RNN hidden state dimensions to be the same. This allows us to then re-use the embedding matrix to perform the linear projection of the hidden states into the output logits instead of using a separate projection matrix and therefore saves a lot of extra parameters and potentially reduces overfitting. This technique is also called `weight tying`. 

Previously we looked at simple n-gram language models which are only feasilble for small $n$, i.e. shorter context size. With an RNN, we have access to much larger contexts and therefore we can get better performance (e.g. lower perplexity compared to n-gram LMs).

We will train a word-level RNN LM on the collected works of Shakespeare.

In [28]:
import re
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import psutil

In [11]:
# prep the training data
with open('shakespeare.txt', 'r') as file:
    lines = file.readlines()

# remove all punctuations (except for apostrophe) and escape characters from the lines, lowercase all characters
sentences_clean = []
for line in lines:
    cleaned = re.sub(r"[^\w\s']",'',line).strip().lower()
    if len(cleaned) > 0:
        sentences_clean.append(cleaned)

# word_tokenize the sentences (split on whitespaces) and add start and end sentence tokens
start_token = '<s>'        
end_token = '</s>'        
sentences_tokenized = [[start_token]+s.split()+[end_token] for s in sentences_clean]
print(f"Num sentences: {len(sentences_tokenized)}")    

# now we split the data into train and test sentences
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)
test_idx = random.sample(range(num_sent), num_test)

sentences_train = []
sentences_val = []
for i in range(num_sent):
    if i not in test_idx:
        sentences_train.append(sentences_tokenized[i])
    else:
        sentences_val.append(sentences_tokenized[i])    

print(f"Number of training sentences: {len(sentences_train)}")        
print(f"Number of test sentences: {len(sentences_val)}")   

Num sentences: 32777
Number of training sentences: 29500
Number of test sentences: 3277


In [25]:
# create vocabulary
pad_token = "<PAD>"
vocab = [pad_token] + sorted(list(set([w for s in sentences_tokenized for w in s])))
word2idx = {w:i for i,w in enumerate(vocab)}
vocab_size = len(vocab)

# tokenize the sentences 
x_train = [[word2idx[word] for word in s] for s in sentences_train]
x_val = [[word2idx[word] for word in s] for s in sentences_val]

max_len_train = max([len(s) for s in x_train])
max_len_val = max([len(s) for s in x_val])

print(f"Longest train sentence: {max_len_train}")
print(f"Longest val sentence: {max_len_val}")

Longest train sentence: 18
Longest val sentence: 17


#### Create a pytorch dataset

In [26]:
class Shakespeare(Dataset):
    def __init__(self, x):
        inputs = [s[:-1] for s in x]
        targets = [s[1:] for s in x]
        self.inputs = pad_sequence([torch.tensor(x, dtype=torch.long) for x in inputs], batch_first=True, padding_value=0)
        self.targets = pad_sequence([torch.tensor(y, dtype=torch.long) for y in targets], batch_first=True, padding_value=-1)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [27]:
train_dataset = Shakespeare(x_train)
val_dataset = Shakespeare(x_val)

#### Create the RNN LM

In [None]:
class RNNLM(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=32, num_rnn_layers=1, dropout_rate=0.1, padding_idx=-1):
        super().__init__()
        self.padding_idx = padding_idx

        # embedding layer
        self.emb = torch.nn.Embedding(vocab_size, embedding_dim)
        c = 0.1        
        torch.nn.init.uniform_(self.emb.weight, -c, c)

        # create rnn layers (we will use bidirectional LSTM so the output hidden states will have dims=2*hidden_dims)
        if num_rnn_layers == 1:
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True)
        else:    
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True, dropout=dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(embedding_dim, vocab_size)
        # tie the output layer weights with the embedding layer weights
        self.output_layer.weight = self.emb.weight

    # forward pass
    def forward(self, x, y):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute rnn hidden states
        x, _ = self.rnn_layers(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,vocab_size)
        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,vocab_size)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y, ignore_index=self.padding_idx)

        return x, loss
        