# Neural ODE

This example trains a [Neural ODE](https://arxiv.org/abs/1806.07366) to reproduce a toy dataset of nonlinear oscillators.

This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/neural_ode.ipynb).

In [1]:
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

We use [Equinox](https://github.com/patrick-kidger/equinox) to build neural networks. We use [Optax](https://github.com/deepmind/optax) for optimisers (Adam etc.)

Recalling that a neural ODE is defined as

$y(t) = y(0) + \int_0^t f_\theta(s, y(s)) ds$,

then here we're now about to define the $f_\theta$ that appears on that right hand side.

In [2]:
#this thing is a constructor, apparently; and constructors in python apparently always start with a reference to the current instance
class Func(eqx.Module):
    mlp: eqx.nn.MLP # MLP is apparently just the standard feedforward neural network; takes in an input layer, goes through some hidden layers, and has an output layer

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth, #depth of neural ode / # of layers
            activation=jnn.softplus, #activation function applied to each layer; 
            #softplus is ln(1+e^x), which approaches 0 for large negatives, and approaches linearity for large positives
            key=key, #used to randomize initial parameter
        )

    def __call__(self, t, y, args): #allows object name to be used as a function ex. self(1, 2, [idk]) would call this (and probably result in an infinite loop, but still)
        return self.mlp(y)

Here we wrap up the entire ODE solve into a model.

In [3]:
class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(), #the method by which the ODE is being solved
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts), max_steps = 1000000,
        )
        return solution.ys

Toy dataset of nonlinear oscillators. Sample paths look like deformed sines and cosines.

In [4]:
def _get_data(ts, *, key):
    #data_key, model_key, loader_key = jr.split(key, 3)
    def f(t, state, args):
        G = 6.67430e-11
        M = 5.972e24
        mu = G * M

        r = jnp.linalg.norm(jnp.asarray(state[:3]))

        a = -mu * jnp.array(state[:3], float) / r**3

        return [state[3], state[4], state[5], a[0], a[1], a[2]]

    term = diffrax.ODETerm(f)
    solver = diffrax.Tsit5()
    
    inclination = 0
    G = 6.67430e-11
    M = 5.972e24
    r = 7.378e6
    a = r
    mu = G * M
    T = 2 * 2 * jnp.pi * jnp.sqrt(a**3/mu) #formula for period of an orbit = 2 * jnp.pi * jnp.sqrt(a**3/mu)
    v = jnp.sqrt(2 * mu/r - mu/a) #formula for velocity of an orbit = jnp.sqrt(2 * mu/r - mu/a)


    y0 = [0, r * jnp.cos(inclination), -r * jnp.sin(inclination), v, 0, 0] 
    ts = jnp.linspace(0, T, 1000) 
    saveat = diffrax.SaveAt(ts=ts)

    sol = diffrax.diffeqsolve(term, solver, ts[0], ts[-1], dt0 = 0.1, y0=y0, saveat=saveat, max_steps=1000000)
    ys = sol.ys #i have no idea why this is necessary, namely the bit with the key
    #print(jnp.shape(ys))
    ys = jnp.swapaxes(jnp.array(ys), 0, 1)
    return ts, ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 130*60, 1000) 
    key = jr.split(key, dataset_size)
    ts, ys1 = _get_data(ts, key=key)
    #print(jnp.shape(ys1))
    ys2 = jax.vmap(lambda key: ys1)(key)
    ys = jnp.array(ys2)
    return ts, ys

In [5]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size) #arange returns evenly spaced values within a given interval; in this case, 0 to dataset_size exclusive, integer values only
    while True:
        perm = jr.permutation(key, indices) #returns a shuffled version of indices
        (key,) = jr.split(key, 1) #splits key into two, stores the first half?
        start = 0
        end = batch_size #32
        while end < dataset_size: #end < 1000
            batch_perm = perm[start:end] #takes end elements from the indices
            yield tuple(array[batch_perm] for array in arrays) #returns tuples of 
            #print(array[batch_perm])
            start = end
            end = start + batch_size

Main entry point. Try runnning `main()`.

In [6]:
def main(
    dataset_size=256,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(0.1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    #print(key)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys = get_data(dataset_size, key=data_key)
    _, length_size, data_size = jnp.shape(ys)
    #print(ys.shape)
    #print(length_size)
    #print(data_size)

    model = NeuralODE(data_size, width_size, depth, key=model_key)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy): #combines elements of each list into tuples
        optim = optax.adabelief(lr) #initializes optimizer called adabelief, with lr being "learning rate"?
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) #filter seems to have exactly zero documentation without something added onto it
        #based on context, though, and considering is_inexact_array is a boolean function that returns true if an element is an inexact array(whatever that means)
        # it's possible filter just identifies those elements for which is_inexact_array returns true
        _ts = ts[: int(length_size * length)] #this seems to define how far into the dataset training occurs, because length is first 0.1 then 1
        #also this is splitting ts into its first (length_size * length) elements, for clarity
        _ys = ys[:, : int(length_size * length)] #same as above except with ys
        for step, (yi,) in zip(
            range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ): 
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        model_y = model(ts, ys[0, 0])
        plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
        plt.plot(ts, model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model

In [7]:
ts, ys, model = main()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Some notes on speed:
The hyperparameters for the above example haven't really been optimised. Try experimenting with them to see how much faster you can make this example run. There's lots of things you can try tweaking:

- The size of the neural network.
- The numerical solver.
- The step size controller, including both its step size and its tolerances.
- The length of the dataset. (Do you really need to use all of a time series every time?)
- Batch size, learning rate, choice of optimiser.
- ... etc.!

Some notes on being Markov:

- This example has assumed that the problem is Markov. Essentially, that the data `ys` is a complete observation of the system, and that we're not missing any channels. Note how the result of our model is evolving in data space. This is unlike e.g. an RNN, which has hidden state, and a linear map from hidden state to data.
- If we wanted we could generalise this to the non-Markov case: inside `NeuralODE`, project the initial condition into some high-dimensional latent space, do the ODE solve there, then take a linear map to get the output. See the [Latent ODE example](../latent_ode) for an example doing this as part of a generative model; also see [Augmented Neural ODEs](https://arxiv.org/abs/1904.01681) for a short paper on it.