In [None]:
#| default_exp glssm_models
%load_ext autoreload
%autoreload 2

In [None]:
# imports for this notebook only
import fastcore.test as fct

# Example gaussian models

## Locally constant model

The locally costant model is a very basic example of a gaussian state space model. It is a univariate model with the following dynamics:

$$
\begin{align*}
    X_{t + 1} &= X_t + \varepsilon_{t + 1} & & \varepsilon_{t + 1} \sim \mathcal N(0, \sigma^2_\varepsilon) \\
    Y_t &= X_t + \eta_t && \eta_{t} \sim \mathcal N(0, \sigma^2_\eta)
\end{align*}
$$

In [None]:
# | export
import jax.numpy as jnp
def lcm(n, x0, s2_x0, s2_eps, s2_eta):
    A = jnp.ones((n, 1, 1))
    B = jnp.ones((n + 1, 1, 1))

    Sigma = jnp.concatenate((s2_x0 * jnp.ones((1, 1, 1)), s2_eps * jnp.ones((n, 1, 1))))
    Omega = jnp.ones((n + 1, 1, 1)) * s2_eta

    x0 = jnp.array(x0).reshape((1,))

    return x0, A, B, Sigma, Omega

In [None]:
n = 10
x0, A, B, Sigma, Omega = lcm(n, 0., 1., 1., 1.)

# assess that shapes are correct
fct.test_eq(x0.shape, (1,))
fct.test_eq(Sigma.shape, (n+1,1,1))
fct.test_eq(Omega.shape, (n+1,1,1))
fct.test_eq(A.shape, (n,1,1))
fct.test_eq(B.shape, (n+1,1,1))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Common factor locally constant model
A model where we have two observations each being the sum of two states where one state is shared.

$$
\begin{align*}
    X_{t + 1} &= X_t + \varepsilon_{t + 1} && \varepsilon_{t + 1} \sim \mathcal N(0, \sigma^2_\varepsilon I_3) \\
    Y_t &= \begin{pmatrix} 1 & 0 & 1 \\ 0 & 1 & 1 \end{pmatrix} X_t + \eta_t && \eta_t \sim \mathcal N(0, \sigma^2_\eta I_2)
\end{align*}
$$

In [None]:
#| export
def common_factor_lcm(n, x0, Sigma0, s2_eps, s2_eta):

    if x0.shape != (3,):
        raise ValueError(f"x0 does not have the correct shape, expected (3,) but got {x0.shape}")
    
    A = jnp.broadcast_to(jnp.eye(3), (n, 3, 3))
    B = jnp.broadcast_to(jnp.array([[1,0,1], [0,1,1]]), (n+1,2,3))
    Sigma = jnp.concatenate(
        (Sigma0[None], s2_eps * jnp.broadcast_to(jnp.eye(3), (n,3,3)))
    )
    Omega = s2_eta * jnp.broadcast_to(jnp.eye(2), (n+1,2,2))
    
    return x0, A, B, Sigma, Omega

In [None]:
n = 10
x0, A, B, Sigma, Omega = common_factor_lcm(n, jnp.zeros(3), jnp.eye(3), 1., 1.)

# assess that shapes are correct
fct.test_eq(x0.shape, (3,))
fct.test_eq(Sigma.shape, (n+1,3,3))
fct.test_eq(Omega.shape, (n+1,2,2))
fct.test_eq(A.shape, (n,3,3))
fct.test_eq(B.shape, (n+1,2,3))

In [None]:
import nbdev
nbdev.nbdev_export()