# Stateful operations (e.g. BatchNorm)

Some layers, such as [`equinox.nn.BatchNorm`][] are sometimes called "stateful": this refers to the fact that they take an additional input (in the case of BatchNorm, the running statistics) and return an additional output (the updated running statistics).

This just means that we need to plumb an extra input and output through our models. This example demonstrates both [`equinox.nn.BatchNorm`][] and [`equinox.nn.SpectralNorm`][].

This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/equinox/blob/main/examples/stateful.ipynb).

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax  # https://github.com/deepmind/optax

In [2]:
# This model is just a weird mish-mash of layers for demonstration purposes, it isn't
# doing any clever.
class Model(eqx.Module):
    norm1: eqx.nn.BatchNorm
    spectral_linear: eqx.nn.SpectralNorm[eqx.nn.Linear]
    norm2: eqx.nn.BatchNorm
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear

    def __init__(self, key):
        key1, key2, key3, key4 = jr.split(key, 4)
        self.norm1 = eqx.nn.BatchNorm(input_size=3, axis_name="batch")
        self.spectral_linear = eqx.nn.SpectralNorm(
            layer=eqx.nn.Linear(in_features=3, out_features=32, key=key1),
            weight_name="weight",
            key=key2,
        )
        self.norm2 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
        self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key3)
        self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key4)

    def __call__(self, x, state):
        x, state = self.norm1(x, state)
        x, state = self.spectral_linear(x, state)
        x = jax.nn.relu(x)
        x, state = self.norm2(x, state)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state

In [3]:
def compute_loss(model, state, xs, ys):
    # The `axis_name` argument is needed specifically for `BatchNorm`: so it knows
    #     what axis to compute batch statistics over.
    # The `in_axes` and `out_axes` are needed so that `state` isn't batched.
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_ys, state = batch_model(xs, state)
    loss = jnp.mean((pred_ys - ys) ** 2)
    return loss, state


@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
    grads, state = eqx.filter_grad(compute_loss, has_aux=True)(model, state, xs, ys)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state

In [4]:
dataset_size = 10
learning_rate = 3e-4
steps = 5
seed = 5678

key = jr.PRNGKey(seed)
mkey, xkey, xkey2 = jr.split(key, 3)
model = Model(mkey)
state = eqx.nn.State(model)
xs = jr.normal(xkey, (dataset_size, 3))
ys = jnp.sin(xs) + 1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

# Full-batch gradient descent in this simple example.
for _ in range(steps):
    model, state, opt_state = make_step(model, state, opt_state, xs, ys)

Overall, we see that this should be relatively straightforward!

When calling `state = eqx.nn.State(model)`, then the model PyTree is iterated over, and any stateful layers store their initial states in the resulting `state` object. The `state` object is itself also a PyTree, so it can just be passed around in the usual way.

In this example, `state` will store the running statistics for `BatchNorm`, and U-V power iterations for `SpectralNorm`.

Subsequently, we just need to thread the `state` object in-and-out of every call. Each time a new state object is returned. (And the old state object should not be reused.)

Finally, let's use our trained model to perform inference:

In [5]:
inference_model = eqx.nn.inference_mode(model)
inference_model = eqx.Partial(inference_model, state=state)


@eqx.filter_jit
def evaluate(model, xs):
    # discard state
    out, _ = jax.vmap(model)(xs)
    return out


test_dataset_size = 5
test_xs = jr.normal(xkey2, (test_dataset_size, 3))
pred_test_ys = evaluate(inference_model, test_xs)

Here, we don't need the updated state object that is output, so we just discard it.

(Also, don't forget to set the `inference` flags.)