# Recurrent Neural Networks (RNN)

Previously, we described various language models where the conditional probability of token $\boldsymbol{\mathsf{x}}_t$ depends on a fixed context $\boldsymbol{\mathsf{x}}_{[t - \tau: t-1]}.$ If we want to incorporate the possible effect of tokens earlier than the given context, we need to increase the context size $\tau$. For the *n*-gram model, this would increase the parameters exponentially in $\tau$. Using embeddings, the MLP network the number of parameters grows as $O(\tau)$. Finally, using convolutions this decreases to $O(\log \tau).$

Alternatively, instead of modeling the next token directly in terms of previous tokens, we can use a latent variable that, in principle, stores *all* previous information up to the previous time step:

$$
p(\boldsymbol{\mathsf x}_{t} \mid \boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t-1}) \approx p(\boldsymbol{\mathsf x}_{t} \mid \boldsymbol{\mathsf h}_{t-1})
$$

where $\boldsymbol{\mathsf h}_{t-1}$ is a *hidden state* that stores information up to the time step $t - 1.$ The hidden state is updated based on the current input and the previous state: 

$$
\boldsymbol{\mathsf h}_{t} = f(\boldsymbol{\mathsf x}_{t}, \boldsymbol{\mathsf h}_{t-1}),
$$

so that $\boldsymbol{\mathsf h}_{t} = F(\boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t}, \boldsymbol{\mathsf h}_{0})$ for some $\boldsymbol{\mathsf h}_{0}$ where $F$ involves recursively applying $f.$ Note that for a sufficiently powerful function $f$, the latent variable model above is not an approximation, since $\boldsymbol{\mathsf h}_{t}$ can simply store all $\boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t}$ it has observed so far. 

<br>

## Simple RNN

RNNs use the same parameters at each time step, i.e. it is assumed that the dynamics is **stationary**. Practically, this means that the number of parameters does not grow as the sequence length increases.
The following implementation is called **Simple RNN**, the state update is calculated using essentially a linear layer where the embedding and hidden state are concatenated as input. Let tokens correspond to embedding vectors $\boldsymbol{\mathsf{x}}_t \in \mathbb{R}^d$, then

$$
\begin{aligned}
\boldsymbol{\mathsf{h}}_t &= \varphi(\boldsymbol{\mathsf{x}}_t \boldsymbol{\mathsf{U}} + \boldsymbol{\mathsf{h}}_{t-1} \boldsymbol{\mathsf{W}} + \boldsymbol{\mathsf{b}}) \\
\boldsymbol{\mathsf{y}}_t &= \boldsymbol{\mathsf{h}}_t \boldsymbol{\mathsf{V}} + \boldsymbol{\mathsf{c}}
\end{aligned}
$$

where $\boldsymbol{\mathsf{U}} \in \mathbb{R}^{d \times h}$, $\boldsymbol{\mathsf{W}} \in \mathbb{R}^{h \times h}$, and $\boldsymbol{\mathsf{b}} \in \mathbb{R}^{h}.$ Here $h$ is the dimensionality of the hidden state. For the outputs, we also have $\boldsymbol{\mathsf{V}} \in \mathbb{R}^{h \times q}$ and $\boldsymbol{\mathsf{c}} \in \mathbb{R}^{q}$ where $q$ is the dimensionality of the output. This computation can be seen in {numref}`04-simple-rnn`.

Note that $\boldsymbol{\mathsf{x}}_t$ can be one-hot vectors since the matrix $\boldsymbol{\mathsf{U}}$ can act as the embedding matrix for the tokens. In this case, $\boldsymbol{\mathsf{U}}$ has shape $(|\mathcal{V}|, h).$

<br>

```{figure} ../../../img/nn/04-simple-rnn.svg
---
width: 600px
name: 04-simple-rnn
align: center
---
Computational graph of an unrolled simple RNN. [Source](https://www.d2l.ai/chapter_recurrent-neural-networks/rnn.html)
```

First, we implement the recurrent layer. To implement batch computation, an input has shape $(B, T, d).$ That is, a batch of $B$ sequences of length $T$, consisting of vectors in $\mathbb{R}^d.$ Elements of a batch are computed in independently, ideally in parallel. At each step, the layer returns the state vector of shape $(B, h).$ This is stacked to get a tensor of shape $(B, T, h)$ consistent with the input.

In [1]:
import torch
import numpy as np
import torch.nn as nn


class SimpleRNN(nn.Module):
    def __init__(self, dim_inputs, dim_hidden):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.dim_inputs = dim_inputs
        self.W = nn.Parameter(torch.randn(dim_hidden, dim_hidden) / np.sqrt(dim_hidden))
        self.U = nn.Parameter(torch.randn(dim_inputs, dim_hidden) / np.sqrt(dim_inputs))
        self.b = nn.Parameter(torch.zeros(dim_hidden))

    def forward(self, x, state=None):
        x = x.transpose(0, 1)  # (B, T, d) -> (T, B, d)
        T, B, d = x.shape
        assert d == self.dim_inputs
        if state is None:
            state = torch.zeros(B, self.dim_hidden, device=x.device)
        else:
            assert state.shape == (B, self.dim_hidden)

        outs = []
        for t in range(T):
            state = torch.tanh(x[t] @ self.U + state @ self.W + self.b)
            outs.append(state)

        outs = torch.stack(outs)
        outs = outs.transpose(0, 1)
        return outs, state

Shapes test:

In [2]:
h = 5
B, T, d = 32, 10, 512
rnn = SimpleRNN(dim_inputs=d, dim_hidden=h)
outs, state = rnn(torch.randn(B, T, d))
assert outs.shape == (B, T, h)
assert state.shape == (B, h)
assert torch.abs(outs[:, -1, :] - state).max() < 1e-8