In [None]:
import math
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt

In [None]:
t = jnp.linspace(0.0, 1.0, 30)
y = jnp.linspace(1.0, 2.0, 50).reshape(-1, 1)
x = jnp.sin(2.0 * math.pi * jnp.matmul(y, t.reshape((1, -1)))).reshape(-1, 30, 1)

In [None]:
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.nn.LSTMCell
    linear: eqx.nn.Linear

    def __init__(self, input_size, hidden_size, out_size, key):
        ckey, lkey = jax.random.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.LSTMCell(
            input_size=input_size, hidden_size=hidden_size, key=ckey
        )
        self.linear = eqx.nn.Linear(hidden_size, out_size, key=lkey)

    def __call__(self, x):
        hidden_cell = (jnp.zeros(self.hidden_size), jnp.zeros(self.hidden_size))
        def fn(hid_cell, inp):
            return self.cell(inp, hid_cell), None

        (h, _), _ = jax.lax.scan(fn, hidden_cell, x)
        return self.linear(h)

In [None]:
plt.plot(t, x[0])
plt.plot(t, x[-1])

In [None]:
model = RNN(input_size=1, hidden_size=8, out_size=1, key=jax.random.PRNGKey(42))
optim = optax.adam(1e-2)
opt_state = optim.init(model)

@eqx.filter_value_and_grad
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return optax.losses.squared_error(pred, y).mean()

@eqx.filter_jit
def evolve(model, x, y, opt_state):
    loss, grads = loss_fn(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, loss, opt_state

for step in range(1000):
    model, loss, opt_state = evolve(model, x, y, opt_state)
    loss = loss.item()
    print(f"step={step}, loss={loss}")

In [None]:
id = 25
print(f"predicted={model(x[id, :])}, actual={y[id]}")