# 4. Recurrent Neural Networks

In [89]:
import torch 
import torch.nn as nn
from torch.utils import data
from torch.nn import functional as F

import re
import collections

import math
import random
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [35]:
from d2l import torch as d2l

## Hidden States

When using **n-grams** in language modelling, we are modelling:

$$P(x_t \mid x_{t-1}, \ldots, x_{t-n+1})$$

However, the **number of parameters** increaase exponentially with $n$, where we need to store **$|\mathcal{V}|^n$ numbers** for a vocabulary set $\mathcal{V}$.

Instead, we can use a **`hidden state`** $h_{t-1}$ which stores the **sequence information** up to time step $t-1$:

$$P(x_t \mid x_{t-1}, \ldots, x_1) \approx P(x_t \mid h_{t-1})$$

Generally, the hidden state at any time step $h_t$ could be computed based on both the **current input $x_t$** and the **previous hidden state $h_{t-1}$**:

$$h_t = f(x_{t}, h_{t-1})$$

Given that the function $f$ is **sufficiently powerful**, the above is not an approximation while reducing both the **costs of computation and storage**.

## Neural Networks without Hidden States

Let's consider a MLP with a single hidden layer.

With the **activation function $\phi$** and **mini-batch $\mathbf{X} \in \mathbb{R}^{n \times d}$**, the **output of the hidden layer** is given as:

$$\mathbf{H} = \phi(\mathbf{X} \mathbf{W}_{xh} + \mathbf{b}_h)$$

where we have the **weigths $\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}$** and **bias $\mathbf{b}_h \in \mathbb{R}^{1 \times h}$**. ($h$ is number of hidden units)

Then, the **final output $\mathbf{O} \in \mathbb{R}^{n \times q}$** of the model is given as:

$$\mathbf{O} = \mathbf{H} \mathbf{W}_{hq} + \mathbf{b}_q$$

where we have the **weigths $\mathbf{W}_{hq} \in \mathbb{R}^{h \times q}$** and **bias $\mathbf{b}_q \in \mathbb{R}^{1 \times q}$**.

## Neural Networks with Hidden States

Let's consider that at time $t$, we have a **mini-batch input $\mathbf{X}_t \in \mathbb{R}^{n \times d}$**, where each row of $\mathbf{X}_t$ is a sample from the sequence at time $t$.

Now, by storing the **hidden layer output $\mathbf{H}_{t-1}$** from the **previous time step** and introducing a **new weight parameter $\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}$**, we can compute the hidden layer output of the **current time step** as:


$$\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh}  + \mathbf{b}_h)$$

As suggested in the above relationship, the **hidden states** captured and retained the sequence's historical information up to the current time step.

Such computation is defined for all the hidden states and therefore is **recurrent**. Hence, neural networks with hidden states are called **recurrent neural neworks** and the layers performing these computations in RNNs are called **recurrent layers**.

## Recurrent Neural Networks

Given the **recurrent computation** of a RNN:

$$\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh}  + \mathbf{b}_h)$$

The **output** is then:

$$\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{hq} + \mathbf{b}_q$$

**Parameters** of a RNN include the **weights $\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}, \mathbf{W}_{hh} \in \mathbb{R}^{h \times h}$ and bias $\mathbf{b}_h \in \mathbb{R}^{1 \times h}$** of the hidden layers, and the **weights $\mathbf{W}_{hq} \in \mathbb{R}^{h \times q}$ and bias $\mathbf{b}_q \in \mathbb{R}^{1 \times q}$** for the output layer.

It is worth mentioning that even at **different time steps**, RNNs always use these model parameters. Therefore, the **parameterization cost** of an RNN does not grow as the number of time steps increases.

The following diagram demonstrates the computational logic of **3 connected recurrent layers**:

![](http://d2l.ai/_images/rnn.svg)

## Recurrent Computation

The recurrent computation involves $\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh}$ which can be proved as equivalent to the matrix multiplication between concatination $\mathbf{X}_t$ and $\mathbf{H}_{t-1}$ and the concatenation of $\mathbf{W}_{xh}$ and $\mathbf{W}_{hh}$:

In [12]:
X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))

In [13]:
torch.matmul(X, W_xh) + torch.matmul(H, W_hh)

tensor([[ 1.0326,  1.2618,  4.0260, -1.4869],
        [ 1.5798,  0.6797, -2.0364,  1.6896],
        [-0.3222,  0.1267,  1.0435,  0.6953]])

In [14]:
torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

