In [None]:
# | default_exp models.pgssm

# Partially Gaussian State Space Models with linear Signal

In [None]:
# | export
import jax.numpy as jnp
from jaxtyping import Float
from isssm.typing import GLSSM, PGSSM
from tensorflow_probability.substrates.jax.distributions import (
    NegativeBinomial as NBinom,
    Poisson,
)


def nb_pgssm(glssm: GLSSM, r: Float):
    np1, p, m = glssm.B.shape
    xi = jnp.full((np1, p), r)

    def dist_nb(log_mu, xi):
        mu = jnp.exp(log_mu)
        return NBinom(xi, probs=mu / (xi + mu))

    return PGSSM(
        glssm.u,
        glssm.A,
        glssm.D,
        glssm.Sigma0,
        glssm.Sigma,
        glssm.v,
        glssm.B,
        dist_nb,
        xi,
    )


def poisson_pgssm(glssm: GLSSM):
    np1, p, m = glssm.B.shape
    xi = jnp.empty((np1, p))

    def dist_poisson(log_mu, xi):
        return Poisson(log_rate=log_mu)

    return PGSSM(
        glssm.u,
        glssm.A,
        glssm.D,
        glssm.Sigma0,
        glssm.Sigma,
        glssm.v,
        glssm.B,
        dist_poisson,
        xi,
    )

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()