# Build an RNN, from scratch!

I wrote my own RNN from scratch based on [Ryan's nice slides](https://github.com/alan-turing-institute/transformers-reading-group/blob/main/sessions/03-seq2seq-part-i/seq2seq_part1_hut23_robots_in_disguise.pdf), and this notebook is the result. Now, I'll guide you to do the same (but using a lot of my boilerplate, so you can just do the fun stuff!)

This notebook assumes familiarity with `numpy` and how to multiply arrays. The reason is that we're gonna use my favourite deep learing library [`jax`](https://github.com/google/jax), which uses the same syntax as `numpy`, but we just `import jax.numpy as jnp` instead! Everything else is more or less the same -- i'll point out any important differences :)

The basic layout is as follows: I'm gonna show you the formula, and write the function signature with some type annotations as hints. You're gonna fill in the blanks! But don't panic: most of these functions are one or two-liners that I literally just copied from the slides.

The format will look something like:

```python3
def my_function(x, y, z):
   ...
```

where you should replace the `...` with the definition of the function.

In these functions, you'll see type annotations like this a lot:
```python3
x: Float[Array, "dim1 dim2"]
```
What am I meaning here? This just says that we have a variable `x`, which is an array of floating-point values (hence the `Float`). The string `"dim1 dim2"` is syntax for the shape of the array, which would have two dimensions `dim1` and `dim2`, i.e. a matrix with `dim1` rows and `dim2` columns. They don't represent anything concrete until we actually instantiate the value of `x`, but paying attention to these dimensions will help you make sure your matrix-vector multiplications will work (remember the rules of matrix multiplication: this matrix could only multiply an object on the right with leading dimension `dim2`). Oh, and if you're wondering where this syntax comes from -- it's the library [`jaxtyping`](https://github.com/google/jaxtyping)!

Speaking of which, let's install a couple dependencies first:

In [None]:
%pip install pytreeclass jaxtyping jaxlib optax

What follows below is a bit messier -- this is a boilerplate function designed to pre-process a text document and form a vocabulary. This is normally handled by external libraries, but I was going all-in on doing this from scratch :p

Don't try too hard to read it -- we will play with the output in the next step!

