#### RNN Language Model

Given a sequence of word embedding vectors $w_1, .., w_N$, we pass the sequence through a `uni-directional RNN` to obtain a sequence of hidden states $h_1,..,h_N$. Then note that each hidden state $h_i$ can be regarded as a contextual repsentation of the words $w_1, .., w_i$. So using this hidden state, we can compute a probability distribution for the next word that follows all the preceding words:

$P(w_{i+1} | w_1,...,w_i) = f(h_i)$

where $f$ is a function that transforms $h_i$ into the probability distribution. $f$ can be a feedforward network, in the simplest case a linear projection followed by a softmax. Also note that we use a uni-directional RNN (and not bi-directional) because for a language model, we want to predict the next word using only the previous words as context.

Optional: The performance of an RNN model can be further improved if we choose the embedding dimensions and the RNN hidden state dimensions to be the same. This allows us to then re-use the embedding matrix to perform the linear projection of the hidden states into the output logits instead of using a separate projection matrix and therefore saves a lot of extra parameters and potentially reduces overfitting. This technique is also called `weight tying`. 

Previously we looked at simple n-gram language models which are only feasilble for small $n$, i.e. shorter context size. With an RNN, we have access to much larger contexts and therefore we can get better performance (e.g. lower perplexity compared to n-gram LMs).

We will train a word-level RNN LM on the collected works of Shakespeare.

In [1]:
import random, math
from nltk.tokenize import word_tokenize
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import psutil
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtanzids[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# prep the training data
with open('shakespeare.txt', 'r') as file:
    lines = file.readlines()

In [3]:
# word_tokenize the sentences (split on whitespaces) and add start and end sentence tokens, keep punctuations as individual tokens
start_token = '<s>'        
end_token = '</s>' 
sentences_tokenized = [[start_token]+word_tokenize(s.lower())+[end_token] for s in lines]
print(f"Num sentences: {len(sentences_tokenized)}")    

# now we split the data into train and test sentences
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)

sentences_train = sentences_tokenized[:-num_test]
sentences_val = sentences_tokenized[-num_test:]

print(f"Number of training sentences: {len(sentences_train)}")        
print(f"Number of test sentences: {len(sentences_val)}")   

Num sentences: 40000
Number of training sentences: 36000
Number of test sentences: 4000


In [4]:
# concatenate all sentences
corpus_train = [w for s in sentences_train for w in s] 
corpus_val = [w for s in sentences_val for w in s] 

# create vocabulary
pad_token = "<PAD>"
vocab = [pad_token] + sorted(list(set([w for s in sentences_tokenized for w in s])))
word2idx = {w:i for i,w in enumerate(vocab)}
vocab_size = len(vocab)

# tokenize the corpus
x_train = [word2idx[word] for word in corpus_train]
x_val = [word2idx[word] for word in corpus_val]

#### Create a pytorch dataset

In [5]:
class Shakespeare(Dataset):
    def __init__(self, corpus, block_size=16):
        self.corpus = corpus
        self.block_size = block_size

    def __len__(self):
        return len(self.corpus)-self.block_size-1

    def __getitem__(self, idx):
        inputs = torch.tensor(self.corpus[idx:idx+self.block_size], dtype=torch.long)
        targets = torch.tensor(self.corpus[idx+1:idx+1+self.block_size], dtype=torch.long)
        return inputs, targets

#### Create the RNN LM

