In [None]:
#| default_exp estimation

# Maximum Likelihood estimation

In [None]:
#| export

import jax.numpy as jnp
import jax.random as jrn
from jax import vmap

In [None]:
#| hide
from isssm.kalman import kalman
from isssm.glssm_models import lcm
from isssm.glssm import simulate_glssm
import tensorflow_probability.substrates.jax.distributions as tfd
import fastcore.test as fct

  _warn(("h5py is running against HDF5 {0} when it was built against {1}, "


## Gaussian linear models

For Gaussian linear state space models we can evaluate the likelihood analytically with a single pass of the Kalman filter.
Based on the predictions $\hat Y_{t| t - 1}$ and associated covariance matrices $\Psi_{t + 1 | t}$ for $t = 0, \dots n$ produced by the Kalman filter we can derive the gaussian negative log likelihood which is given by the gaussian distribution with that mean and covariance matrix and observation $Y_t$. 

In [None]:
# | export
from jaxtyping import Float, Array

vmm = vmap(jnp.matmul, (0, 0))


def gnll(
    y: Float[Array, "n+1 p"],  # observations $y_t$
    x_pred: Float[Array, "n+1 m"],  # predicted states $\hat X_{t+1\bar t}$
    Xi_pred: Float[Array, "n+1 m m"],  # predicted state covariances $\Xi_{t+1\bar t}$
    B: Float[Array, "n+1 p m"],  # state observation matrices $B_{t}$
    Omega: Float[Array, "n+1 p p"],  # observation covariances $\Omega_{t}$
) -> Float:  # gaussian negative log-likelihood
    """Gaussian negative log-likelihood"""
    y_pred = vmm(B, x_pred)
    Psi_pred = vmm(vmm(B, Xi_pred), jnp.transpose(B, (0, 2, 1))) + Omega

    return -tfd.MultivariateNormalFullCovariance(y_pred, Psi_pred).log_prob(y).sum()

In [None]:
#| hide

x0, A, B, Sigma, Omega = lcm(1, 0., 1., 1., 1.)
_, (y,) = simulate_glssm(x0, A, B, Sigma, Omega, 1, jrn.PRNGKey(34234))

x_filt, Xi_filt, x_pred, Xi_pred = kalman(y, x0, Sigma, Omega, A, B)
nll = gnll(y, x_pred, Xi_pred, B, Omega)


EY = jnp.zeros((2,))
CovY = jnp.array([[2., 1.], [1., 3.]])

fct.test_eq(nll, -tfd.MultivariateNormalFullCovariance(EY, CovY).log_prob(y.reshape(-1)))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## MLE in GLSSMs

For a parametrized GLSSM, that is a model that depends on parameters $\theta$, we can use numerical optimization to find the maximum likelihood estimatior.

::: {.callout-caution}
With this method, the user has to take care that they provide a parametrization that is unconstrained, i.e. using $\log$ transformations for positive parameters.
:::

In [None]:
# | export
from jax.scipy.optimize import minimize, OptimizeResults

def mle_glssm(
    y: Float[Array, "n+1 p"],  # observations $y_t$
    model,  # parameterize GLSSM
    theta0: Float[Array, "k"],  # initial parameter guess
    aux,  # auxiliary data for the model
) -> OptimizeResults:  # result of MLE optimization
    """Maximum likelihood estimation for GLSSM"""
    def f(theta: Float[Array, "k"]) -> Float:
        x0, A, B, Sigma, Omega = model(theta, aux)
        _, _, x_pred, Xi_pred = kalman(y, x0, Sigma, Omega, A, B)
        return gnll(y, x_pred, Xi_pred, B, Omega)

    return minimize(f, theta0, method="BFGS")

In [None]:
def parameterized_lcm(theta, aux):
    log_s2_eps, log_s2_eta = theta
    n, x0, s2_x0 = aux

    return lcm(n, x0, s2_x0, jnp.exp(log_s2_eps), jnp.exp(log_s2_eta))
    
theta = jnp.log(jnp.array([2., 3.]))
aux = (100, 0., 1.)
x0, A, B, Sigma, Omega = parameterized_lcm(theta, aux)
_, (y,) = simulate_glssm(x0, A, B, Sigma, Omega, 1, jrn.PRNGKey(15435324))

result = mle_glssm(y, parameterized_lcm, jnp.ones(2), aux)
jnp.exp(result.x), jnp.exp(theta)

(Array([1.3734169, 3.719784 ], dtype=float32), Array([2., 3.], dtype=float32))

## Asymptotic behavior of MLE

::: {.callout-note}
# TODO
Write / Implement
:::

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()