# Build an RNN, from scratch!

I wrote my own RNN from scratch based on [Ryan's nice slides](https://github.com/alan-turing-institute/transformers-reading-group/blob/main/sessions/03-seq2seq-part-i/seq2seq_part1_hut23_robots_in_disguise.pdf), and this notebook is the result. Now, I'll guide you to do the same (but using a lot of my boilerplate, so you can just do the fun stuff!)

This notebook assumes familiarity with `numpy` and how to multiply arrays. The reason is that we're gonna use my favourite deep learing library [`jax`](https://github.com/google/jax), which uses the same syntax as `numpy`, but we just `import jax.numpy as jnp` instead! Everything else is more or less the same :)

The basic layout is as follows: I'm gonna show you the formula, and write the function signature with some type annotations as hints. You're gonna fill in the blanks! But don't panic: most of these functions are one or two-liners that I literally just copied from the slides.

In these functions, you'll see type annotations like this a lot:
```python3
x: Float[Array, "dim1 dim2"]
```
What am I meaning here? This just says that we have a variable `x`, which is an array of floating-point values (hence the `Float`). The string `"dim1 dim2"` is syntax for the shape of the array, which would have two dimensions `dim1` and `dim2`, i.e. a matrix with `dim1` rows and `dim2` columns. They don't represent anything concrete until we actually instantiate the value of `x`, but paying attention to these dimensions will help you make sure your matrix-vector multiplications will work (remember the rules of matrix multiplication: this matrix could only multiply an object on the right with leading dimension `dim2`). Oh, and if you're wondering where this syntax comes from -- it's the library [`jaxtyping`](https://github.com/google/jaxtyping)!

In [10]:
from __future__ import annotations

import jax
import jax.numpy as jnp
import re
from jaxtyping import Array, Float, Int
from equinox import Module
from typing import Generator
from functools import partial
from copy import deepcopy


# This is a container for all the free parameters of an RNN!
# You can see the shapes of each attribute from the type annotations.
# We have a couple of sizes: hidden_state, embedding, vocab
# -> these represent the size of the hidden_state weights,
#    the embedding matrix, and the vocabulary respecively.
# This parameters object will be passed to most functions below:
# e.g. access the output weights by calling `params.output_weights` etc.
class Parameters(Module):
    embedding_weights: Float[Array, "hidden_state embedding"]
    hidden_state_weights: Float[Array, "hidden_state hidden_state"]
    output_weights: Float[Array, "vocab hidden_state"]
    hidden_state_bias: Float[Array, "hidden_state"]
    output_bias: Float[Array, "vocab"]
    embedding_matrix: Float[Array, "embedding vocab"]

### Updating the hidden state

We'll work our way inside-out, and start with the hidden state update, assuming the existence of some quantities we'll go on to define later. 

Recall that we looked at RNNs for *language modelling*, where every word is turned into a one-hot vector (or "token"). We then multiply these one-hot words by an *embedding matrix* $E$, which multiplies the words to reduce the dimension of that long one-hot vector (=size of the whole vocabulary) to some specified lower dimensional representation (normally ~100). This embedded word is then used to update the hidden state $h$ of the RNN.

Assume some hidden state $h^{(t-1)}$, and we're trying to go to hidden state $h^{(t)}$ given our next embedded word, and the set of free parameters. Using the slides, fill in this function -- for $\sigma$, you can use the activation function `jax.nn.tanh`. Here's the slide as a reminder:

![](images/main_rnn.png)

In [None]:
def update_hidden_state(
    embedding: Float[Array, "embedding"],
    hidden_state: Float[Array, "hidden_state"],
    params: Parameters,
) -> Float[Array, "hidden_state"]:
    return jax.nn.tanh(
        params.hidden_state_weights @ hidden_state
        + params.embedding_weights @ embedding
        + params.hidden_state_bias
    )

### Embeddings

We return to the topic of embeddings. Usually, people use [pre-trained embeddings](https://huggingface.co/blog/getting-started-with-embeddings) for tasks like text generation, but we're taking the ambitious route of learning the embedding matrix jointly with the weights. It's likely that the results of this notebook would be much better if we used embeddings that already encoded information about language and the relationship between words (though, the vocabulary would then need to be much larger than what we use later, which just finds all the unique words in a text document).

We've already described how to embed a word -- just multiply by the embedding matrix. Implement this in the following function for a single word (and we'll use `jax.vmap` again to broadcast this over a sentence for later).

In [None]:
def make_embeddings(
    one_hot_word: Float[Array, "vocab"], params: Parameters
) -> Float[Array, "embedding"]:
    return params.embedding_matrix @ one_hot_word


embeddings_map = jax.vmap(make_embeddings, in_axes=(0, None))

### Computing an output

Once we've updated the hidden state, we're able to produce an output from this timestep! Remember that this approach to language modelling means defining a probability distribution across our vocabulary -- this means that for each word, we assign a number that represents how likely that word is to appear next; our output $\mathbf{\hat{y}^{(t)}}$ is then a vector of the same length as the vocabulary with a set of numbers in the range 0-1 assigned to each word.

Using this knowledge, and referring to the slide, implement the output computation using the current hidden state $h^{(t)}$ and the RNN parameters.

In [None]:
def output(
    hidden_state: Float[Array, "hidden_state"], params: Parameters
) -> Float[Array, "vocab"]:
    return jax.nn.softmax(params.output_weights @ hidden_state + params.output_bias)

### Putting it all together

Now all that's left to produce output from the RNN is to compose the functions from above!

I've pseudocoded the function already using comments -- see if you can fill it in. Note that the input to the RNN here is a sentence (i.e. a sequence of one-hot-encoded words).

In [None]:
def rnn(
    data: Float[Array, "sentence vocab"], params: Parameters, hidden_size: int
) -> Float[Array, "sentence vocab"]:
    # apply embeddings_map to create a vector of embeddings
    embeddings = embeddings_map(data, params)  # ["sentence embedding"]

    # initialize the hidden state with zeros
    hidden_state = jnp.zeros((hidden_size,))

    # for each word in the vector of embeddings:
    #    update the hidden state
    #    compute the output word using that hidden state and store it
    # return the set of outputs
    outputs = []

    for word in embeddings:
        hidden_state = update_hidden_state(word, hidden_state, params)
        outputs.append(output(hidden_state, params))

    return outputs

### Teaching the network through the loss (teacher forcing)

What's a good metric to see if the language modelling is working correctly? Well, if we're predicting the next word in the sentence, and we happen to have access to the whole sentence, we can just see what probability the model assigns to the correct next word. Since that's something we want to maximise, but neural networks are usually trained to minimise the objective, we can do the common trick of taking the negative log of that quantity to serve as our "loss" function -- the thing that we expect to be small when we're doing well. Using this and the slide below, implement the loss function for our RNN.

Note that you only need to do this for a single output word -- we'll use the function `jax.vmap` to automatically vectorise the function to handle whole sentences!

![](images/rnn_loss.png)

In [None]:
def loss(
    output: Float[Array, "vocab"], next_one_hot_word: Float[Array, "vocab"]
) -> Float[Array, ""]:
    # index the softmax probs at the word of interest
    return -jnp.log(output[jnp.argmax(next_one_hot_word)])


sentence_loss = jax.vmap(loss, in_axes=(0, 0))

In [None]:
def forward_pass(
    data: Float[Array, "sentence vocab"],
    next_words: Float[Array, "sentence vocab"],  # data shifted by 1 to the right
    params: Parameters,
    hidden_size: int,
) -> Float[Array, ""]:
    output = rnn(data, params, hidden_size)
    return loss_map(output, next_words).mean(axis=0)


loss_and_gradient = jax.value_and_grad(forward_pass, argnums=2)
batched_grads = jax.jit(
    jax.vmap(loss_and_gradient, in_axes=(0, 0, None, None)), static_argnums=(3,)
)

In [None]:
def one_hot_sentence(
    sentence: Int[Array, "sentence"], vocab_size: int
) -> Int[Array, "sentence vocab"]:
    return jnp.array([jnp.zeros((vocab_size,)).at[word].set(1) for word in sentence])

In [None]:
def predict_next_words(
    prompt: str,
    vocab: list[str],
    rnn_params: Parameters,
    rnn_hidden_size: int,
    num_predicted_tokens: int,
    include_prompt=True,
) -> str:
    # Define a regular expression pattern to match all punctuation marks
    punctuation_pattern = r"[^\w\s]"

    # Define a regular expression pattern to match words with apostrophes
    apostrophe_pattern = r"\w+(?:\'\w+)?"
    # Define a regular expression pattern to match newlines
    newline_pattern = r"\n"

    # Combine the three patterns to match all tokens
    token_pattern = (
        punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern
    )

    tokens = re.findall(token_pattern, prompt.lower())
    one_hot_indicies = jnp.array([vocab.index(t) for t in tokens], dtype=jnp.int32)
    sentence = one_hot_sentence(one_hot_indicies, len(vocab))
    embeddings = embeddings_map(sentence, rnn_params)  # ["sentence embedding"]

    hidden_state = jnp.zeros((rnn_hidden_size,))
    outputs = [None] * num_predicted_tokens
    for word in embeddings[:-1]:
        hidden_state = update_hidden_state(word, hidden_state, rnn_params)
    hidden_state = update_hidden_state(embeddings[-1], hidden_state, rnn_params)
    outputs[0] = output(hidden_state, rnn_params)

    for i in range(1, num_predicted_tokens):
        embedded_pred = make_embeddings(outputs[i - 1], rnn_params)
        hidden_state = update_hidden_state(embedded_pred, hidden_state, rnn_params)
        outputs[i] = output(hidden_state, rnn_params)

    res = jnp.array(outputs)
    res_indicies = jnp.argmax(res, axis=1)
    words = [vocab[i] for i in res_indicies]
    out = " ".join(words)
    return prompt + " | " + out if include_prompt else out

In [5]:
import re

file_name = "one-fish-two-fish.txt"

with open(file_name, "r+") as file:
    all_text = file.read()
    # all_text = all_text.replace('\n', ' ').replace('  : ', '')

# Define a regular expression pattern to match all punctuation marks
punctuation_pattern = r"[^\w\s]"

# Define a regular expression pattern to match words with apostrophes
apostrophe_pattern = r"\w+(?:\'\w+)?"
# Define a regular expression pattern to match newlines
newline_pattern = r"\n"

# Combine the three patterns to match all tokens
token_pattern = punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern


# Split the text into tokens, including words with apostrophes as separate tokens
all_words = re.findall(token_pattern, all_text.lower())
vocab = list(set(all_words))

sentence_length = 8  # even for now...

vocab_one_hot_indicies = jnp.array([vocab.index(t) for t in all_words], dtype=jnp.int32)
split_indicies = vocab_one_hot_indicies[
    : (len(vocab) // sentence_length) * sentence_length
].reshape(len(vocab) // sentence_length, sentence_length)
# make last word random, shouldn't make too much of an impact (could be better handled with special char?)
split_indicies_labels = jnp.concatenate(
    (
        vocab_one_hot_indicies[
            1 : ((len(vocab) - 1) // sentence_length) * sentence_length
        ],
        jnp.array([0]),
    )
).reshape((len(vocab) - 1) // sentence_length, sentence_length)
partition_index = 6 * int(len(split_indicies) / 7)
train = split_indicies[:partition_index]
train_labels = split_indicies_labels[:partition_index]
valid = split_indicies[partition_index:]
valid_labels = split_indicies_labels[partition_index:]

batch_one_hot = jax.vmap(partial(one_hot_sentence, vocab_size=len(vocab)))

train loss: 5.703, valid loss: 5.704
train loss: 5.656, valid loss: 5.658
train loss: 5.590, valid loss: 5.592
train loss: 5.484, valid loss: 5.489
train loss: 5.334, valid loss: 5.344
train loss: 5.156, valid loss: 5.174
train loss: 4.982, valid loss: 5.014
train loss: 4.838, valid loss: 4.889
train loss: 4.726, valid loss: 4.803
train loss: 4.640, valid loss: 4.747
train loss: 4.572, valid loss: 4.712
train loss: 4.516, valid loss: 4.690
train loss: 4.469, valid loss: 4.677
train loss: 4.429, valid loss: 4.670
train loss: 4.393, valid loss: 4.665
train loss: 4.362, valid loss: 4.663
train loss: 4.334, valid loss: 4.661
train loss: 4.309, valid loss: 4.661
train loss: 4.286, valid loss: 4.661
train loss: 4.265, valid loss: 4.661
train loss: 4.245, valid loss: 4.662
train loss: 4.227, valid loss: 4.663
train loss: 4.211, valid loss: 4.664
train loss: 4.195, valid loss: 4.665
train loss: 4.181, valid loss: 4.666
train loss: 4.168, valid loss: 4.668
train loss: 4.156, valid loss: 4.669
t

In [None]:
batch_size = 400

import numpy.random as npr


def batches(training_data: Array, batch_size: int) -> Generator:
    num_train = training_data.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # batching mechanism, ripped from the JAX docs :)
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train[batch_idx], train_labels[batch_idx]

    return data_stream()


batch = batches(train, batch_size)

In [None]:
e = 30
h = 16
v = len(vocab)
o = v

pars = Parameters(
    embedding_weights=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[h, e], key=jax.random.PRNGKey(0)
    ),
    hidden_state_weights=jnp.identity(h),  # keep gradients from exploding
    output_weights=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[o, h], key=jax.random.PRNGKey(0)
    ),
    hidden_state_bias=jnp.zeros((h,)),  # keep gradients from exploding
    output_bias=jnp.zeros(
        shape=[
            o,
        ]
    ),
    embedding_matrix=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[e, v], key=jax.random.PRNGKey(0)
    ),
)
num_iter = 2000
lr = 4e-2
one_hot_valid, one_hot_valid_labels = batch_one_hot(valid), batch_one_hot(valid_labels)
best_loss = 999
best_pars = None


def gradient_descent(param: jax.Array, grads: jax.Array) -> jax.Array:
    return param - lr * grads.mean(axis=0)


import optax

opt = optax.chain(
    optax.clip(1),
    optax.adamw(learning_rate=lr),
)
opt_state = opt.init(pars)

In [None]:
for i in range(num_iter):
    sentences, sentence_labels = next(batch)
    one_hot_sentences, one_hot_sentence_labels = batch_one_hot(
        sentences
    ), batch_one_hot(sentence_labels)
    loss, grads = batched_grads(one_hot_sentences, one_hot_sentence_labels, pars, h)
    valid_loss, _ = batched_grads(one_hot_valid, one_hot_valid_labels, pars, h)
    loss, valid_loss = loss.mean(), valid_loss.mean()
    pars = jax.tree_map(gradient_descent, pars, grads)
    # avg_grads = jax.tree_map(lambda g: g.mean(axis=0), grads)
    # updates, opt_state = opt.update(avg_grads, opt_state, params=pars)
    # pars = optax.apply_updates(pars, updates)
    if valid_loss < best_loss:
        best_pars = deepcopy(pars)
        best_loss = valid_loss
    if i % 20 == 0:
        print(f"train loss: {loss.mean():.3f}", end=", ")
        print(f"valid loss: {valid_loss.mean():.3f}")

print(f"best valid loss: {best_loss:.3f}")
print(predict_next_words("Red fish ", vocab, pars, h, 10, include_prompt=True))

In [9]:
print(predict_next_words("Hello", vocab, best_pars, h, 2, include_prompt=True))

Hello | 
 

