# Neural SDE

This example constructs a neural SDE as a generative time series model.

An SDE is, of course, random: it defines some distribution. Each sample is a whole path. Thus in modern machine learning parlance, an SDE is a generative time series model. This means it can be trained as a GAN, for example. This does mean we need a discriminator that consumes a path as an input; we use a CDE.

Training an SDE as a GAN is precisely what this example does. Doing so will reproduce the following toy example:

![ou](../imgs/neural_sde.png)

**References:**

Training SDEs as GANs:
```bibtex
@inproceedings{kidger2021sde1,
    title={{N}eural {SDE}s as {I}nfinite-{D}imensional {GAN}s},
    author={Kidger, Patrick and Foster, James and Li, Xuechen and Lyons, Terry J},
    booktitle = {Proceedings of the 38th International Conference on Machine Learning},
    pages = {5453--5463},
    year = {2021},
    volume = {139},
    series = {Proceedings of Machine Learning Research},
    publisher = {PMLR},
}
```

Improved training techniques:
```bibtex
@incollection{kidger2021sde2,
    title={{E}fficient and {A}ccurate {G}radients for {N}eural {SDE}s},
    author={Kidger, Patrick and Foster, James and Li, Xuechen and Lyons, Terry},
    booktitle = {Advances in Neural Information Processing Systems 34},
    year = {2021},
    publisher = {Curran Associates, Inc.},
}
```

This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/neural_sde.ipynb).

In [None]:
from typing import Union

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import tqdm  # https://github.com/tqdm/tqdm

LipSwish activation functions are a good choice for the discriminator of an SDE-GAN. (Their use here was introduced in the second reference above.)
For simplicity we will actually use LipSwish activations everywhere, even in the generator.

In [None]:
def lipswish(x):
    return 0.909 * jnn.silu(x)

Now set up the vector fields appearing on the right hand side of each differential equation.

In [None]:
class VectorField(eqx.Module):
    scale: Union[int, jnp.ndarray]
    mlp: eqx.nn.MLP

    def __init__(self, hidden_size, width_size, depth, scale, *, key):
        super().__init__()
        scale_key, mlp_key = jrandom.split(key)
        if scale:
            self.scale = jrandom.uniform(
                scale_key, (hidden_size,), minval=0.9, maxval=1.1
            )
        else:
            self.scale = 1
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size + 1,
            out_size=hidden_size,
            width_size=width_size,
            depth=depth,
            activation=lipswish,
            final_activation=jnn.tanh,
            key=mlp_key,
        )

    @eqx.filter_jit
    def __call__(self, t, y, args):
        return self.scale * self.mlp(jnp.concatenate([t[None], y]))


class ControlledVectorField(eqx.Module):
    scale: Union[int, jnp.ndarray]
    mlp: eqx.nn.MLP
    control_size: int
    hidden_size: int

    def __init__(self, control_size, hidden_size, width_size, depth, scale, *, key):
        super().__init__()
        scale_key, mlp_key = jrandom.split(key)
        if scale:
            self.scale = jrandom.uniform(
                scale_key, (hidden_size, control_size), minval=0.9, maxval=1.1
            )
        else:
            self.scale = 1
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size + 1,
            out_size=hidden_size * control_size,
            width_size=width_size,
            depth=depth,
            activation=lipswish,
            final_activation=jnn.tanh,
            key=mlp_key,
        )
        self.control_size = control_size
        self.hidden_size = hidden_size

    @eqx.filter_jit
    def __call__(self, t, y, args):
        return self.scale * self.mlp(jnp.concatenate([t[None], y])).reshape(
            self.hidden_size, self.control_size
        )

Something pretty neat: for this toy example, both the generator and the discriminator have essentially the same structure. We describe this through this `DifferentialEquation` class that we'll use for both.