In [13]:
class RNNLM(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=32, num_rnn_layers=1, dropout_rate=0.1):
        super().__init__()

        # embedding layer
        self.emb = torch.nn.Embedding(vocab_size, embedding_dim)
        c = 0.1        
        torch.nn.init.uniform_(self.emb.weight, -c, c)

        # create rnn layers (we will use bidirectional LSTM so the output hidden states will have dims=2*hidden_dims)
        if num_rnn_layers == 1:
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True)
        else:    
            self.rnn_layers = torch.nn.LSTM(input_size=embedding_dim, hidden_size=embedding_dim, num_layers=num_rnn_layers, batch_first=True, dropout=dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(embedding_dim, vocab_size)
        # tie the output layer weights with the embedding layer weights
        self.output_layer.weight = self.emb.weight

    # forward pass
    def forward(self, x, y=None):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute rnn hidden states
        x, _ = self.rnn_layers(x) # shape: (B,L,D)
        # apply dropout
        x = self.dropout(x)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,vocab_size)

        if y==None:
            return x

        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,vocab_size)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y)
        return x, loss
    
    @torch.no_grad()
    def generate(self, vocab, word2idx, temperature=1.0, topk=None, start_token="<s>", end_token="</s>", max_len=30, device="cpu"):
        self.eval()
        # generate one word at a time
        x = torch.full(size=(1,1), fill_value=word2idx[start_token], dtype=torch.long, device=device)
        for _ in range(max_len):
            logits = self.forward(x) # shape: (1,L,V)
            # rescale the logits with the temperature
            logits = logits / temperature
            if topk is not None:
                topk_logits, idx = torch.sort(logits[0,-1,:], descending=True)
                # sample from the distribution for the last word in the sequence
                p = F.softmax(topk_logits, dim=-1) # shape: (V,)
                next_word_idx = idx[torch.multinomial(p, num_samples=1)]
            else:             
                # sample from the distribution for the last word in the sequence
                p = F.softmax(logits[:,-1,:], dim=-1) # shape: (V,)
                next_word_idx = torch.multinomial(p, num_samples=1)
            # append to the sequence
            x = torch.cat((x, next_word_idx.view(1,1)), dim=1)
        # convert integer tokens to words
        words = x.view(-1).tolist()
        words = [vocab[w] for w in words[1:]]
        # remove <s> tokens and replace </s> tokens with "\n"
        sent = []
        for w in words:
            if w != start_token:
                if w != end_token:
                    sent.append(w)
                else:
                    sent.append("\n")    

        sent= " ".join(sent) 
        
        self.train()

        return sent

