In this notebook we exemplify some methods for estimating the conditional moments of SDE.

Note that this notebook uses `jax`.

In [7]:
import math
import jax.numpy as jnp
import jax.random
import tme.base_jax as tme
from functools import partial
from jax.config import config
import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)

Define the SDE coefficients. This is a geometric Brownian motion. We are interested in computing $\mathbb{E}[(X(t))^n \mid X(0) = x_0]$ for some $n$, where in this example, we let $x_0 = 1$ and $t=0.1$.

In [8]:
a, b = -3., 1.


def drift(x):
    return jnp.tanh(x)


def dispersion(x):
    return jnp.array([b])


x0 = 0.
t = 1.

This SDE is analytically solvable, and its raw moments are also in closed-form.

In [9]:
@partial(jax.vmap, in_axes=[0])
def tme_raw_moment(order):
    return tme.expectation(lambda u: u ** order, jnp.array([x0]), t,
                           drift=drift, dispersion=dispersion, order=3)

In [10]:
# Moment orders of interests
moment_orders = jnp.arange(1, 10)

# Approximate raw moment using 3-order TME
tme_raw_moment(jnp.array([1, 2, 3]))

DeviceArray([[nan],
             [nan],
             [nan]], dtype=float64)

In [11]:
z=tme.expectation(lambda u: u ** 2, jnp.array([x0]), t,
                           drift=drift, dispersion=dispersion, order=3)[0]

In [12]:
z

DeviceArray(2., dtype=float64)