# Recurrent Models

In [None]:
# | default_exp models.recurrent

In [None]:
# | export

import jax
import jax.numpy as jnp
import flax.linen as nn
from functools import partial
from einops import rearrange

## LRU dynamics

> Linear dynamics using initisialisation of the eigenvalues based on the LRU paper

In [None]:
# | export

from physmodjax.models.ssm import (
    theta_init,
    nu_init,
)

In [None]:
# | export

from physmodjax.models.ssm import LRUDynamics

### LRU with MLP

In [None]:
# | export

class LRUDynamicsVarying(LRUDynamics):

    model: nn.Module  # model to process the linear state

    def setup(self):
        super().setup()

    def __call__(
        self,
        x: jnp.ndarray,  # initial complex state flattened (d_hidden,) complex
        steps: int,  # number of steps to advance
    ) -> jnp.ndarray:  # advanced state (steps, d_hidden) complex

        x = super().__call__(x, steps)
        x_hat = self.model(x.real**2 + x.imag**2)
        x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
        x = x * x_hat
        return x

In [None]:
from physmodjax.models.mlp import MLP

In [None]:
# | test

d_hidden = 64
steps = 50
model = MLP(hidden_channels=[64, 64, 64])
dyn = LRUDynamicsVarying(
    d_hidden=d_hidden,
    r_min=0.99,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=model,
    clip_eigs=False,
)

## Deep GRU

In [None]:
# | export

class DeepRNN(nn.Module):
    """
    A deep RNN model that applies a RNN cell over the last dimension of the input.
    Works with nn.GRUCell, nn.RNNCell, nn.SimpleCell, nn.MGUCell.
    """

    d_model: int
    d_vars: int
    n_layers: int
    cell: nn.Module
    training: bool = True
    norm: str = "layer"

    def setup(self):

        # scan does the same thing as nn.RNN (unrolls the over the time dimension)
        self.first_layer = nn.RNN(
            self.cell(features=self.d_model * self.d_vars),
        )

        self.layers = [
            nn.RNN(
                self.cell(features=self.d_model * self.d_vars),
            )
            for _ in range(self.n_layers)
        ]

    def __call__(
        self,
        x0: jnp.ndarray,  # (W, C) # initial state
        x: jnp.ndarray,  # (T, W, C) # empty state
    ) -> jnp.ndarray:  # (T, W, C) # advanced state
        # the rnn works over the last dimension
        # we need to reshape the input to (T, d_model * C)
        x0 = rearrange(x0, "w c -> (w c)")
        x = rearrange(x, "t w c -> t (w c)")
        x = self.first_layer(x, initial_carry=x0)
        for layer in self.layers:
            x = layer(x)
        return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)


BatchedDeepRNN = nn.vmap(
    DeepRNN,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    axis_name="batch",
)

In [None]:
# | test

B, T, W, C = 10, 50, 20, 3
deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
x = jnp.ones((B, T, W, C))
x0 = jnp.ones((B, W, C))
variables = deep_rnn.init(jax.random.PRNGKey(65), x0, x)
out = deep_rnn.apply(variables, x0, x)

assert out.shape == (B, T, W, C)