# Parameter estimation

In this notebook, we see how to use the moment filter to estimate model parameters. The model that we use here is

$$
\begin{equation}
    \begin{split}
        \mathrm{d} X(t) &= X(t) \, \bigl( 1 - \theta_1 \, X(t)^2 \bigr) \mathrm{d} t + \mathrm{d} W(t), \\
        X(0) &\sim \frac{1}{2}\bigl( \mathcal{N}(-0.5, 0.05) + \mathcal{N}(0.5, 0.05)\bigr), \\
        Y_k \mid X_k &\sim \mathrm{Poisson}\Bigl( \log\bigl(1 + \exp(\theta_2 \, X_k)\bigr) \Bigr),
    \end{split}
\end{equation}
$$

and we aim to estiamte $\theta_1$ and $\theta_2$ from the measurements. Let us set the true values of them two be 3.

## Note
This notebook assumes that you have already taken a look at `./benes_bernoulli.ipynb` to get familiar with the moment filter.


In [1]:
import jax
import jax.numpy as jnp
import jaxopt
import tme.base_jax as tme
from mfs.one_dim.filtering import moment_filter_cms
from mfs.utils import GaussianSum1D, simulate_sde
from jax.config import config
from functools import partial

# Using `float64` is often necessary.
config.update("jax_enable_x64", True)

# Random seed
key = jax.random.PRNGKey(123)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


# Define the model

Define the model and make a simulation to generate data.

In [2]:
true_theta1, true_theta2 = 3., 3.


def drift(x, p):
    """Drift function
    """
    return x * (1 - p * x ** 2)


def dispersion(_):
    """Dispersion function
    """
    return 1.


def emission(x, p):
    return jnp.log(1. + jnp.exp(p * x))


def measurement_cond_pmf(y, x, p):
    return jax.scipy.stats.poisson.pmf(y, emission(x, p))


# The order which gives to 2 * N - 1 moments. The higher the more accurate.
N = 5

# Initial condition
init_cond = GaussianSum1D.new(means=jnp.array([-0.5, 0.5]),
                              variances=jnp.array([0.05, 0.05]),
                              weights=jnp.array([0.5, 0.5]),
                              N=N)

# Random keys for simulation
key_x0, key_xs, key_ys = jax.random.split(key, 3)

# Times
dt = 1e-2
T = 1000
ts = jnp.linspace(dt, dt * T, T)

# Simulate an initial, a trajectory, and measurements
x0 = init_cond.sampler(key_x0, 1)[0]
xs = simulate_sde(
    lambda _x, _dt: tme.mean_and_cov(jnp.atleast_1d(_x), _dt, lambda _x: drift(_x, true_theta1), dispersion, order=3),
    x0, dt, T, key_xs, diagonal_cov=False, integration_steps=100)[:, 0]
ys = jax.random.poisson(key_ys, emission(xs, true_theta2), (T,))

# Objective function

Now this is the key. We need to create an objective function of the unknown parameters and outputs the negative log-likelihood.

The last return of the moment filter function is `nell` which is the negative log-likelihood.

In [3]:
# The objective function
@jax.jit
def obj_func(params, _ys):
    # Use a bijection to ensure the positivity of the parameters
    params = jnp.log(jnp.exp(params) + 1.)

    def _drift(x):
        return drift(x, params[0])

    def _measurement_cond_pmf(y, x):
        return measurement_cond_pmf(y, x, params[1])

    @partial(jax.vmap, in_axes=[0, None, None])
    @partial(jax.vmap, in_axes=[None, 0, None])
    def state_cond_central_moments(x, n, mean):
        def phi(u):
            return (u - mean) ** n

        return jnp.squeeze(tme.expectation(phi, jnp.atleast_1d(x), dt, _drift, dispersion, order=3))

    @partial(jax.vmap, in_axes=[0])
    def state_cond_mean(x):
        return jnp.squeeze(tme.expectation(lambda u: u, jnp.atleast_1d(x), dt, _drift, dispersion, order=3))

    _, _, nell = moment_filter_cms(state_cond_central_moments, state_cond_mean, _measurement_cond_pmf,
                                   init_cond.cms, init_cond.mean, _ys)
    return nell

Use L-BFGS-B to do the optimisation, starting from initials 0.1.

In [4]:
# Run optimisation
init_params = jnp.log(jnp.exp(jnp.array([0.1, 0.1])) - 1.)
opt_solver = jaxopt.ScipyMinimize(method='L-BFGS-B', jit=True, fun=obj_func)
opt_params, opt_state = opt_solver.run(init_params, ys)
opt_params = jnp.log(jnp.exp(opt_params) + 1.)

In [5]:
print(opt_state)
print(f'Learnt parameters: {opt_params}')

ScipyMinimizeInfo(fun_val=Array(1354.68233404, dtype=float64, weak_type=True), success=True, status=0, iter_num=16)
Learnt parameters: [2.55806413 3.35835942]
