# Train RNN

This is an introductory example. We demonstrate what using Equinox normally looks like day-to-day.

Here, we'll train an RNN to classify clockwise vs anticlockwise spirals.

This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/equinox/blob/main/examples/train_rnn.ipynb).

In [1]:
import math

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax  # https://github.com/deepmind/optax

We begin by importing the usual libraries, setting up a very simple dataloader, and generating a toy dataset of spirals.

In [2]:
def dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def get_data(dataset_size, *, key):
    t = jnp.linspace(0, 2 * math.pi, 16)
    offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
    x1 = jnp.sin(t + offset) / (1 + t)
    x2 = jnp.cos(t + offset) / (1 + t)
    y = jnp.ones((dataset_size, 1))

    half_dataset_size = dataset_size // 2
    x1 = x1.at[:half_dataset_size].multiply(-1)
    y = y.at[:half_dataset_size].set(0)
    x = jnp.stack([x1, x2], axis=-1)

    return x, y

Now for our model.

Purely by way of example, we handle the final adding on of bias ourselves, rather than letting the `linear` layer do it. This is just so we can demonstrate how to use custom parameters in models.

In [3]:
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jrandom.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

And finally the training loop.

In [4]:
def main(
    dataset_size=10000,
    batch_size=32,
    learning_rate=3e-3,
    steps=200,
    hidden_size=16,
    depth=1,
    seed=5678,
):
    data_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 2)
    xs, ys = get_data(dataset_size, key=data_key)
    iter_data = dataloader((xs, ys), batch_size)

    model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)

    @eqx.filter_value_and_grad
    def compute_loss(model, x, y):
        pred_y = jax.vmap(model)(x)
        # Trains with respect to binary cross-entropy
        return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

    # Important for efficiency whenever you use JAX: wrap everything into a single JIT
    # region.
    @eqx.filter_jit
    def make_step(model, x, y, opt_state):
        loss, grads = compute_loss(model, x, y)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    optim = optax.adam(learning_rate)
    opt_state = optim.init(model)
    for step, (x, y) in zip(range(steps), iter_data):
        loss, model, opt_state = make_step(model, x, y, opt_state)
        loss = loss.item()
        print(f"step={step}, loss={loss}")

    pred_ys = jax.vmap(model)(xs)
    num_correct = jnp.sum((pred_ys > 0.5) == ys)
    final_accuracy = (num_correct / dataset_size).item()
    print(f"final_accuracy={final_accuracy}")

`eqx.filter_value_and_grad` will calculate the gradient with respect to all floating-point arrays in the first argument (`model`). In this case the `model` parameters will be differentiated, whilst `model.hidden_size` is an integer and will get `None` as its gradient.

Likewise, `eqx.filter_jit` will look at all the arguments passed to `make_step`, and automatically JIT-trace every array and JIT-static everything else. In this case the `model` parameters and the data `x` and `y` will be traced, whilst `model.hidden_size` is an integer and will be static'd instead.

In [5]:
main()  # All right, let's run the code.

step=0, loss=0.6816999316215515
step=1, loss=0.7202574014663696
step=2, loss=0.6925007104873657
step=3, loss=0.689198911190033
step=4, loss=0.6808685064315796
step=5, loss=0.7059305906295776
step=6, loss=0.6922754049301147
step=7, loss=0.6842439770698547
step=8, loss=0.6972116231918335
step=9, loss=0.7047306299209595
step=10, loss=0.6993851661682129
step=11, loss=0.6921849846839905
step=12, loss=0.6844913959503174
step=13, loss=0.6941200494766235
step=14, loss=0.6870629787445068
step=15, loss=0.6922240257263184
step=16, loss=0.6966875195503235
step=17, loss=0.7021255493164062
step=18, loss=0.6913468241691589
step=19, loss=0.6915531158447266
step=20, loss=0.6906869411468506
step=21, loss=0.6945821046829224
step=22, loss=0.6963403820991516
step=23, loss=0.6893304586410522
step=24, loss=0.6923031210899353
step=25, loss=0.6952496767044067
step=26, loss=0.6937462687492371
step=27, loss=0.6946915984153748
step=28, loss=0.6912715435028076
step=29, loss=0.6945470571517944
step=30, loss=0.69285