#### RNN-based Sequence Labeller

We've seen how to implement an HMM bigram POS tagger and saw that it can achieve over 90% test accuracy on the Stanford treebank. The main limitation of this model is the limited context size (bigram context) which is used for making the sequence label predictions. Naively extending the algorithm to handle larger n-gram contexts may result in exponential increase in memory consumption (for vocab size $|V|$, the number of possible n-grams is on the order of $|V|^n$). A more natural way to handle larger contexts is using a `recurrent neural network (RNN)` to perform the sequence labelling task, where we use a sequence of pre-trained word embeddings as input and predict POS tags corrresponding to each word. We will experiment with some simple RNN architectures for this task and compare performance with the unigram and Viterbi bigram taggers.

In [16]:
import gensim.downloader as api
import nltk
from nltk.corpus import treebank
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import psutil


In [2]:
# first download pretrained GloVe embeddings (api.load() returns a 'KeyedVectors' object)
# we're getting the 50 dimensional words vectors, 400k vocab size
embeddings = api.load("glove-wiki-gigaword-50")

In [3]:
# get the POS tagged corpus, 3914 tagged sentences
corpus = treebank.tagged_sents()
print("Number of sentences: ", len(corpus))
print(f"Longest sentence length: {max([len(s) for s in corpus])}")


# lets get the vocabulary and tag set
pad_token = "<PAD>"
vocab = [pad_token] + sorted(list(set([elem[0] for s in corpus for elem in s])))
vocab_size = len(vocab)
start_tag = "<s>"
tags = [start_tag] + sorted(list(set([elem[1] for s in corpus for elem in s])))
num_tags = len(tags)

word2idx = {w:i for i,w in enumerate(vocab)}
tag2idx = {t:i for i,t in enumerate(tags)}

print(f"Vocab size: {vocab_size}")
print(tags)

Number of sentences:  3914
Longest sentence length: 271
Vocab size: 12409
['<s>', '#', '$', "''", ',', '-LRB-', '-NONE-', '-RRB-', '.', ':', 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', '``']


In [4]:
# now lets get the embeddings for all words in our vocab
embedding_dim = 50
embedding_vectors = np.zeros(shape=(vocab_size,embedding_dim))

oov_words = []
for i, word in enumerate(vocab):
    if word.lower() in embeddings.key_to_index:
        embedding_vectors[i] = embeddings[word.lower()]
    else:
        #print(f"'{word.lower()}' not in GloVe vocab!")    
        # if the word is hyphenated, then split it and see if the sub-words have embeddings
        if "-" in word:
            split_hyphen = word.lower().split("-")
            # compute average word embedding across split words
            emb = np.zeros(shape=(embedding_dim))
            found = False
            for i, w in enumerate(split_hyphen):
                if w in embeddings.key_to_index:
                    emb += embeddings[w]
                    found = True
            if found:        
                emb += (len(split_hyphen)-1) * embeddings["-"]
                emb = emb / (2*len(split_hyphen)-1)
                embedding_vectors[i] = emb   
            else:
                oov_words.append(word)     

print(f"Number of words for which embeddings not available: {len(oov_words)}")        

Number of words for which embeddings not available: 0


Tokenize sentences and tags to get the inputs and targets

In [5]:
x = [[word2idx[word] for word,tag in s] for s in corpus]
y = [[tag2idx[tag] for word,tag in s] for s in corpus]

Create train-validation splits

In [6]:
num_train = int(0.9 * len(x))
x_train, y_train = x[:num_train], y[:num_train]
x_val, y_val = x[num_train:], y[num_train:]

print(f"Longest train sentence length: {max([len(s) for s in x_train])}")
print(f"Longest val sentence length: {max([len(s) for s in x_val])}")


Longest train sentence length: 271
Longest val sentence length: 58


Create pytorch dataset

In [7]:
class Treebank(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = pad_sequence([torch.tensor(x, dtype=torch.long) for x in inputs], batch_first=True, padding_value=0)
        self.targets = pad_sequence([torch.tensor(y, dtype=torch.long) for y in targets], batch_first=True, padding_value=-1)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [8]:
train_dataset = Treebank(x_train, y_train)
val_dataset = Treebank(x_val, y_val)

#### Now let's create our RNN model

In [45]:
class RNNTagger(torch.nn.Module):
    def __init__(self, vocab_size, num_tags, embedding_dims, pretrained_embeddings, num_rnn_layers=1, hidden_dims=64, dropout_rate=0.1, padding_idx=-1):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab_size, embedding_dims)        
        # intialize with pretrained embedding vectors
        self.emb.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        # freeze the embedding layer (i.e. make non-trainable)
        self.emb.weight.requires_grad = False
        # create rnn layers (we will use bidirectional LSTM so the output hidden states will have dims=2*hidden_dims)
        self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dims, num_layers=num_rnn_layers, bidirectional=True, batch_first=True)
        self.dropout = torch.nn.Dropout(dropout_rate)
        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(2*hidden_dims, num_tags)
        self.padding_idx = padding_idx

    # forward pass
    def forward(self, x, y):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # compute rnn hidden states
        x, _ = self.rnn_layers(x) # shape: (B,L,2*H)
        # apply dropout
        x = self.dropout(x)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,num_tags)
        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,num_tags)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y, ignore_index=self.padding_idx)

        return x, loss
        