# training loop
def train(model, optimizer, scheduler, train_dataloader, val_dataloader, device="cpu", num_epochs=10, val_every=1, save_every=10, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    pp = 0
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            logits, loss = model(inputs, targets)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}, Val Perplexity: {pp:.1f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        scheduler.step()
        train_acc = num_correct / num_total        
        if epoch%val_every == 0:
            # compute validation loss
            val_loss, val_acc, pp = validation(model, val_dataloader, device=device)

        if (epoch+1) % save_every == 0:
            save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    perplexity = math.exp(val_loss)
    return val_loss, val_accuracy, perplexity


def save_model_checkpoint(model, optimizer, epoch=None, loss=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    torch.save(checkpoint, 'rnntagger_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer):
    checkpoint = torch.load('rnntagger_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer      

In [7]:
print(len(x_train))
print(len(x_val))

303695
31100


In [29]:
#train_dataset = Shakespeare(x_train, block_size=64)
#val_dataset = Shakespeare(x_val, block_size=64)

train_dataset = Shakespeare(x_train[3*65536:], block_size=16)
val_dataset = Shakespeare(x_val[:16384], block_size=16)

In [30]:
#model, optimizer = load_model_checkpoint(model, optimizer)
train_dataloader = DataLoader(train_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)


In [14]:
B = 32
D = 256
num_rnn_layers = 5
learning_rate = 1e-3
DEVICE = "cuda"

model = RNNLM(vocab_size, D, num_rnn_layers=num_rnn_layers, dropout_rate=0.5).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.95)
model, optimizer = load_model_checkpoint(model, optimizer)


num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Loaded model from checkpoint!
Total number of parameters in transformer network: 5.803317 M
RAM used: 1066.91 MB


In [None]:
"""
# create a W&B run
run = wandb.init(
    project="RNN_shakespeare", 
    config={
        "learning_rate": learning_rate, 
        "epochs": 100,
        "batch_size": B, 
        "emb_dim": D,
        "num_rnn_layers" : num_rnn_layers,
        "corpus": "Shakespeare"},)   

def log_metrics(metrics):
    wandb.log(metrics)
"""

In [None]:
for g in optimizer.param_groups:
    g['lr'] = 0.0005

In [31]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=20, save_every=10, val_every=5) #, log_metrics=log_metrics)

Epoch 1, EMA Train Loss: 3.131, Train Accuracy:  0.000, Val Loss:  0.000, Val Accuracy:  0.000, Val Perplexity: 0.0: 100%|██████████| 3346/3346 [00:34<00:00, 96.20it/s] 
Epoch 2, EMA Train Loss: 3.007, Train Accuracy:  0.421, Val Loss:  4.405, Val Accuracy:  0.364, Val Perplexity: 81.8: 100%|██████████| 3346/3346 [00:34<00:00, 97.27it/s] 
Epoch 3, EMA Train Loss: 2.876, Train Accuracy:  0.431, Val Loss:  4.405, Val Accuracy:  0.364, Val Perplexity: 81.8: 100%|██████████| 3346/3346 [00:34<00:00, 97.50it/s] 
Epoch 4, EMA Train Loss: 2.762, Train Accuracy:  0.436, Val Loss:  4.405, Val Accuracy:  0.364, Val Perplexity: 81.8: 100%|██████████| 3346/3346 [00:34<00:00, 97.54it/s] 
Epoch 5, EMA Train Loss: 2.706, Train Accuracy:  0.442, Val Loss:  4.405, Val Accuracy:  0.364, Val Perplexity: 81.8: 100%|██████████| 3346/3346 [00:35<00:00, 93.36it/s]
Epoch 6, EMA Train Loss: 2.588, Train Accuracy:  0.447, Val Loss:  4.405, Val Accuracy:  0.364, Val Perplexity: 81.8: 100%|██████████| 3346/3346 [0

Saved model checkpoint!


Epoch 11, EMA Train Loss: 2.469, Train Accuracy:  0.468, Val Loss:  4.611, Val Accuracy:  0.359, Val Perplexity: 100.6: 100%|██████████| 3346/3346 [00:36<00:00, 90.63it/s]
Epoch 12, EMA Train Loss: 2.431, Train Accuracy:  0.470, Val Loss:  4.812, Val Accuracy:  0.354, Val Perplexity: 123.0: 100%|██████████| 3346/3346 [00:36<00:00, 91.83it/s]
Epoch 13, EMA Train Loss: 2.442, Train Accuracy:  0.473, Val Loss:  4.812, Val Accuracy:  0.354, Val Perplexity: 123.0: 100%|██████████| 3346/3346 [00:36<00:00, 91.52it/s]
Epoch 14, EMA Train Loss: 2.512, Train Accuracy:  0.475, Val Loss:  4.812, Val Accuracy:  0.354, Val Perplexity: 123.0: 100%|██████████| 3346/3346 [00:37<00:00, 89.21it/s]
Epoch 15, EMA Train Loss: 2.445, Train Accuracy:  0.476, Val Loss:  4.812, Val Accuracy:  0.354, Val Perplexity: 123.0: 100%|██████████| 3346/3346 [00:37<00:00, 88.41it/s]
Epoch 16, EMA Train Loss: 2.458, Train Accuracy:  0.479, Val Loss:  4.812, Val Accuracy:  0.354, Val Perplexity: 123.0: 100%|██████████| 334

Saved model checkpoint!


In [None]:
# Mark the run as finished
#wandb.finish()



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Batch loss,█▅▄▄▃▃▂▂▃▂▂▃▂▂▂▂▁▂▂▂▂▁▁▂▁▂▁▂▂▂▂▂▂▁▁▂▂▂▂▂
Moving Avg Loss,█▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,▁███████████████████████████████████████

0,1
Batch loss,3.32277
Moving Avg Loss,3.38491
Val Loss,4.54387


In [35]:
s = model.generate(vocab, word2idx, temperature=0.9, topk=None, device=DEVICE, max_len=500)
print(s)

there is some sap to the deputy . 
 
 escalus : 
 sir , my lord , 
 if your law be so murdering days his land , 
 that i have seen him last , but , as she is , 
 am there as monstrous as you take off our right 
 is pawn 'd his swelling eye , 
 the first mould of life , you have not answer 'd 
 for his own old lord and well-warranted man , 
 drest as good as damask roses ; 
 masks too weak to fear her brother 's ghost , 
 that is not the sweet apollo 's son . 
 
 florizel : 
 come , sir ; here 's no man the rest , 
 that thou neglect him not , and i do thou hast come by 
 him ; and , or repent to your father , nor as 
 i am no children , i am glad i am a courtier : i 
 have heard her in their silent affairs 
 and with the witness of that hath access too much 
 all foolery like language . but his head to the 
 speech of this reason : if i should be 
 false , that for the other earth is so green , 
 perfume in the end . 
 
 paulina : 
 ha ! undone ! 
 say a man witch : i am not well pence