tensor([[ 1.0326,  1.2618,  4.0260, -1.4869],
        [ 1.5798,  0.6797, -2.0364,  1.6896],
        [-0.3222,  0.1267,  1.0435,  0.6953]])

## RNN-based Character-Level Language Models

In language models, the goal is to **predict the next token** based on the current and previous tokens. Therefore, we **shift** the input sequences by one unit to obtaint the **labels**. 

Let's consider applying a **character-level language model** on the sequence **"machine"** with **batch size** of 1:

![](http://d2l.ai/_images/rnn-train.svg)

During the **training** process, we run a **softmax** operation on the **output** from the output layer for each time step, and then use the **cross-entropy loss** to compute the error between the model output and the target. 

In practice, each token is represented by a **$d$-dimensional vector**, and we use a **batch size** $n>1$. Therefore, the input $\mathbf X_t$ at time step $t$ will be a **$n\times d$ matrix**.

## Perplexity

Now, let's talk about how to **evaluate** a language model.

To evaluate the **predictive ability** of a language model in predicting the **next token** based on the current and previous token, we can use the **average cross-entropy** of all the $n$ tokens in a sequence:

$$\frac{1}{n} \sum_{t=1}^n -\log P(x_t \mid x_{t-1}, \ldots, x_1)$$

where the **distribution $P$** is given by the model and $x_t$ is the **observed token** at time $t$.

In NLP, instead of cross-entropy, we usually use an evaluation metric called **perplexity**:

$$\exp\left(-\frac{1}{n} \sum_{t=1}^n \log P(x_t \mid x_{t-1}, \ldots, x_1)\right)$$

Perplexity can be best understood as the **geometric mean** of the number of **real choices** that we have when deciding which token to pick next. 

Let’s look at a number of cases:

1. In the **best** case scenario, the model always perfectly estimates the probability of the target token as 1. In this case the perplexity of the model is **1**.
2. In the **worst** case scenario, the model always predicts the probability of the target token as 0. In this situation, the perplexity is **positive infinity**.
3. At the **baseline**, the model predicts a uniform distribution over all the available tokens of the vocabulary. In this case, the perplexity equals the **number of unique tokens** of the vocabulary. In fact, this provides a nontrivial **upper bound** that any useful model must beat.

## Implementing RNNs

We first load the dataset:

In [68]:
d2l.DATA_HUB['time_machine'] = (d2l.DATA_URL + 'timemachine.txt', '090b5e7e70c295757f55df93cb0a180b9691891a')

def read_time_machine(): 
    with open(d2l.download('time_machine'), 'r') as f:
        lines = f.readlines()
    return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

def tokenize(lines, token='word'):
    if token=='word':
        return [line.split() for line in lines]
    elif token=='char':
        return [list(line) for line in lines]
    else:
        print('Error: wrong type of tokens.')
        
def count_corpus(tokens): 
    if len(tokens) == 0 or isinstance(tokens[0], list):
        tokens = [token for line in tokens for token in line]
    return collections.Counter(tokens)

class Vocab:  
    
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        
        if tokens is None:
            tokens = []
        if reserved_tokens is None:
            reserved_tokens = []
            
        counter = count_corpus(tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        
        self.idx_to_token = ['<unk>'] + reserved_tokens
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
        
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self): 
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

def load_corpus_time_machine(max_tokens=-1):  
    lines = read_time_machine()
    tokens = tokenize(lines, 'word')
    vocab = Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens > 0:
        corpus = corpus[:max_tokens]
    return corpus, vocab

def seq_data_iter_sequential(corpus, batch_size, num_steps):
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus)-offset-1) // batch_size) * batch_size
    Xs = torch.tensor(corpus[offset: offset+num_tokens])
    Ys = torch.tensor(corpus[offset+1: offset+1+num_tokens])
    Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)
    num_batches = Xs.shape[1] // num_steps
    for i in range(0, num_steps * num_batches, num_steps):
        X = Xs[:, i:i+num_steps]
        Y = Ys[:, i:i+num_steps]
        yield X, Y

class SeqDataLoader:  
    
    def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
        if use_random_iter:
            self.data_iter_fn = seq_data_iter_random
        else:
            self.data_iter_fn = seq_data_iter_sequential
        self.corpus, self.vocab = load_corpus_time_machine(max_tokens)
        self.batch_size, self.num_steps = batch_size, num_steps

    def __iter__(self):
        return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)
    
