# Neural DDE

This example demonstrates how to use Diffrax in order to solve a Delay Differential Equation (DDE) with known delays.  
Unlike ODEs that are identified by their vector field $f(t, y(t))$ and initial condition $y(0)=y_0$, DDEs are specified by their vector field $f$, deviated arguments $y(t-\tau)$ and history function $\phi(t)=y(t<0)$.

We will model the [Lotka Volterra](https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations) (LK) equations with one constant time delay defined as 

$$
\begin{align}
& y_1'(t) = \frac{1}{2} y_1(t) ( 1  - y_2(t-0.2)) \\
& y_2'(t) = -\frac{1}{2} y_2(t)( 1  - y_1(t-0.2)) \\
& \phi(t) = y(t<0) = (y_{1,0}, y_{2,0}) 
\end{align}
$$

where $y_{1,0}, y_{2,0}$ are uniformly sampled in $[1.0,1.5]$.

This example is available as a Jupyter notebook [here](url).

In [1]:
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In order to model our problem as a DDE $y'(t) = f_{\theta}(t, y(t), y(t-\tau_1), \dots, y(t-\tau_d))$, we first need to define a `Delays` object that incorporates deviated arguments in our vector field $f$.  

LK's initial time point $t=0$ has a derivative jump because $\phi^{\prime}(t=0^{-}) \neq  y^{\prime}(t=0^{+})$ and the history function $\phi(t)$ has `None`.  
The DDE model only has one time delay so $d=1$ and our vector field will be $y'(t) = f_{\theta}(t, y(t), y(t-\tau))$

In [2]:
delays = diffrax.Delays(
    delays=[lambda t, y, args: 0.2], initial_discontinuities=jnp.array([0.0])
)

Below is defined the vector field $f_{\theta}$. 

In [3]:
class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=2 * data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.relu,
            key=key,
        )

    def __call__(self, t, y, args, *, history):
        return self.mlp(jnp.hstack([y, *history]))

The `history` variable inside the network's `__call__`  is a tuple of deviated arguments. For example, if we possess a `Delays` object with 2 delays then the first element of tuple would be the first deviated argument $y(t-\tau_1)$ and the second one $y(t-\tau_1)$.  
In our case, `history[0]` corresponds to $y(t-0.2)$ and by extension `history[0][0]` is $y_1(t-0.2)$.

Here we wrap up the entire DDE solve into a model.

In [4]:
class NeuralDDE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Dopri5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=lambda t: y0,
            delays=delays,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts, dense=True),
        )
        return solution.ys

We generate the LK dataset.

In [5]:
def _get_data(ts, *, key):
    y0 = jrandom.uniform(key, (2,), minval=1.0, maxval=1.5)

    def vector_field(t, y, args, history):
        return jnp.array(
            [
                1 / 2 * y[0] * (1 - history[0][1]),
                -1 / 2 * y[1] * (1 - history[0][0]),
            ]
        )

    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        diffrax.Dopri5(),
        t0=ts[0],
        t1=ts[-1],
        dt0=ts[1] - ts[0],
        y0=lambda t: y0,
        stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
        saveat=diffrax.SaveAt(ts=ts, dense=True),
        delays=delays,
    )

    return sol.ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 15, 150)
    key = jrandom.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys


def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

Main entry point. Try runnning `main()`.

In [6]:
def main(
    dataset_size=256,
    batch_size=128,
    lr_strategy=(3e-3,),
    steps_strategy=(500,),
    length_strategy=(1.0,),
    width_size=32,
    depth=3,
    seed=5679,
    plot=True,
    print_every=5,
):
    key = jrandom.PRNGKey(seed)
    data_key, model_key, loader_key = jrandom.split(key, 3)

    ts, ys = get_data(dataset_size, key=data_key)
    mean, std = jnp.mean(ys), jnp.std(ys)
    ys = (ys - mean) / std
    _, length_size, data_size = ys.shape

    model = NeuralDDE(data_size, width_size, depth, key=model_key)

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[: int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        for step, (yi,) in zip(
            range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(_ts, _ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(_ts, _ys[0, :, 1], c="dodgerblue")
        model_y = model(_ts, _ys[0, 0])
        plt.plot(_ts, model_y[:, 0], c="crimson", label="Model")
        plt.plot(_ts, model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_dde.png")
        plt.close()

    return ts, ys, model

In [None]:
ts, ys, model = main()

Step: 0, Loss: 1.401884913444519, Computation time: 14.241726636886597
Step: 5, Loss: 1.134800910949707, Computation time: 1.192063570022583
Step: 10, Loss: 0.9904348254203796, Computation time: 1.9642481803894043
Step: 15, Loss: 0.9832042455673218, Computation time: 2.6581308841705322
Step: 20, Loss: 0.9569293260574341, Computation time: 3.8219752311706543
Step: 25, Loss: 0.9038560390472412, Computation time: 3.6977927684783936
Step: 30, Loss: 0.8163420557975769, Computation time: 35.42212176322937
Step: 35, Loss: 0.6886805295944214, Computation time: 6.909891843795776
Step: 40, Loss: 0.3398251533508301, Computation time: 5.504199266433716
Step: 45, Loss: 0.16126586496829987, Computation time: 5.103270769119263
Step: 50, Loss: 0.08195476979017258, Computation time: 5.78333044052124
Step: 55, Loss: 0.06952822208404541, Computation time: 10.413585901260376
Step: 60, Loss: 0.031119057908654213, Computation time: 7.735013723373413
Step: 65, Loss: 0.025405289605259895, Computation time: 7.