## Comparison of EIS and the CEM for SSMs

Simplified version of regional model in Chapter 4.1, keeping only $\log I_t$ and $\log \rho_t$ in the states. 

- States $X_t = \left(\log I_{t}, \log \rho_{t + 1}\right)$
- Observations $Y_t | X_t \sim \operatorname{Pois} \left( \exp \log I_{t}\right)$

Varying $n = 10, 100, 1000$. Initialize $\log \rho_0 = 0$ with small variance and $\log I_0 = \log 1000$ with small variance as well.

Let $\sigma^2_\rho = \frac{1}{n}0.05$, s.t. $\operatorname{Var} (\log \rho_{n +1}) = 0.05$ and approx. $\mathbf P(\log \rho_{n + 1} \in [-0.1, 0.1]) \geq 0.95$, so approx. $\rho_{n +1} \in [0.9, 1.1]$, ensuring stabilitiy of infections counts (don't go to $0$ or $\infty$).


In [None]:
from pyprojroot import here
from isssm.laplace_approximation import posterior_mode
from isssm.laplace_approximation import posterior_mode
from isssm.importance_sampling import ess_pct
import pandas as pd
from isssm.importance_sampling import pgssm_importance_sampling
from isssm.ce_method import log_weight_cem, simulate_cem
from jax import vmap
from functools import partial
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)
from isssm.ce_method import cross_entropy_method as CEM
from isssm.pgssm import simulate_pgssm
import jax.random as jrn
import jax.numpy as jnp
import jax
from isssm.typing import PGSSM
from tensorflow_probability.substrates.jax.distributions import Poisson

from tqdm.notebook import tqdm

In [2]:
jax.config.update("jax_enable_x64", True)

In [3]:
def _model(n, I0):
    np1 = n + 1
    s2_rho = 0.05 / n if n > 1 else 1

    m = 2
    p = 1
    l = 1

    # states
    u = jnp.zeros((np1, m))
    u = u.at[0, 0].set(jnp.log(I0))

    A = jnp.broadcast_to(jnp.array([[1.0, 1.0], [0.0, 1.0]]), (n, m, m))
    D = jnp.broadcast_to(jnp.eye(m)[:, 1:2], (n, m, l))  # only update rho

    Sigma0 = jnp.array([[1.0, 0.0], [0.0, 0.1]])
    Sigma = jnp.broadcast_to(s2_rho * jnp.eye(1), (n, l, l))

    # observations
    B = jnp.broadcast_to(jnp.eye(m)[:1], (np1, p, m))

    v = jnp.zeros((np1, p))

    def poisson_obs(s, xi):
        return Poisson(log_rate=s)

    dist = poisson_obs

    xi = jnp.empty((np1, p, 1))
    return PGSSM(u, A, D, Sigma0, Sigma, v, B, dist, xi)

In [None]:
def determine_efficiency_factor(n, key):
    pgssm = _model(n, I0=1000)
    key, subkey = jrn.split(key)

    _, (Y,) = simulate_pgssm(pgssm, 1, subkey)

    N_iter = 1000
    N_samples = 10000

    key, sk_meis, sk_cem = jrn.split(key, 3)
    prop_la, _ = LA(Y, pgssm, N_iter)
    prop_meis, _ = MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk_meis)
    prop_cem, lw_cem = CEM(pgssm, Y, N_samples, sk_cem, N_iter)

    N_ef = 10000
    key, sk_la, sk_meis, sk_cem = jrn.split(key, 4)
    _, lw_la = pgssm_importance_sampling(
        Y, pgssm, prop_la.z, prop_la.Omega, N_ef, sk_la
    )
    _, lw_meis = pgssm_importance_sampling(
        Y, pgssm, prop_meis.z, prop_meis.Omega, N_ef, sk_meis
    )

    lw_cem = vmap(partial(log_weight_cem, y=Y, model=pgssm, proposal=prop_cem))(
        simulate_cem(prop_cem, N_samples, sk_cem)
    )

    result = pd.Series(
        {
            "n": n,
            "N_samples": N_samples,
            "N_iter": N_iter,
            "EF_LA": ess_pct(lw_la),
            "EF_MEIS": ess_pct(lw_meis),
            "EF_CEM": ess_pct(lw_cem),
        }
    )

    return result