def load_data_time_machine(batch_size, num_steps, use_random_iter=False, max_tokens=10000):
    data_iter = SeqDataLoader(batch_size, num_steps, use_random_iter, max_tokens)
    return data_iter, data_iter.vocab

In [69]:
batch_size, num_steps = 32, 35
train_iter, vocab = load_data_time_machine(batch_size, num_steps)

Now, we define the **model**:

In [70]:
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

Then, we initialize the **hidden states**:

In [71]:
state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

Note that the **rnn_layer** does not computes the output but only returns the hidden states.

(**Y** represents the all the **hidden states** at the 35 time steps.)

In [72]:
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
X.shape

torch.Size([35, 32, 4580])

In [73]:
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

We need to define a class for the **complete RNN model**:

In [74]:
class RNNModel(nn.Module):

    def __init__(self, rnn_layer, vocab_size, **kwargs):
        
        super(RNNModel, self).__init__(**kwargs)
        
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size

        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            return (torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens), device=device),
                    torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens), device=device))

In [75]:
device = torch.device('mps')
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)

Let's have a look what the model would give us before training:

In [108]:
def predict(prefix, num_preds, net, vocab, device): 
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))
    for y in prefix[1:]: 
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [109]:
predict('time traveller', 10, net, vocab, device)

'tim<unk><unk>t<unk>av<unk><unk><unk><unk><unk>aaiandaaandiaa'

Now, we **train** the model:

In [110]:
def grad_clipping(net, theta): 
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [111]:
def train_epoch(net, train_iter, loss, updater, device, use_random_iter):
    state, timer = None, d2l.Timer()
    metric = d2l.Accumulator(2)  
    for X, Y in train_iter:
        if state is None or use_random_iter:
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(net, nn.Module) and not isinstance(state, tuple):
                state.detach_()
            else:
                for s in state:
                    s.detach_()
        y = Y.T.reshape(-1)
        X, y = X.to(device), y.to(device)
        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            grad_clipping(net, 1)
            updater.step()
        else:
            l.backward()
            grad_clipping(net, 1)
            updater(batch_size=1)
        metric.add(l * y.numel(), y.numel())
    return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()

In [112]:
def train(net, train_iter, vocab, lr, num_epochs, device, use_random_iter=False):
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',
                            legend=['train'], xlim=[10, num_epochs])
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict_f = lambda prefix: predict(prefix, 50, net, vocab, device)
    for epoch in range(num_epochs):
        ppl, speed = train_epoch(
            net, train_iter, loss, updater, device, use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict_f('time traveller'))
            animator.add(epoch + 1, [ppl])
    print(f'Perplexity {ppl:.1f}, {speed:.1f} tokens/sec {str(device)}')
    print(predict_f('time traveller'))
    print(predict_f('traveller'))

In [None]:
num_epochs, lr = 500, 1
train(net, train_iter, vocab, lr, num_epochs, device)

## Backpropagation Through Time

Let's have a look at the **gradients** in RNNs:

$$\frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}$$

When the time step **$t$ gets larger**, the above equation gets longer, which might lead to **gradient vanishing** or **gradient exploding**.

There is a few strategies available:

> **Full Computation**: slow training, might cause gradient exploding, bad generalization

> **Truncating Time Steps**: truncate the sum after $\tau$ steps, giving an estimate $\frac{\partial h_{t-\tau}}{\partial w_h}$ of the true gradient, simple and stable, focuses on short-term effects, practically feasible

> **Randomized Truncation**: replace the gradient with a random variable $z_t$ that truncates the sequence randomly,
>$$z_t= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}$$
>where $E[z_t] = \frac{\partial h_{t-\tau}}{\partial w_h}$, $E[\xi_t] = 1$, and $P(\xi_t = 0) = 1-\pi_t$, $P(\xi_t = \pi_t^{-1}) = \pi_t$. 
>The summation ends whenever $\xi_t = 0$. This results in a weighted sum of sequences with different lengths.

![](http://d2l.ai/_images/truncated-bptt.svg)

The above diagram demonstrates the 3 strategies: randomized truncation, regular truncation, full computation.

Even though **randomized truncation** looks good, **regular truncation** works better in practice for the following reasons:

1. The effect of an observation after a number of backpropagation steps into the past is sufficient to **capture dependencies** in practice.
2. The **increase in variance** counteracts the fact that the gradient is more accurate with more steps.
3. We want models that have only a **short range of interactions**.
4. Regularly truncated backpropagation through time has a **slight regularizing effect** that can be desirable.