In [None]:
#| default_exp glssm_models
%load_ext autoreload
%autoreload 2

In [None]:
# imports for this notebook only
import fastcore.test as fct

# Example gaussian models


## Locally constant model

The locally costant model is a very basic example of a gaussian state space model. It is a univariate model with the following dynamics:

$$
\begin{align*}
    X_{t + 1} &= X_t + \varepsilon_{t + 1} & & \varepsilon_{t + 1} \sim \mathcal N(0, \sigma^2_\varepsilon), \\
    Y_t &= X_t + \eta_t && \eta_{t} \sim \mathcal N(0, \sigma^2_\eta).
\end{align*}
$$

In this model the states $X_t$ perform a discrete time, univariate, random walk and are observed with noise $\eta_t$. 

In [None]:
# | export
import jax.numpy as jnp
from jaxtyping import Float, Array
def lcm(n: int, x0: Float, s2_x0: Float, s2_eps: Float, s2_eta: Float):
    """Univariate locally constant model

    Parameters
    ----------
    n : int
        number of observations
    x0 : Float
        initial mean
    s2_x0 : Float
        initial variance
    s2_eps : Float
        variance of innovations
    s2_eta : Float
        variance of disturbances

    Returns
    -------
    model
        a GLSSM
    """
    A = jnp.ones((n, 1, 1))
    B = jnp.ones((n + 1, 1, 1))

    Sigma = jnp.concatenate((s2_x0 * jnp.ones((1, 1, 1)), s2_eps * jnp.ones((n, 1, 1))))
    Omega = jnp.ones((n + 1, 1, 1)) * s2_eta

    x0 = jnp.array(x0).reshape((1,))

    return x0, A, B, Sigma, Omega

In [None]:
n = 10
x0, A, B, Sigma, Omega = lcm(n, 0., 1., 1., 1.)

# assess that shapes are correct
fct.test_eq(x0.shape, (1,))
fct.test_eq(Sigma.shape, (n+1,1,1))
fct.test_eq(Omega.shape, (n+1,1,1))
fct.test_eq(A.shape, (n,1,1))
fct.test_eq(B.shape, (n+1,1,1))

## Common factor locally constant model
A model where we have two observations each being the sum of two states where one state is shared.

$$
\begin{align*}
    X_{t + 1} &= X_t + \varepsilon_{t + 1} && \varepsilon_{t + 1} \sim \mathcal N(0, \sigma^2_\varepsilon I_3) \\
    Y_t &= \begin{pmatrix} 1 & 0 & 1 \\ 0 & 1 & 1 \end{pmatrix} X_t + \eta_t && \eta_t \sim \mathcal N(0, \sigma^2_\eta I_2)
\end{align*}
$$

In [None]:
#| export
def common_factor_lcm(n: int, x0: Float[Array, "3"], Sigma0: Float[Array, "3 3"], s2_eps: Float, s2_eta: Float):

    if x0.shape != (3,):
        raise ValueError(f"x0 does not have the correct shape, expected (3,) but got {x0.shape}")
    
    A = jnp.broadcast_to(jnp.eye(3), (n, 3, 3))
    B = jnp.broadcast_to(jnp.array([[1,0,1], [0,1,1]]), (n+1,2,3))
    Sigma = jnp.concatenate(
        (Sigma0[None], s2_eps * jnp.broadcast_to(jnp.eye(3), (n,3,3)))
    )
    Omega = s2_eta * jnp.broadcast_to(jnp.eye(2), (n+1,2,2))
    
    return x0, A, B, Sigma, Omega

In [None]:
n = 10
x0, A, B, Sigma, Omega = common_factor_lcm(n, jnp.zeros(3), jnp.eye(3), 1., 1.)

# assess that shapes are correct
fct.test_eq(x0.shape, (3,))
fct.test_eq(Sigma.shape, (n+1,3,3))
fct.test_eq(Omega.shape, (n+1,2,2))
fct.test_eq(A.shape, (n,3,3))
fct.test_eq(B.shape, (n+1,2,3))

## Stationary AR(1) model

States form a stationary AR(1) process with stationary distribution $\mathcal N(\mu, \tau^2)$, observed with noise 
$$
\begin{align*}
    \alpha &\in \left( -1, 1\right) \\
     \sigma^2 &= (1 - \alpha^2)\tau^2\\
    X_{t + 1} &= \mu + \alpha (X_t - \mu) + \varepsilon_{t + 1}\\
    \varepsilon_t &\sim \mathcal N(0, \sigma^2)\\
    Y_t &= X_t + \eta_t \\
    \eta_t &\sim \mathcal N(0, \omega^2)
\end{align*}
$$


In [None]:
#| export
def ar1(mu, tau2, alpha, omega2, n):
    x0 = mu
    A = jnp.tile(alpha * jnp.eye(1)[None], (n, 1, 1))
    B = jnp.tile(jnp.eye(1)[None], (n + 1, 1, 1))

    sigma2 = (1 - alpha ** 2) * tau2
    Sigma = jnp.concatenate((tau2 * jnp.ones((1, 1, 1)), sigma2 * jnp.ones((n, 1, 1))))

    Omega = omega2 * jnp.ones((n + 1, 1, 1))

    return x0, A, B, Sigma, Omega

In [None]:
import nbdev
nbdev.nbdev_export()