# Regional Growth Factor Model

In [1]:
#| default_exp model

In [2]:
#| export
import jax.numpy as jnp
import jax.scipy as jsp
import jax
from jax import vmap, jit
import matplotlib.pyplot as plt
import matplotlib as mpl

from isssm.typing import PGSSM, GLSSMState
from jaxtyping import Array, Float

from tensorflow_probability.substrates.jax.distributions import NegativeBinomial as NBinom

In [3]:
mpl.rcParams['figure.figsize'] = (20, 6)
jax.config.update("jax_enable_x64", True)


## States
Let $\bar r_t = \log \bar \rho_t$ be the average growth factor (on the log scale) across all counties and $u_t^c$ be the deviation of county $c$'s log-growth factor for $\bar r_t$. 
We model $\bar r_t$ following a random walk and $u^\cdot_t$ following an $\text{VAR}(1)$ process with transition matrix $\alpha I$ and spatial correlation $\Omega$.

Thus we have
$$
\begin{align*}
\bar r_{t + 1} &= \bar r_{t} + \varepsilon_{t + 1}^\rho \\
u^c_{t + 1} &= \alpha u^c_t + \varepsilon_{t + 1}^c
\end{align*}
$$

with $\text{Var}(\varepsilon_{t + 1}^\rho) = \sigma^2_\rho$ and $\text{Cov}(\varepsilon_{t + 1}) = \Omega$. 
The covariance matrix of the stationary distribution is $\Sigma = \frac {1} {1 - \alpha^2} \Omega$.

In [4]:
#| export
def _state_model(r0, u0, alpha, s2_rho, Omega, n) -> GLSSMState:
    x0 = jnp.concatenate([r0, u0])
    K, = u0.shape
    A = jnp.broadcast_to(
        jsp.linalg.block_diag(alpha, jnp.eye(K)),
        (n, K + 1, K + 1)
    )
    Sigma0 = 1e3 * jnp.eye(K + 1)
    Sigma = jnp.linalg.block_diag(s2_rho * jnp.eye(1), 1/(1 - alpha ** 2) * Omega)
    Sigma = jnp.concatenate([ 
        Sigma0[None, ...], 
        jnp.broadcast_to(Sigma, (n, K + 1, K + 1))
        ], axis=0)

    return GLSSMState(x0, A, Sigma)

## Observations
The log growth factor in a region at time $t$, $r^c_t$, is given by the mean log growth factor $\bar r_t$ and the per-region deviation $u^c_t$.
Conditional on the log growth factors and past cases, cases are Negative Binomially distributed with shared overdispersion parameter $r$
$$
\begin{align*}
r^c_t &= \bar r_t + u^c_{t} \\
\lambda_t^c &= \exp(r^c_t) \sum_{d}p_{c,d} I_t^d \\
I^c_{t + 1} | I^c_{t}, \rho_t, u^c_t &\sim \text{NegBinom}(\underbrace{\lambda_t^c}_{\text{mean}}, \underbrace{r}_{\text{overdispersion}})
\end{align*}
$$

where the variance of the negative binomial distribution is $\mu + \frac{\mu^2}{r}$


In [5]:
#| export
def _observation_model(
        obs: Float[Array, "n+2 K"],
        P: Float[Array, "K K"],
        r: Float
    ):

    np2, p = obs.shape

    delayed_obs = obs[:-1]
    cases_adjusted = vmap(jnp.matmul, (None, 0))(P, delayed_obs)

    xi = jnp.concatenate((
        jnp.full((np2 - 1, p, 1), r),
        cases_adjusted[:,:, None]
    ), axis=-1)

    def dist_obs(signal, xi):
        r, sum_I =jnp.moveaxis(xi, -1, 0)
        return NBinom(r, logits=signal + jnp.log(sum_I))
    
    return dist_obs, xi

## Spatial Correlations

Suppose we have in county $c$ $S^c$ many new infections generated, which may be attributed to the same county or another one, $c'$ say.
Let $p_{c,c'}$ be the fraction of cases generated in county $c'$ (instead of in county $c$) and $p_{c,c} = 1 - \sum_{c' \neq c} p_{c,c'}$.


Let $P = \left(p_{c,c'}\right)_{c = 1, \dots, K, c' = 1,\dots, K}$.

Let $\tilde S^c = \sum p_{c,d} S^d$ be the number of cases generated in county $c$.

Then we are interested in 

$$
\begin{align*}
    \text{Cov}(\tilde S^c, \tilde S^{c'}) = \underbrace{\sigma^2}_{\text{Var}(S^c)} (\sum_{d} p_{c,d} p_{c',d}) = \sigma^2_{\text{spat}} (P P^T)_{c, c'}
\end{align*}
$$


To obtain $p_{c,c'}$ we use data on commuters $q_{c,c'}$ the fraction of socially insured employees that have their center of life in county $c$ but are registered to work in county $c'$. 

To account for non-working inhabitants (elderly, kids, ...) we introduce a constant $C \geq 1$ s.t. 
$$
p_{c,c'} = \bar q + (1 - \bar q)\frac{\mathbf 1 _{c \neq c'} q_{c,c'}}{ \sum_{d \neq c} q_{c,d} + C q_{c,c}},
$$
i.e. we blow up the proportion of "stay at home" by a constant $C$ (that is the same for all counties) and add a constant "socket" of travel $\bar q$ between the counties.


Finally we choose

$$
\Omega = \sigma^2_{\text{spat}}PP^T %+ \sigma^2_{\text{nugget}} I
$$

In [6]:
#| export
def _P(C, q, n_ij, n_tot) -> Float[Array, "K K"]:
    p, _ = n_ij.shape
    m_ij = n_ij + jnp.diag(C * n_tot - n_ij.sum(axis=1)) 
    normalize_rows = lambda x: x / x.sum(axis=1).reshape((-1,1))
    return jnp.full((p,p), q / p) + (1 - q) * normalize_rows(m_ij)


## Parameters

$$
\theta = \left( \text{logit}(\alpha), \log \sigma^2_r, \log \sigma^2_{\text{spat}}, C, \log \mu \right)%\log \sigma^2_{\text{nugget}}, \log \mu \right)
$$


## Final Model

In [7]:
#| export
def growth_factor_model(
        theta,
        aux
    ) -> PGSSM:
    
    logit_alpha, log_s2_r, log_s2_spat, logit_q, C, log_r= theta
    obs, n_ij, n_tot= aux

    np2, _ = obs.shape
    np1 = np2 - 1
    K, = n_tot
    
    alpha = jsp.special.expit(logit_alpha)
    s2_rho = jnp.exp(log_s2_r)
    s2_spat = jnp.exp(log_s2_spat)
    r = jnp.exp(log_r)
    q = jsp.special.expit(logit_q)

    P = _P(C, q, n_ij, n_tot)
    state  = _state_model(jnp.zeros(1), jnp.zeros(K), alpha, s2_rho, s2_spat * P @ P.T)
    dist, xi = _observation_model(obs, P, r)

    B = jnp.broadcast_to(
        jnp.block([jnp.zeros((K, 1)), jnp.eye(K)]),
        (np1, K, K + 1)
    )

    return PGSSM(
        x0 = state.x0,
        A = state.A,
        Sigma = state.Sigma,
        B = B,
        dist = dist,
        xi = xi
    )

In [8]:
from nbdev import export
export.nb_export(
    "10_model.ipynb",
    "src"
)