# Application 1: Showcase 

In [None]:
from ssm4epi.models.regional_growth_factor import (
    key,
    n_iterations,
    N_mle,
    N_meis,
    N_posterior,
    percentiles_of_interest,
    make_aux,
    dates_full,
    cases_full,
    n_ij,
    n_tot,
    account_for_nans,
    growth_factor_model,
)

import jax.numpy as jnp
import jax
import jax.random as jrn

from isssm.importance_sampling import prediction
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)

from pyprojroot.here import here

jax.config.update("jax_enable_x64", True)
from isssm.estimation import initial_theta
import pickle

In [None]:
initial_date = dates_full[15]
np1 = 10
dates = dates_full[15 : 15 + np1]
aux = make_aux(initial_date, cases_full, n_ij, n_tot, np1)

y = aux[0][1:]
y_nan = y.at[-1].set(jnp.nan)
missing_inds = jnp.isnan(y_nan)
theta_manual = jnp.array(
    [5.950e00, -2.063e00, -5.355e00, -4.511e-01, -5.711e-01, 7.932e-01]
)
_, y_miss = account_for_nans(
    growth_factor_model(theta_manual, aux), y_nan, missing_inds
)
_model_miss = lambda theta, aux: account_for_nans(
    growth_factor_model(theta, aux), y_nan, missing_inds
)[0]

theta0_result = initial_theta(y_miss, _model_miss, theta_manual, aux, n_iterations)
theta0 = theta0_result.x
fitted_model = _model_miss(theta0, aux)

proposal_la, info_la = LA(y_miss, fitted_model, n_iterations)
key, subkey = jrn.split(key)
proposal_meis, info_meis = MEIS(
    y_miss, fitted_model, proposal_la.z, proposal_la.Omega, n_iterations, N_meis, subkey
)
key, subkey = jrn.split(key)


def f_pred(x, s, y):
    y_total = y[-1].sum()[None]
    y_counties = y[-1]
    growth_factors = s.reshape(-1)
    return jnp.concatenate([y_total, y_counties, growth_factors])


preds = prediction(
    f_pred,
    y_miss,
    proposal_la,
    fitted_model,
    N_posterior,
    subkey,
    percentiles_of_interest,
    growth_factor_model(theta0, aux),
)

result = (theta0, proposal_meis, preds, dates, y)

with open(here() / "data/results/4_local_outbreak_model/results.pickle", "wb") as f:
    pickle.dump(result, f)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/stefan/workspace/work/phd/thesis/data/results/4_local_outbreak_model/results.pickle'

In [None]:
preds

(Array([1.66255525e+04, 2.12210181e+00, 1.03437335e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00], dtype=float64),
 Array([7.66581469e+03, 2.91949927e+00, 1.49179858e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00], dtype=float64),
 Array([[6.15800000e+03, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.56250733e+03, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [7.59300000e+03, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [2.92606867e+04, 8.00000000e+00, 3.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [3.63747669e+04, 1.00000000e+01, 4.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [4.12856030e+04, 1.30000000e+01, 7.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]