In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

In [4]:
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.nn.LSTMCell
    linear: eqx.nn.Linear

    def __init__(self, input_size, hidden_size, out_size, key):
        ckey, lkey = jax.random.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.LSTMCell(
            input_size=input_size, hidden_size=hidden_size, key=ckey
        )
        self.linear = eqx.nn.Linear(hidden_size, out_size, key=lkey)

    def __call__(self, x):
        hidden = jnp.zeros((self.hidden_size,))

        def f(hid, inp):
            return self.cell(inp, hid), None

        out, _ = jax.lax.scan(f, hidden, x)
        return jax.nn.sigmoid(self.linear(out))

In [2]:
hidden_size = 5
hidden = jnp.zeros((hidden_size,))
hidden

Array([0., 0., 0., 0., 0.], dtype=float32)

In [6]:
model = RNN(2, 4, 1, jax.random.PRNGKey(42))
model

RNN(
  hidden_size=4,
  cell=LSTMCell(
    weight_ih=f32[16,2],
    weight_hh=f32[16,4],
    bias=f32[16],
    input_size=2,
    hidden_size=4,
    use_bias=True
  ),
  linear=Linear(
    weight=f32[1,4],
    bias=f32[1],
    in_features=4,
    out_features=1,
    use_bias=True
  )
)