In [None]:
class DifferentialEquation(eqx.Module):
    initial: eqx.nn.MLP
    vf: VectorField  # drift
    cvf: ControlledVectorField  # diffusion
    readout: eqx.nn.Linear

    def __init__(
        self,
        initial_size,
        control_size,
        hidden_size,
        width_size,
        depth,
        readout_size,
        scale,
        *,
        key,
    ):
        super().__init__()
        initial_key, vf_key, cvf_key, readout_key = jrandom.split(key, 4)

        self.initial = eqx.nn.MLP(
            initial_size, hidden_size, width_size, depth, key=initial_key
        )
        self.vf = VectorField(hidden_size, width_size, depth, scale, key=vf_key)
        self.cvf = ControlledVectorField(
            control_size, hidden_size, width_size, depth, scale, key=cvf_key
        )
        self.readout = eqx.nn.Linear(hidden_size, readout_size, key=readout_key)

    @eqx.filter_jit
    def _initial(self, ts, init):
        t0 = ts[0]
        t1 = ts[-1]
        y0 = self.initial(init)
        # We use a huge step size. By using a large step size we essentially "bake in"
        # the discretisation.
        # This is quite a standard thing to do when the vector field is a pure neural
        # network.
        # You can reduce the step size here if you want to -- which will increase the
        # computational cost, of course.
        dt0 = 1.0
        return t0, t1, y0, dt0

    @eqx.filter_jit
    def _readout(self, ys):
        return jax.vmap(self.readout)(ys)

    def __call__(self, ts, init, control, saveat):
        vf = diffrax.ODETerm(self.vf)  # Drift term
        cvf = diffrax.ControlTerm(self.cvf, control)  # Diffusion term
        term = diffrax.MultiTerm((vf, cvf))  # Combine both terms
        solver = diffrax.ReversibleHeun(term)  # Choice of solver.

        t0, t1, y0, dt0 = self._initial(ts, init)
        sol = diffrax.diffeqsolve(solver, t0, t1, y0, dt0, saveat=saveat)
        return self._readout(sol.ys)

Now set up the neural SDE (the generator) and the neural CDE (the discriminator).

In [None]:
class NeuralSDE(eqx.Module):
    diffeq: DifferentialEquation
    initial_noise_size: int
    noise_size: int

    def __init__(
        self,
        data_size,
        initial_noise_size,
        noise_size,
        hidden_size,
        width_size,
        depth,
        *,
        key,
    ):
        super().__init__()
        self.diffeq = DifferentialEquation(
            initial_size=initial_noise_size,
            control_size=noise_size,
            hidden_size=hidden_size,
            width_size=width_size,
            depth=depth,
            readout_size=data_size,
            scale=True,  # Needed for a flexible generator.
            key=key,
        )
        self.initial_noise_size = initial_noise_size
        self.noise_size = noise_size

    @eqx.filter_jit
    def _setup(self, ts, key):
        init_key, bm_key = jrandom.split(key, 2)
        init = jrandom.normal(init_key, (self.initial_noise_size,))
        control = diffrax.UnsafeBrownianPath(shape=(self.noise_size,), key=bm_key)
        saveat = diffrax.SaveAt(ts=ts)  # Record output at all times.
        return ts, init, control, saveat

    def __call__(self, ts, *, key):
        return self.diffeq(*self._setup(ts, key))


