In [None]:
import math
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt

In [None]:
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.nn.LSTMCell
    linear: eqx.nn.Linear

    def __init__(self, input_size, hidden_size, out_size, key):
        ckey, lkey = jax.random.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.LSTMCell(
            input_size=input_size, hidden_size=hidden_size, key=ckey
        )
        self.linear = eqx.nn.Linear(hidden_size, out_size, key=lkey)

    def __call__(self, x):
        hidden = jnp.zeros((self.hidden_size,))

        def f(hid, inp):
            return self.cell(inp, hid), None

        out, _ = jax.lax.scan(f, hidden, x)
        return self.linear(out)

In [None]:
hidden_size = 5
hidden = jnp.zeros((hidden_size,))
hidden

In [None]:
model = RNN(input_size=1, hidden_size=4, out_size=1, key=jax.random.PRNGKey(42))
model

In [None]:
t = jnp.linspace(0.0, 1.0, 30)
y = jnp.linspace(1.0, 2.0, 50)
x = jnp.sin(2.0 * math.pi * jnp.matmul(y.reshape((-1, 1)), t.reshape((1, -1))))

In [None]:
plt.plot(t, x[0])
plt.plot(t, x[-1])