# Recurrent Neural Network Implementation from Scratch
:label:`sec_rnn-scratch`

We are now ready to implement an RNN from scratch.
In particular, we will train this RNN to function
as a character-level language model
(see :numref:`sec_rnn`)
and train it on a corpus consisting of 
the entire text of H. G. Wells' *The Time Machine*,
following the data processing steps 
outlined in :numref:`sec_text-sequence`.
We start by loading the dataset.


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
import optax
from flax import nnx

## RNN Model

We begin by defining a class 
to implement the RNN model
(:numref:`subsec_rnn_w_hidden_states`).
Note that the number of hidden units `num_hiddens` 
is a tunable hyperparameter.

[**The `forward` method below defines how to compute 
the output and hidden state at any time step,
given the current input and the state of the model
at the previous time step.**]
Note that the RNN model loops through 
the outermost dimension of `inputs`,
updating the hidden state 
one time step at a time.
The model here uses a $\tanh$ activation function (:numref:`subsec_tanh`).

In [20]:
class RNNScratch(nnx.Module): 
	"""The RNN model implemented from scratch."""
	def __init__(self, num_inputs: int, num_hiddens: int, rngs: nnx.Rngs):
		self.W_xh = nnx.Param(nnx.initializers.he_normal()(rngs(), (num_inputs, num_hiddens)))
		self.W_hh = nnx.Param(nnx.initializers.he_normal()(rngs(), (num_hiddens, num_hiddens)))
		self.b_h = nnx.Param(nnx.initializers.zeros_init()(rngs(), (num_hiddens,)))
		self.num_hiddens = num_hiddens

	def __call__(self, inputs, state=None):
		if state is not None:
			state, = state
		outputs = []
		for X in inputs:  # Shape of inputs: (num_steps, batch_size, num_inputs)
			state = jnp.tanh(jnp.matmul(X, self.W_xh) + (
				jnp.matmul(state, self.W_hh) if state is not None else 0)
											+ self.b_h)
			outputs.append(state)
		return outputs, state

We can feed a minibatch of input sequences into an RNN model as follows.


In [28]:
# batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
batch_size, num_inputs, num_hiddens, num_steps = 2, 2, 3, 2
rnn = RNNScratch(num_inputs, num_hiddens, rngs=nnx.Rngs(0))

In [29]:
X = jnp.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)

In [30]:
outputs

[Array([[ 0.9977814, -0.7699115,  0.8565357],
        [ 0.9977814, -0.7699115,  0.8565357]], dtype=float32),
 Array([[ 0.9995162 , -0.9408712 ,  0.42814836],
        [ 0.9995162 , -0.9408712 ,  0.42814836]], dtype=float32)]

In [31]:
state

Array([[ 0.9995162 , -0.9408712 ,  0.42814836],
       [ 0.9995162 , -0.9408712 ,  0.42814836]], dtype=float32)

Let's check whether the RNN model
produces results of the correct shapes
to ensure that the dimensionality 
of the hidden state remains unchanged.


In [32]:
def check_len(a, n):  #@save
    """Check the length of a list."""
    assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'

def check_shape(a, shape):  #@save
    """Check the shape of a tensor."""
    assert a.shape == shape, \
            f'tensor\'s shape {a.shape} != expected shape {shape}'

check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))

## RNN-Based Language Model

The following `RNNLMScratch` class defines 
an RNN-based language model,
where we pass in our RNN 
via the `rnn` argument
of the `__init__` method.
When training language models, 
the inputs and outputs are 
from the same vocabulary. 
Hence, they have the same dimension,
which is equal to the vocabulary size.
Note that we use perplexity to evaluate the model. 
As discussed in :numref:`subsec_perplexity`, this ensures 
that sequences of different length are comparable.


In [None]:
class RNNLMScratch(nnx.Module):  #@save
	"""The RNN-based language model implemented from scratch."""

	def __init__(self, rnn: nnx.Module, vocab_size: int, rngs: nnx.Rngs):
		self.W_hq = nnx.Param(nnx.initializers.he_normal()(rngs(), (rnn.num_hiddens, vocab_size)))
		self.b_q = nnx.Param(nnx.initializers.zeros_init()(rngs(), (vocab_size,)))
		self.vocab_size = vocab_size
	
	def one_hot(self, X):
    # Output shape: (num_steps, batch_size, vocab_size)
		return jax.nn.one_hot(X.T, self.vocab_size)

### [**One-Hot Encoding**]

Recall that each token is represented 
by a numerical index indicating the
position in the vocabulary of the 
corresponding word/character/word piece.
You might be tempted to build a neural network
with a single input node (at each time step),
where the index could be fed in as a scalar value.
This works when we are dealing with numerical inputs 
like price or temperature, where any two values
sufficiently close together
should be treated similarly.
But this does not quite make sense. 
The $45^{\textrm{th}}$ and $46^{\textrm{th}}$ words 
in our vocabulary happen to be "their" and "said",
whose meanings are not remotely similar.

When dealing with such categorical data,
the most common strategy is to represent
each item by a *one-hot encoding*
(recall from :numref:`subsec_classification-problem`).
A one-hot encoding is a vector whose length
is given by the size of the vocabulary $N$,
where all entries are set to $0$,
except for the entry corresponding 
to our token, which is set to $1$.
For example, if the vocabulary had five elements,
then the one-hot vectors corresponding 
to indices 0 and 2 would be the following.


In [34]:
jax.nn.one_hot(jnp.array([0, 2]), 5)

Array([[1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]], dtype=float32)

In [None]:
nnx.one_hot(jnp.array([0, 2]), 5)

Array([[1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]], dtype=float32)

(**The minibatches that we sample at each iteration
will take the shape (batch size, number of time steps).
Once representing each input as a one-hot vector,
we can think of each minibatch as a three-dimensional tensor, 
where the length along the third axis 
is given by the vocabulary size (`len(vocab)`).**)
We often transpose the input so that we will obtain an output 
of shape (number of time steps, batch size, vocabulary size).
This will allow us to loop more conveniently through the outermost dimension
for updating hidden states of a minibatch,
time step by time step
(e.g., in the above `forward` method).


### I'll be back.