# Steady states

This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.

In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.

This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/steady_state.ipynb).

In [1]:
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax

In [2]:
class ExponentialDecayToSteadyState(eqx.Module):
    steady_state: float

    def __call__(self, t, y, args):
        return self.steady_state - y

In [3]:
def loss(model, target_steady_state):
    term = diffrax.ODETerm(model)
    solver = diffrax.Tsit5()
    t0 = 0
    t1 = jnp.inf
    dt0 = None
    y0 = 1.0
    max_steps = None
    controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
    cond_fn = diffrax.steady_state_event()
    event = diffrax.Event(cond_fn)
    adjoint = diffrax.ImplicitAdjoint()
    # This combination of event, t1, max_steps, adjoint is particularly
    # natural: we keep integration forever until we hit the event, with
    # no maximum time or number of steps. Backpropagation happens via
    # the implicit function theorem.
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        max_steps=max_steps,
        stepsize_controller=controller,
        event=event,
        adjoint=adjoint,
    )
    (y1,) = sol.ys
    return (y1 - target_steady_state) ** 2

In [4]:
model = ExponentialDecayToSteadyState(
    jnp.array(0.0)
)  # initial steady state guess is 0.
# target steady state is 0.76
target_steady_state = jnp.array(0.76)
optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)
opt_state = optim.init(model)


@eqx.filter_jit
def make_step(model, opt_state, target_steady_state):
    grads = eqx.filter_grad(loss)(model, target_steady_state)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state


for step in range(100):
    model, opt_state = make_step(model, opt_state, target_steady_state)
    print(f"Step: {step} Steady State: {model.steady_state}")
print(f"Target: {target_steady_state}")

Step: 0 Steady State: 0.025839969515800476
Step: 1 Steady State: 0.05824900045990944
Step: 2 Steady State: 0.09451568126678467
Step: 3 Steady State: 0.1327039748430252
Step: 4 Steady State: 0.1714443564414978
Step: 5 Steady State: 0.20979028940200806
Step: 6 Steady State: 0.24709881842136383
Step: 7 Steady State: 0.28294941782951355
Step: 8 Steady State: 0.31707584857940674
Step: 9 Steady State: 0.34934186935424805
Step: 10 Steady State: 0.37968698143959045
Step: 11 Steady State: 0.4081074893474579
Step: 12 Steady State: 0.43463948369026184
Step: 13 Steady State: 0.45934492349624634
Step: 14 Steady State: 0.48230400681495667
Step: 15 Steady State: 0.5036059021949768
Step: 16 Steady State: 0.5233321189880371
Step: 17 Steady State: 0.5415896773338318
Step: 18 Steady State: 0.5584752559661865
Step: 19 Steady State: 0.5740804076194763
Step: 20 Steady State: 0.5884985327720642
Step: 21 Steady State: 0.6018134951591492
Step: 22 Steady State: 0.6141058206558228
Step: 23 Steady State: 0.625450