# Example 3.6 (Negative Binomial VAR(1) SSM)

In [None]:
from isssm.typing import GLSSMState, GLSSM
from isssm.models.pgssm import nb_pgssm
from isssm.pgssm import simulate_pgssm
import jax.numpy as jnp
import jax.random as jrn

key = jrn.PRNGKey(1142424457)

## Setup model

In [None]:
alpha = 0.1
n = 100
Sigma = jnp.ones((n, 1, 1))

u = jnp.zeros((n + 1, 1))
A = alpha * jnp.ones((n, 1, 1))

glssm_state = GLSSMState(
    u=u,
    A=A,
    D=jnp.ones((n, 1, 1)),
    Sigma0=jnp.eye(1),
    Sigma=Sigma,
)

In [None]:
r = 2.0
glssm = GLSSM(
    v=jnp.zeros((n + 1, 1)),
    B=jnp.ones((n + 1, 1, 1)),
    Omega=jnp.empty((n + 1, 1, 1)),
    **glssm_state._asdict(),
)

model = nb_pgssm(
    glssm,
    r=r,
)

## Simulate from distribution

In [None]:
key, subkey = jrn.split(key)
(_,), (Y,) = simulate_pgssm(model, N=1, key=subkey)

## Perform LA 

In [None]:
from isssm.laplace_approximation import laplace_approximation
from isssm.importance_sampling import pgssm_importance_sampling, ess_pct

proposal, info = laplace_approximation(y=Y, model=model, n_iter=1000, eps=1e-6)
key, subkey = jrn.split(key)


def proposal_ess_pct(y, model, proposal, N=1000, key=None):
    _, lw = pgssm_importance_sampling(
        y=y,
        model=model,
        z=proposal.z,
        Omega=proposal.Omega,
        N=N,
        key=key,
    )

    return ess_pct(lw)


proposal_ess_pct(Y, model, proposal, N=1000, key=subkey)

## Find "true" proposals for CE/EIS

In [None]:
from isssm.ce_method import cross_entropy_method as cem
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as eis,
)

N_true = 100_000

key, subkey = jrn.split(key)
cem_true, log_w_cem = cem(model, Y, N_true, subkey, n_iter=100)
ess_pct(log_w_cem)

In [None]:
key, subkey = jrn.split(key)
eis_true, _ = eis(
    y=Y,
    model=model,
    z_init=proposal.z,
    Omega_init=proposal.Omega,
    n_iter=100,
    N=N_true,
    key=subkey,
)
key, subkey = jrn.split(key)

proposal_ess_pct(Y, model, eis_true, N=1000, key=subkey)

## Obtain proposals $\hat{\mathbf P}^N_{\text{CE}}$ and $\hat{\mathbf P}^N_{\text{EIS}}$

In [None]:
def finite_N_cem(N, key):
    key, subkey = jrn.split(key)
    cem_finite, log_w_cem = cem(model, Y, N, subkey, n_iter=100)
    return cem_finite

## Extract marginal means and variances

In [None]:
def cem_marginals(cem_proposal):
    

## Calculate MSE

## Bias-Variance decomposition