#### RNN Language Model

Given a sequence of word embedding vectors $w_1, .., w_N$, we pass the sequence through a `uni-directional RNN` to obtain a sequence of hidden states $h_1,..,h_N$. Then note that each hidden state $h_i$ can be regarded as a contextual repsentation of the words $w_1, .., w_i$. So using this hidden state, we can compute a probability distribution for the next word that follows all the preceding words:

$P(w_{i+1} | w_1,...,w_i) = f(h_i)$

where $f$ is a function that transforms $h_i$ into the probability distribution. $f$ can be a feedforward network, in the simplest case a linear projection followed by a softmax. Also note that we use a uni-directional RNN (and not bi-directional) because for a language model, we want to predict the next word using only the previous words as context.

Optional: The performance of an RNN model can be further improved if we choose the embedding dimensions and the RNN hidden state dimensions to be the same. This allows us to then re-use the embedding matrix to perform the linear projection of the hidden states into the output logits instead of using a separate projection matrix and therefore saves a lot of extra parameters and potentially reduces overfitting. This technique is also called `weight tying`. 

Previously we looked at simple n-gram language models which are only feasilble for small $n$, i.e. shorter context size. With an RNN, we have access to much larger contexts and therefore we can get better performance (e.g. lower perplexity compared to n-gram LMs).

We will train a word-level RNN LM on the collected works of Shakespeare.

In [56]:
import random
from nltk.tokenize import word_tokenize
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import psutil

In [57]:
# prep the training data
with open('shakespeare.txt', 'r') as file:
    lines = file.readlines()

In [63]:
# word_tokenize the sentences (split on whitespaces) and add start and end sentence tokens, keep punctuations as individual tokens
start_token = '<s>'        
end_token = '</s>'        
sentences_tokenized = [[start_token]+word_tokenize(s.lower())+[end_token] for s in lines]
print(f"Num sentences: {len(sentences_tokenized)}")    

# now we split the data into train and test sentences
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)
test_idx = random.sample(range(num_sent), num_test)

sentences_train = []
sentences_val = []
for i in range(num_sent):
    if i not in test_idx:
        sentences_train.append(sentences_tokenized[i])
    else:
        sentences_val.append(sentences_tokenized[i])    

print(f"Number of training sentences: {len(sentences_train)}")        
print(f"Number of test sentences: {len(sentences_val)}")   

Num sentences: 40000
Number of training sentences: 36000
Number of test sentences: 4000


In [64]:
sentences_train[:10]

[['<s>', 'first', 'citizen', ':', '</s>'],
 ['<s>',
  'before',
  'we',
  'proceed',
  'any',
  'further',
  ',',
  'hear',
  'me',
  'speak',
  '.',
  '</s>'],
 ['<s>', '</s>'],
 ['<s>', 'all', ':', '</s>'],
 ['<s>', 'speak', ',', 'speak', '.', '</s>'],
 ['<s>', '</s>'],
 ['<s>', 'first', 'citizen', ':', '</s>'],
 ['<s>',
  'you',
  'are',
  'all',
  'resolved',
  'rather',
  'to',
  'die',
  'than',
  'to',
  'famish',
  '?',
  '</s>'],
 ['<s>', '</s>'],
 ['<s>', 'all', ':', '</s>']]

In [65]:
# create vocabulary
pad_token = "<PAD>"
vocab = [pad_token] + sorted(list(set([w for s in sentences_tokenized for w in s])))
word2idx = {w:i for i,w in enumerate(vocab)}
vocab_size = len(vocab)

# tokenize the sentences 
x_train = [[word2idx[word] for word in s] for s in sentences_train]
x_val = [[word2idx[word] for word in s] for s in sentences_val]

max_len_train = max([len(s) for s in x_train])
max_len_val = max([len(s) for s in x_val])

print(f"Longest train sentence: {max_len_train}")
print(f"Longest val sentence: {max_len_val}")

Longest train sentence: 22
Longest val sentence: 20


#### Create a pytorch dataset