The basic gist:
- Find all unique tokens (including newlines, punctuation, contracted words like it's)
- Create a unique identifier for each token -- it's position in a list of the vocabulary
- Map the text to its corresponding indicies in the vocabulary
- Split this up into a training set (what the model sees) and a validation set (what it doesn't see) so we can evaluate our generalisation capabilities to unseen words

In [None]:
from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np
import re
from jaxtyping import Array, Float, Int
import pytreeclass as pytc
from typing import Generator
from functools import partial
from copy import deepcopy


def prepare_text(file_name, sentence_length):
    with open(file_name, "r+") as file:
        all_text = file.read()
        # all_text = all_text.replace('\n', ' ').replace('  : ', '')

    # Define a regular expression pattern to match all punctuation marks
    punctuation_pattern = r"[^\w\s]"

    # Define a regular expression pattern to match words with apostrophes
    apostrophe_pattern = r"\w+(?:\'\w+)?"
    # Define a regular expression pattern to match newlines
    newline_pattern = r"\n"

    # Combine the three patterns to match all tokens
    token_pattern = (
        punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern
    )

    # Split the text into tokens, including words with apostrophes as separate tokens
    all_words = re.findall(token_pattern, all_text.lower())
    vocab = list(set(all_words))

    vocab_one_hot_indicies = jnp.array(
        [vocab.index(t) for t in all_words], dtype=jnp.int32
    )
    split_indicies = vocab_one_hot_indicies[
        : (len(vocab) // sentence_length) * sentence_length
    ].reshape(len(vocab) // sentence_length, sentence_length)
    # make last word random, shouldn't make too much of an impact (could be better handled with special char?)
    split_indicies_labels = jnp.concatenate(
        (
            vocab_one_hot_indicies[
                1 : ((len(vocab) - 1) // sentence_length) * sentence_length
            ],
            jnp.array([0]),
        )
    ).reshape((len(vocab) - 1) // sentence_length, sentence_length)
    partition_index = 6 * int(len(split_indicies) / 7)
    train = split_indicies[:partition_index]
    train_labels = split_indicies_labels[:partition_index]
    valid = split_indicies[partition_index:]
    valid_labels = split_indicies_labels[partition_index:]

    return train, train_labels, valid, valid_labels, vocab

### Producing a vocabulary from text

Below, we explore the data we're working with!

In [None]:
file_name = "one-fish-two-fish.txt"
sentence_length = 8  # keep even because of how we split the data
train, train_labels, valid, valid_labels, vocab = prepare_text(
    file_name, sentence_length
)

print(f"examples from vocab: {vocab[:10]}")
print(f"total length of vocab: {len(vocab)} unique words")
print(
    f"total length of training data: {len(train)} sentences (each {sentence_length} words)"
)
print(
    f"total length of validation data: {len(valid)} sentences (each {sentence_length} words)"
)

In [None]:
# first sentence in train set
train[0]

In [None]:
# first sentence in train labels == same sentence shifted by one word
# i.e. equivalent to train[0][1:] + train[1][0]
train_labels[0]

In [None]:
# we can reconstruct a sentence by mapping indicies back to words
" ".join([vocab[i] for i in train[0]])

In [None]:
" ".join([vocab[i] for i in train_labels[0]])

As you can see, our whole text has been split into many individual words, or *tokens*. Of course, we don't have to model words; we can use characters, strings of length 4, numbers... anything goes! For simplicity, we assume one word <-> one token here. Then, I'm thinking of a vocabulary as a lookup table that maps a word seen in the text to a unique numerical identifier -- this can just be the position of that word in the vocabulary, assuming it's a list-like structure. We're including things like punctuation, contracted words, newlines etc.

What next? For our RNN, we need to be able to take in a sentence -- assumed to be a list of indicies of positions in the vocabulary -- and construct one-hot vector for each. Recall the definition of a one-hot vector here:

In [None]:
fish_idx = vocab.index("fish")  # the index of the word "fish" in the vocab
one_hot_fish = np.zeros(len(vocab))  # a vector of zeros with length equal to the vocab
one_hot_fish[fish_idx] = 1  # set the index of the word "fish" to 1

# the syntax in JAX is a little different, but the idea is the same
# we use the `at` method to set the value at a particular index,
# and the `set` method to set the value at that index to 1.
# this is due to the fact that JAX arrays are immutable,
# so we can't just set the value at an index to 1 directly!
fish_idx = vocab.index("fish")
one_hot_fish = jnp.zeros(len(vocab))
one_hot_fish = one_hot_fish.at[fish_idx].set(1)

one_hot_fish

Based on the above, fill in the function to turn a sentence of indicies -- corresponding to words in the vocab -- into an array of one-hot vectors for those words.

In [None]:
def one_hot_sentence(
    sentence: Int[Array, "sentence"], vocab_size: int
) -> Int[Array, "sentence vocab"]:
    ...


# we'll test out your function to see if it worked:
# make a very intelligent sentence of all "fish"
fish_sentence = jnp.array([fish_idx] * 10)

# one hot encode the sentence
one_hot_fish_sentence = one_hot_sentence(fish_sentence, len(vocab))

# assert that the sentence is one hot encoded correctly
assert jnp.all(one_hot_fish_sentence == jnp.array([one_hot_fish] * 10))

In [None]:
# we can use `vmap` to automatically transform the function to work on a batch of sentences!
# this will be useful when we want to train our model on multiple sentences at once.
# note that we need to specify the `in_axes` argument to tell JAX which argument
# in the function is the one that we want to map over (in this case, we want to
# map over the first axis of the `sentence` argument, indicated by `0`).
# we also need to specify `None` for the `vocab_size` argument, since it is not
# being mapped over -- it is the same for every sentence in the batch.
batch_one_hot = jax.vmap(one_hot_sentence, in_axes=(0, None))

### Embeddings

What's an embedding? In this context, its a reduction in dimensionality on the vocabulary that should contain some useful information about language. Usually, people use [pre-trained embeddings](https://huggingface.co/blog/getting-started-with-embeddings) for tasks like text generation (since language is a fairly general setting), but we're taking the ambitious route of learning the embedding matrix jointly with the weights. It's likely that the results of this notebook would be much better if we used embeddings that already encoded information about language and the relationship between words (though, the vocabulary would then need to be much larger than what we use later, which just finds all the unique words in a text document).



Now, we're going to take an aside to define our embedding matrix and other relevant quantities. Let's look at the parameters of our RNN:

In [None]:
# This is a container for all the free parameters of an RNN!
# You can see the shapes of each attribute from the type annotations.
# We have a couple of sizes: hidden_state, embedding, vocab
# -> these represent the size of the hidden state weights,
#    the embedding matrix, and the vocabulary respectively.
# This parameters object will be passed to most functions below:
# e.g. access the output weights by calling `params.output_weights` etc.
class Parameters(pytc.TreeClass):
    embedding_weights: Float[Array, "hidden_state embedding"]
    hidden_state_weights: Float[Array, "hidden_state hidden_state"]
    output_weights: Float[Array, "vocab hidden_state"]
    hidden_state_bias: Float[Array, "hidden_state"]
    output_bias: Float[Array, "vocab"]
    embedding_matrix: Float[Array, "embedding vocab"]


# we'll initialize our parameters randomly, but close to 0/identity so that
# we don't have exploding gradients later on!

# set sizes for embeddings, hidden state, vocab, and output vectors
e = 30
h = 16
v = len(vocab)
o = v

params = Parameters(
    embedding_weights=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[h, e], key=jax.random.PRNGKey(0)
    ),
    hidden_state_weights=jnp.identity(h),
    output_weights=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[o, h], key=jax.random.PRNGKey(0)
    ),
    hidden_state_bias=jnp.zeros((h,)),
    output_bias=jnp.zeros(
        shape=[
            o,
        ]
    ),
    embedding_matrix=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[e, v], key=jax.random.PRNGKey(0)
    ),
)

# let's inspect the structure of our parameters
print(pytc.tree_summary(params))
print(pytc.tree_diagram(params))

We'll come to use most of these values later, but for now, we're just focused on embeddings!

Recall what we did in the previous sessions, looking at RNNs for *language modelling*, where every word is turned into a one-hot vector (or "token"). We then multiply these one-hot words by an *embedding matrix* $E$, which multiplies the words to reduce the dimension of that long one-hot vector (=size of the whole vocabulary) to some specified lower dimensional representation (normally ~100 ish). This embedded word is then used to update the hidden state $h$ of the RNN.

Use the embedding matrix (accessible through `params.embedding_matrix`) to fill in the function below, which embeds a single word.

In [None]:
def make_embeddings(
    one_hot_word: Float[Array, "vocab"], params: Parameters
) -> Float[Array, "embedding"]:
    ...


# I should be a vector of length `e`
assert make_embeddings(one_hot_fish, params).shape == (e,)

In [None]:
# map to work over sentences for later!
embeddings_map = jax.vmap(make_embeddings, in_axes=(0, None))

### Updating the hidden state

Assume some hidden state $h^{(t-1)}$, and we're trying to go to hidden state $h^{(t)}$ given our next embedded word, and the set of free parameters. Using the slides, fill in this function -- for $\sigma$, you can use the activation function `jax.nn.tanh`. Here's the slide as a reminder:

![](images/main_rnn.png)

In [None]:
def update_hidden_state(
    embedding: Float[Array, "embedding"],
    hidden_state: Float[Array, "hidden_state"],
    params: Parameters,
) -> Float[Array, "hidden_state"]:
    ...


# I should be a vector of length `h`
embedding = make_embeddings(one_hot_fish, params)
assert update_hidden_state(embedding, jnp.zeros((h,)), params).shape == (h,)

### Computing an output

Once we've updated the hidden state, we're able to produce an output from this timestep! Remember that this approach to language modelling means defining a probability distribution across our vocabulary -- this means that for each word, we assign a number that represents how likely that word is to appear next; our output $\mathbf{\hat{y}^{(t)}}$ is then a vector of the same length as the vocabulary with a set of numbers in the range 0-1 assigned to each word.

Using this knowledge, and referring to the slide, implement the output computation using the current hidden state $h^{(t)}$ and the RNN parameters.

In [None]:
def output(
    hidden_state: Float[Array, "hidden_state"], params: Parameters
) -> Float[Array, "vocab"]:
    ...


# I should be a vector of length `v`
assert output(jnp.zeros((h,)), params).shape == (v,)

### Putting it all together

Now all that's left to produce output from the RNN is to compose the functions from above!

I've pseudocoded the function already using comments -- see if you can fill it in. Note that the input to the RNN here is a sentence (i.e. a sequence of one-hot-encoded words).

In [None]:
def rnn(
    data: Float[Array, "sentence vocab"], params: Parameters, hidden_size: int
) -> Float[Array, "sentence vocab"]:
    # apply embeddings_map to create a vector of embeddings for the sentence

    # initialize the hidden state with zeros

    # for each word in the vector of embeddings:
    #   > update the hidden state
    #   > compute the output word using that hidden state and store it
    # return the set of outputs
    ...


# make a very intelligent sentence of all "fish"
fish_sentence = jnp.array([fish_idx] * 10)
one_hot_fish_sentence = one_hot_sentence(fish_sentence, len(vocab))

# run the RNN on "fish fish fish fish fish fish fish fish fish fish"
rnn_outputs = rnn(one_hot_fish_sentence, params, h)

# the output should be a list of 10 vectors of length `v`
# corresponding to the output probabilities at each position in the sentence
assert rnn_outputs.shape == (10, v)

# what is the most likely word at each position?
# (remember that we have randomly initialized our parameters, so this will be nonsense!)
most_likely_words = [vocab[jnp.argmax(output)] for output in rnn_outputs]
most_likely_words

### Teaching the network through the loss (teacher forcing)

What's a good metric to see if the language modelling is working correctly? Well, if we're predicting the next word in the sentence, and we happen to have access to the whole sentence, we can just see what probability the model assigns to the correct next word. Since that's something we want to maximise, but neural networks are usually trained to minimise the objective, we can do the common trick of taking the negative log of that quantity to serve as our "loss" function -- the thing that we expect to be small when we're doing well. Using this and the slide below, implement the loss function for our RNN.

Note that you only need to do this for a single output word -- we'll use the function `jax.vmap` to automatically vectorise the function to handle whole sentences!

![](images/rnn_loss.png)

In [None]:
def loss(
    output: Float[Array, "vocab"], next_one_hot_word: Float[Array, "vocab"]
) -> Float[Array, ""]:
    ...


# I should be a scalar
assert loss(rnn_outputs[0], one_hot_fish_sentence[1]).shape == ()

# we'll define `sentence_loss` to be a function that computes the loss for a single sentence
sentence_loss = jax.vmap(loss, in_axes=(0, 0))

Now that we have a loss function we can compute over sentences, we can make a pipeline that goes from initial data all the way to the loss function result. The reason we need to be explicit about this is the following: in order to calculate the gradient of the loss function in `jax`, we need to have a function that takes in the thing we want to differentiate with respect to (here, the `Parameters`), and returns the loss result, which can only be computed *after* the model has been run.

Try to implement this below -- use the `rnn` function from the last step! Note that you should the mean of the loss over the sentence -- this means that the RNN is targeting the goal of getting the correct word on average across the sentences we feed in.

In [None]:
def forward_pass(
    data: Float[Array, "sentence vocab"],
    next_words: Float[Array, "sentence vocab"],  # data shifted by 1 to the right
    params: Parameters,
    hidden_size: int,
) -> Float[Array, ""]:
    # run the RNN

    # compute the losses across the sentence

    # return the mean of the losses
    ...


# run the forward pass on our fish sentence
loss_value = forward_pass(one_hot_fish_sentence, one_hot_fish_sentence, params, h)
assert loss_value.shape == ()
print(f"starting loss: {loss_value:.2f}")

In [None]:
# here, we transform the forward pass into the gradient function,
# and also vmap again so it can handle a batch of sentences instead of one.
loss_and_gradient = jax.value_and_grad(forward_pass, argnums=2)
print("gradients are packed up in a Parameters object:")
print(
    pytc.tree_diagram(
        loss_and_gradient(one_hot_fish_sentence, one_hot_fish_sentence, params, h)[1]
    )
)

In [None]:
# we can also vmap the gradient function to handle a batch of sentences,
# and jit it to make it faster!
batched_grads = jax.jit(
    jax.vmap(loss_and_gradient, in_axes=(0, 0, None, None)), static_argnums=(3,)
)

You've done all the hard work! We'll not write any more code -- it's all pretty generic boilerplate for training any ML model.

### Predicting words

Another boilerplate-y function: turning the softmaxed probabilities into actual words. The only real thing going on here is turning a softmax into a hard-max using `jnp.argmax` -- then we can just convert back from the one-hot representation. One other thing is that we can technically predict arbitrarily many words by putting every prediction from the RNN as an input to generate new hidden state and new output. You'll find this cycle to very quickly produce some funny-looking sentences if you let it go on for too many tokens...

In [None]:
def predict_next_words(
    prompt: str,
    vocab: list[str],
    rnn_params: Parameters,
    rnn_hidden_size: int,
    num_predicted_tokens: int,
    include_prompt=True,
) -> str:
    # Define a regular expression pattern to match all punctuation marks
    punctuation_pattern = r"[^\w\s]"

    # Define a regular expression pattern to match words with apostrophes
    apostrophe_pattern = r"\w+(?:\'\w+)?"
    # Define a regular expression pattern to match newlines
    newline_pattern = r"\n"

    # Combine the three patterns to match all tokens
    token_pattern = (
        punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern
    )

    tokens = re.findall(token_pattern, prompt.lower())
    one_hot_indicies = jnp.array([vocab.index(t) for t in tokens], dtype=jnp.int32)
    sentence = one_hot_sentence(one_hot_indicies, len(vocab))
    embeddings = embeddings_map(sentence, rnn_params)  # ["sentence embedding"]

    hidden_state = jnp.zeros((rnn_hidden_size,))
    outputs = [None] * num_predicted_tokens
    for word in embeddings[:-1]:
        hidden_state = update_hidden_state(word, hidden_state, rnn_params)
    hidden_state = update_hidden_state(embeddings[-1], hidden_state, rnn_params)
    outputs[0] = output(hidden_state, rnn_params)

    for i in range(1, num_predicted_tokens):
        embedded_pred = make_embeddings(outputs[i - 1], rnn_params)
        hidden_state = update_hidden_state(embedded_pred, hidden_state, rnn_params)
        outputs[i] = output(hidden_state, rnn_params)

    res = jnp.array(outputs)
    res_indicies = jnp.argmax(res, axis=1)
    words = [vocab[i] for i in res_indicies]
    out = " ".join(words)
    return prompt + " | " + out if include_prompt else out

### Training the model

We'll go through some more code to set up batching, initialize our parameters, and look at the training loop. You don't need to write any more code -- it should all just work now!

#### Batching

In [None]:
batch_size = 400

import numpy.random as npr


def batches(training_data: Array, batch_size: int) -> Generator:
    num_train = training_data.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # batching mechanism, ripped from the JAX docs :)
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train[batch_idx], train_labels[batch_idx]

    return data_stream()


batch = batches(train, batch_size)
one_hot_valid, one_hot_valid_labels = batch_one_hot(valid, v), batch_one_hot(
    valid_labels, v
)

#### Setting up training hyperparameters

In [None]:
# training hyperparams, modify at will!
num_iter = 2000
lr = 4e-2
best_loss = 999
best_pars = None


# basic gradient descent
def gradient_descent(param: jax.Array, grads: jax.Array) -> jax.Array:
    return param - lr * grads.mean(axis=0)


# more advanced gradient descent
import optax

opt = optax.chain(
    optax.clip(1),
    optax.adamw(learning_rate=lr),
)
opt_state = opt.init(params)

### Train time!

In [None]:
for i in range(num_iter):
    sentences, sentence_labels = next(batch)
    one_hot_sentences, one_hot_sentence_labels = batch_one_hot(
        sentences, v
    ), batch_one_hot(sentence_labels, v)
    loss, grads = batched_grads(one_hot_sentences, one_hot_sentence_labels, params, h)
    valid_loss, _ = batched_grads(one_hot_valid, one_hot_valid_labels, params, h)
    loss, valid_loss = loss.mean(), valid_loss.mean()

    # gradient descent!
    params = jax.tree_map(gradient_descent, params, grads)

    ## uncomment these lines for advanced version
    # avg_grads = jax.tree_map(lambda g: g.mean(axis=0), grads)
    # updates, opt_state = opt.update(avg_grads, opt_state, params=pars)
    # pars = optax.apply_updates(pars, updates)

    if valid_loss < best_loss:
        best_pars = deepcopy(params)
        best_loss = valid_loss
    if i % 20 == 0:
        print(f"train loss: {loss.mean():.3f}", end=", ")
        print(f"valid loss: {valid_loss.mean():.3f}")

print(f"best valid loss: {best_loss:.3f}")

In [None]:
predict_next_words("Red fish ", vocab, params, h, 10, include_prompt=True)

You may have ended up predicting all newline characters! I wonder why?

Well, what's the easiest route to minimize the loss of the *average* word? Just always predict the most common one!

In [None]:
# make a histogram of the frequency of each word in the document
import matplotlib.pyplot as plt
from collections import Counter

with open(file_name, "r+") as file:
    all_text = file.read()

# Define a regular expression pattern to match all punctuation marks
punctuation_pattern = r"[^\w\s]"

# Define a regular expression pattern to match words with apostrophes
apostrophe_pattern = r"\w+(?:\'\w+)?"
# Define a regular expression pattern to match newlines
newline_pattern = r"\n"

# Combine the three patterns to match all tokens
token_pattern = punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern

# Split the text into tokens, including words with apostrophes as separate tokens
all_words = re.findall(token_pattern, all_text.lower())

# Count the number of occurrences of each word
word_counts = Counter(all_words)

# Get the 100 most common words
most_common_words = word_counts.most_common(10)

# Plot the histogram
plt.bar(
    [word for word, count in most_common_words],
    [count for word, count in most_common_words],
)

# the x-labels will not escape the newlines, so we need to replace them
plt.xticks(
    [word for word, count in most_common_words],
    [word.replace("\n", r"\n") for word, _ in most_common_words],
)
plt.xlabel("Word")
plt.ylabel("Frequency")
plt.title(f"Most common words in {file_name}")
plt.show()