In [None]:
key = jrn.PRNGKey(140235293)
ns_ef = jnp.repeat(jnp.array([1, 10, 20, 50, 100]), 10)
key, *keys_ef = jrn.split(key, len(ns_ef) + 1)

In [None]:
results_ef = pd.DataFrame([determine_efficiency_factor(n, k) for n, k in zip(ns_ef, keys_ef)])
results_ef.to_csv(here("data/figures/ef_meis_cem_ssms.csv"), index=False)

In [None]:
def asymptotic_det_meis(Y, pgssm, prop_la, N_iter, N_samples, key, M: int):
    key, *subkeys = jrn.split(key, 1 + M)
    proposals = [
        MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk)[0]
        for sk in subkeys
    ]
    modes = jnp.array([posterior_mode(proposal).reshape(-1) for proposal in proposals])
    cov = jnp.cov(modes, rowvar=False) * N_samples
    _, logdet = jnp.linalg.slogdet(cov)

    return logdet

def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int):
    key, *subkeys = jrn.split(key, 1 + M)
    proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
    modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])
    cov = jnp.cov(modes, rowvar=False) * N_samples

    _, logdet = jnp.linalg.slogdet(cov)
    return logdet


def asymptotic_variance(n: int, key: jrn.PRNGKey, N_var: int = 10):
    pgssm = _model(n, I0=1000)
    key, subkey = jrn.split(key)

    _, (Y,) = simulate_pgssm(pgssm, 1, subkey)

    N_iter = 1000
    N_samples = 10000

    prop_la, _ = LA(Y, pgssm, N_iter)

    key, *sks = jrn.split(key, 1 + 2 * N_var)

    sks_meis = sks[:N_var]
    sks_cem = sks[N_var:]

    logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem))
    logdet_meis = asymptotic_det_meis(
        Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis)
    )

    result = pd.Series(
        {
            "n": n,
            "N_samples": N_samples,
            "N_iter": N_iter,
            "log_DET_CEM": logdet_cem,
            "log_DET_MEIS": logdet_meis,
            "ARE": jnp.exp(logdet_cem - logdet_meis),
        }
    )

    return result

In [54]:
key = jrn.PRNGKey(140235293)
ns_are = jnp.repeat(jnp.array([1, 2, 5, 10]), 10)
key, *keys_are = jrn.split(key, len(ns_are) + 1)

In [None]:
results_are = pd.DataFrame(
    [asymptotic_variance(n, k) for n, k in zip(ns_are, keys_are)]
)

results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False)
results_are

Unnamed: 0,n,N_samples,N_iter,DET_CEM,DET_MEIS,ARE
0,1,10000,1000,1.250197594061786e-21,7.362584944337338e-21,0.1698041657262412
1,1,10000,1000,2.1315459186618236e-19,7.12489799034234e-18,0.0299168622701839
2,1,10000,1000,1.5207781065517603e-20,6.692526415181031e-20,0.2272352789078418
3,1,10000,1000,6.007385714992874e-17,6.660606838997179e-11,9.01927686201239e-07
4,1,10000,1000,6.624824337829013e-22,3.177769216726953e-22,2.084740547852769
5,1,10000,1000,1.151610362262488e-21,3.2312036211803578e-21,0.3564029065558564
6,1,10000,1000,3.0832489442082734e-23,1.2071821492071727e-22,0.2554087588383595
7,1,10000,1000,3.1796825882208896e-21,4.0446667403011136e-21,0.7861420463986537
8,1,10000,1000,2.1802986541424173e-18,8.979117586628375e-15,0.0002428188107692
9,1,10000,1000,1.877286375520868e-20,2.2784117172564035e-18,0.0082394519010876


: 