In [66]:
class Shakespeare(Dataset):
    def __init__(self, x):
        inputs = [s[:-1] for s in x]
        targets = [s[1:] for s in x]
        self.inputs = pad_sequence([torch.tensor(x, dtype=torch.long) for x in inputs], batch_first=True, padding_value=0)
        self.targets = pad_sequence([torch.tensor(y, dtype=torch.long) for y in targets], batch_first=True, padding_value=-1)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [67]:
train_dataset = Shakespeare(x_train)
val_dataset = Shakespeare(x_val)

#### Create the RNN LM

In [68]:
class RNNLM(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=32, num_rnn_layers=1, dropout_rate=0.1, padding_idx=-1):
        super().__init__()
        self.padding_idx = padding_idx

        # embedding layer
        self.emb = torch.nn.Embedding(vocab_size, embedding_dim)
        c = 0.1        
        torch.nn.init.uniform_(self.emb.weight, -c, c)

        # create rnn layers (we will use bidirectional LSTM so the output hidden states will have dims=2*hidden_dims)
        if num_rnn_layers == 1:
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True)
        else:    
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True, dropout=dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(embedding_dim, vocab_size)
        # tie the output layer weights with the embedding layer weights
        self.output_layer.weight = self.emb.weight

    # forward pass
    def forward(self, x, y=None):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute rnn hidden states
        x, _ = self.rnn_layers(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,vocab_size)

        if y==None:
            return x

        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,vocab_size)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y, ignore_index=self.padding_idx)
        return x, loss
    
    @torch.no_grad()
    def generate(self, vocab, word2idx, start_token="<s>", end_token="</s>", max_len=30, device="cpu"):
        self.eval()
        # generate one word at a time
        x = torch.full(size=(1,1), fill_value=word2idx[start_token], dtype=torch.long, device=device)
        for _ in range(max_len):
            logits = self.forward(x) # shape: (1,L,V)
            # sample from the distribution for the last word in the sequence
            p = F.softmax(logits[0,-1,:]) # shape: (V,)
            next_word_idx = torch.multinomial(p, num_samples=1)
            if next_word_idx == word2idx[end_token]:
                break
            # append to the sequence
            x = torch.cat((x, next_word_idx.view(1,1)), dim=1)
        self.train()
        
        # convert integer tokens to words
        words = x.view(-1).tolist()
        words = [vocab[w] for w in words[1:]]
        return " ".join(words)

# training loop
def train(model, optimizer, scheduler, train_dataloader, val_dataloader, device="cpu", num_epochs=10):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            logits, loss = model(inputs, targets)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

        scheduler.step()
        train_acc = num_correct / num_total        
        # compute validation loss
        val_loss, val_acc = validation(model, val_dataloader, device=device)

        #if epoch % 5 == 0:
        #    save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def save_model_checkpoint(model, optimizer, epoch=None, loss=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    torch.save(checkpoint, 'rnntagger_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer):
    checkpoint = torch.load('rnntagger_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer      

In [69]:
B = 128
D = 64
num_rnn_layers = 1
learning_rate = 1e-2
DEVICE = "cuda"

model = RNNLM(vocab_size, D, num_rnn_layers=num_rnn_layers).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.9)

#model, optimizer = load_model_checkpoint(model, optimizer)
train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)

num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 0.835445 M
RAM used: 1214.17 MB


In [70]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=10)

Epoch 1, EMA Train Loss: 5.116, Train Accuracy:  0.000, Val Loss:  0.000, Val Accuracy:  0.000: 100%|██████████| 282/282 [00:03<00:00, 80.52it/s]
Epoch 2, EMA Train Loss: 4.803, Train Accuracy:  0.184, Val Loss:  5.027, Val Accuracy:  0.220:  72%|███████▏  | 204/282 [00:02<00:00, 94.11it/s]

In [54]:
model.generate(vocab, word2idx, device=DEVICE)

  p = F.softmax(logits[0,-1,:]) # shape: (V,)


'o warwick and my honour like to chance'