class NeuralCDE(eqx.Module):
    diffeq: DifferentialEquation

    def __init__(self, data_size, hidden_size, width_size, depth, *, key):
        super().__init__()
        self.diffeq = DifferentialEquation(
            initial_size=data_size + 1,
            control_size=data_size,
            hidden_size=hidden_size,
            width_size=width_size,
            depth=depth,
            readout_size=1,
            # Want to constrain the Lipschitz norm of the discriminator. (See also the
            # `clip_weights` method below.)
            scale=False,
            key=key,
        )

    @staticmethod
    @eqx.filter_jit
    def _setup(ts, ys):
        # Interpolate data into a continuous path.
        ys = diffrax.linear_interpolation(
            ts, ys, replace_nans_at_start=0.0, fill_forward_nans_at_end=True
        )
        init = jnp.concatenate([ts[0, None], ys[0]])
        control = diffrax.LinearInterpolation(ts, ys)
        # Have the discriminator produce an output at both `t0` *and* `t1`.
        # The output at `t0` has only seen the initial point of a sample. This gives
        # additional supervision to the distribution learnt for the initial condition.
        # The output at `t1` has seen the entire path of a sample. This is needed to
        # actually learn the evolving trajectory.
        saveat = diffrax.SaveAt(t0=True, t1=True)
        return ts, init, control, saveat

    def __call__(self, ts, ys):
        return self.diffeq(*self._setup(ts, ys))

    @eqx.filter_jit
    def clip_weights(self):
        # Equinox modules are just PyTrees like any other, so we can flatten them etc.
        # like any other PyTree.
        leaves, treedef = jax.tree_flatten(
            self, is_leaf=lambda x: isinstance(x, eqx.nn.Linear)
        )
        new_leaves = []
        for leaf in leaves:
            if isinstance(leaf, eqx.nn.Linear):
                lim = 1 / leaf.out_features
                leaf = eqx.tree_at(
                    lambda x: x.weight, leaf, leaf.weight.clip(-lim, lim)
                )
            new_leaves.append(leaf)
        return jax.tree_unflatten(treedef, new_leaves)

Note `clip_weights` method on the CDE above -- this is part of imposing the Lipschitz condition on the discriminator of a Wasserstein GAN.
(The other thing doing this is the use of those LipSwish activation functions we saw earlier)

Coming up we now have mostly standard stuff for training a GAN: creating the dataset, training, plotting results etc.

In [None]:
def get_data(key):
    bm_key, y0_key, drop_key = jrandom.split(key, 3)

    mu = 0.02
    theta = 0.1
    sigma = 0.4

    t0 = 0
    t1 = 63
    t_size = 64

    def drift(t, y, args):
        return mu * t - theta * y

    def diffusion(t, y, args):
        return 2 * sigma * t / t1

    bm = diffrax.UnsafeBrownianPath(shape=(), key=bm_key)
    solver = diffrax.euler_maruyama(drift, diffusion, bm)
    y0 = jrandom.uniform(y0_key, (1,), minval=-1, maxval=1)
    dt0 = 0.1
    ts = jnp.linspace(t0, t1, t_size)
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(solver, t0, t1, y0, dt0, saveat=saveat)
    ys = sol.ys

    to_drop = jrandom.bernoulli(drop_key, 0.3, (t_size, 1))
    ys = jnp.where(to_drop, jnp.nan, ys)

    return ts, ys