# training loop
def train(model, optimizer, train_dataloader, val_dataloader, device="cpu", num_epochs=10, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            logits, loss = model(inputs, targets)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

        train_acc = num_correct / num_total        
        # compute validation loss
        val_loss, val_acc = validation(model, val_dataloader, device=device)

        #if epoch % 5 == 0:
        #    save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def save_model_checkpoint(model, optimizer, epoch=None, loss=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    torch.save(checkpoint, 'rnntagger_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer):
    checkpoint = torch.load('rnntagger_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer

In [46]:
B = 128
H = 64
num_rnn_layers = 1
learning_rate = 1e-4
DEVICE = "cuda"

model = RNNTagger(vocab_size, num_tags, embedding_dim, embedding_vectors, hidden_dims=H, num_rnn_layers=num_rnn_layers).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#model, optimizer = load_model_checkpoint(model, optimizer)
train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)


num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 0.685905 M
RAM used: 1248.61 MB


In [47]:
train(model, optimizer, train_dataloader, val_dataloader, device=DEVICE, num_epochs=200)

Epoch 1, EMA Train Loss: 3.611, Train Accuracy:  0.000, Val Loss:  0.000, Val Accuracy:  0.000: 100%|██████████| 28/28 [00:00<00:00, 56.83it/s]
Epoch 2, EMA Train Loss: 3.724, Train Accuracy:  0.049, Val Loss:  3.790, Val Accuracy:  0.108: 100%|██████████| 28/28 [00:00<00:00, 97.72it/s] 
Epoch 3, EMA Train Loss: 3.625, Train Accuracy:  0.113, Val Loss:  3.699, Val Accuracy:  0.117: 100%|██████████| 28/28 [00:00<00:00, 104.25it/s]
Epoch 4, EMA Train Loss: 3.454, Train Accuracy:  0.119, Val Loss:  3.578, Val Accuracy:  0.117: 100%|██████████| 28/28 [00:00<00:00, 117.22it/s]
Epoch 5, EMA Train Loss: 3.235, Train Accuracy:  0.120, Val Loss:  3.364, Val Accuracy:  0.119: 100%|██████████| 28/28 [00:00<00:00, 115.50it/s]
Epoch 6, EMA Train Loss: 3.094, Train Accuracy:  0.128, Val Loss:  3.154, Val Accuracy:  0.142: 100%|██████████| 28/28 [00:00<00:00, 117.83it/s]
Epoch 7, EMA Train Loss: 3.022, Train Accuracy:  0.153, Val Loss:  2.991, Val Accuracy:  0.199: 100%|██████████| 28/28 [00:00<00:00