def make_dataloader(arrays, batch_size, loop, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size
        if not loop:
            break


def loss(generator, discriminator, ts_i, ys_i, key, step=0):
    batch_size, _ = ts_i.shape
    key = jrandom.fold_in(key, step)
    key = jrandom.split(key, batch_size)
    fake_ys_i = jax.vmap(generator)(ts_i, key=key)
    real_score = jax.vmap(discriminator)(ts_i, ys_i)
    fake_score = jax.vmap(discriminator)(ts_i, fake_ys_i)
    return jnp.mean(real_score - fake_score)


@eqx.filter_grad
def grad_loss(g_d, ts_i, ys_i, key, step):
    generator, discriminator = g_d  # We differentiate just the first argument
    return loss(generator, discriminator, ts_i, ys_i, key, step)


# This is one very helpful trick. The distribution learnt for the initial condition of
# the SDE can sometimes train quite poorly.
# Increasing its learning rate (here by a factor of 10) seems to help with this.
def _increase_update_initial(updates):
    get_initial_leaves = lambda u: jax.tree_leaves(u.diffeq.initial)
    mul = lambda x: x * 10
    return eqx.tree_at(where=get_initial_leaves, pytree=updates, replace_fn=mul)


@eqx.filter_jit
def update(
    generator, discriminator, g_opt_state, d_opt_state, g_optim, d_optim, g_grad, d_grad
):
    g_updates, g_opt_state = g_optim.update(g_grad, g_opt_state)
    d_updates, d_opt_state = d_optim.update(d_grad, d_opt_state)
    g_updates = _increase_update_initial(g_updates)
    d_updates = _increase_update_initial(d_updates)
    generator = eqx.apply_updates(generator, g_updates)
    discriminator = eqx.apply_updates(discriminator, d_updates)
    discriminator = discriminator.clip_weights()
    return generator, discriminator, g_opt_state, d_opt_state

This is our main entry point. Try running `main()`.

In [None]:
def main(
    initial_noise_size=5,
    noise_size=3,
    hidden_size=16,
    width_size=16,
    depth=1,
    generator_lr=2e-5,
    discriminator_lr=1e-4,
    batch_size=1024,
    steps=10000,
    steps_per_print=10,
    dataset_size=8192,
    seed=5678,
):

    key = jrandom.PRNGKey(seed)
    (
        data_key,
        generator_key,
        discriminator_key,
        dataloader_key,
        train_key,
        evaluate_key,
        sample_key,
    ) = jrandom.split(key, 7)
    data_key = jrandom.split(data_key, dataset_size)

    ts, ys = jax.vmap(get_data)(data_key)
    _, _, data_size = ys.shape

    generator = NeuralSDE(
        data_size,
        initial_noise_size,
        noise_size,
        hidden_size,
        width_size,
        depth,
        key=generator_key,
    )
    discriminator = NeuralCDE(
        data_size, hidden_size, width_size, depth, key=discriminator_key
    )

    g_optim = optax.rmsprop(generator_lr)
    d_optim = optax.rmsprop(-discriminator_lr)
    g_opt_state = g_optim.init(eqx.filter(generator, eqx.is_array))
    d_opt_state = d_optim.init(eqx.filter(discriminator, eqx.is_array))

    trange = tqdm.tqdm(range(steps))
    infinite_dataloader = make_dataloader(
        (ts, ys), batch_size, loop=True, key=dataloader_key
    )

    for step, (ts_i, ys_i) in zip(trange, infinite_dataloader):
        g_grad, d_grad = grad_loss((generator, discriminator), ts_i, ys_i, key, step)
        generator, discriminator, g_opt_state, d_opt_state = update(
            generator,
            discriminator,
            g_opt_state,
            d_opt_state,
            g_optim,
            d_optim,
            g_grad,
            d_grad,
        )

        if (step % steps_per_print) == 0 or step == steps - 1:
            total_score = 0
            num_batches = 0
            for ts_i, ys_i in make_dataloader(
                (ts, ys), batch_size, loop=False, key=evaluate_key
            ):
                score = loss(generator, discriminator, ts_i, ys_i, sample_key)
                total_score += score.item()
                num_batches += 1
            trange.write(f"Step: {step}, Loss: {total_score / num_batches}")

    # Plot samples
    fig, ax = plt.subplots()
    num_samples = min(50, dataset_size)
    ts_to_plot = ts[:num_samples]
    ys_to_plot = ys[:num_samples]

    def _interp(ti, yi):
        return diffrax.linear_interpolation(
            ti, yi, replace_nans_at_start=0.0, fill_forward_nans_at_end=True
        )

    ys_to_plot = jax.vmap(_interp)(ts_to_plot, ys_to_plot)[..., 0]
    ys_sampled = jax.vmap(generator)(
        ts_to_plot, key=jrandom.split(sample_key, num_samples)
    )[..., 0]
    kwargs = dict(label="Real")
    for ti, yi in zip(ts_to_plot, ys_to_plot):
        ax.plot(ti, yi, c="dodgerblue", linewidth=0.5, alpha=0.7, **kwargs)
        kwargs = {}
    kwargs = dict(label="Generated")
    for ti, yi in zip(ts_to_plot, ys_sampled):
        ax.plot(ti, yi, c="crimson", linewidth=0.5, alpha=0.7, **kwargs)
        kwargs = {}
    ax.set_title(f"{num_samples} samples from both real and generated distributions.")
    fig.legend()
    fig.tight_layout()
    fig.savefig("neural_sde.png")
    plt.show()