# Utilities for the models of this thesis
> comment

In [None]:
# | default_exp models.util

In [None]:
# | export
# libraries
import matplotlib.pyplot as plt
from isssm.typing import PGSSM
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import jax.numpy as jnp
from jaxtyping import Float, Array
import fastcore.test as fct

## Visualization

In [None]:
# | export


def __zero_to_nan(arr, eps=1e-10):
    return jnp.where(jnp.abs(arr) < eps, jnp.nan, arr)


def visualize_pgssm(pgssm: PGSSM):
    fig, axes = plt.subplots(nrows=1, ncols=3)
    cmap = cm.get_cmap("viridis")

    A, B, D, Sigma = pgssm.A[0], pgssm.B[0], pgssm.D[0], pgssm.Sigma[0]
    max = jnp.max(jnp.array([A.max(), B.max(), D.max()]))
    min = jnp.min(jnp.array([A.min(), B.min(), D.min()]))

    normalizer = Normalize(min, max)
    im = cm.ScalarMappable(norm=normalizer)
    axes[0].imshow(__zero_to_nan(A), cmap=cmap, norm=normalizer)
    axes[0].set_title("A")
    axes[1].imshow(__zero_to_nan(B), cmap=cmap, norm=normalizer)
    axes[1].set_title("B")
    axes[2].imshow(__zero_to_nan(D), cmap=cmap, norm=normalizer)
    axes[2].set_title("D")

    fig.colorbar(im, ax=axes.ravel().tolist())
    plt.show()

    plt.imshow(__zero_to_nan(Sigma))
    plt.colorbar()
    plt.show()

## Computation

For $p \in \mathbf R^{k}_{>0}$ with $\sum_{i = 1}^k p_{i} = 1$, let $\log q_i = \log \frac{p_{i}}{p_{k}}$ for $i = 1, \dots, k -1$. Then 
$$
    p_{k} = \frac{1}{1 + \sum_{i = 1}^{k-1}q_{i}},
$$
so 
$$
    p_{i} = q_{i} p_{k} = \frac{q_{i}}{1 + \sum_{i = 1}^{k-1}q_{i}}.

$$



In [None]:
# | export
def to_log_probs(log_ratios: Float[Array, "k-1"]) -> Float[Array, "k"]:
    exp_q = jnp.exp(log_ratios)
    p_n_delay = 1 / (1 + exp_q.sum(axis=-1, keepdims=True))
    log_p = jnp.log(jnp.concatenate([exp_q * p_n_delay, p_n_delay], axis=-1))
    return log_p

In [None]:
# | hide
fct.test_close(to_log_probs(jnp.zeros(4)), jnp.log(jnp.ones(5) / 5))

Another parametrization takes consecutive conditonal probabilities, using logits to make the problem unconstrained.

Thus for $p\in \mathbf R^k$ we have
$$
    q_{i} = \frac{p_{i}}{1 - \sum_{j = 1}^{i - 1} p_{i}},
$$
for $i = 1, \dots, k - 1$ ($q_k$ is $1$ and can be discarded).  

Then for $i = 1, \dots, k$ 
$$
    p_{i} = q_{i} \prod_{j = 1}^{i - 1}(1 - q_j).
$$




In [None]:
# | export
import jax.scipy as jsp


def to_consecutive_logits(probs: Float[Array, "k"]) -> Float[Array, "k-1"]:
    cum_probs = jnp.cumsum(probs[..., ::-1], axis=-1)[..., ::-1]
    return jsp.special.logit(probs[:-1] / cum_probs[:-1])


def from_consecutive_logits(
    consecutive_logits: Float[Array, "k-1"]
) -> Float[Array, "k"]:
    q = jsp.special.expit(consecutive_logits)
    q_ext = jnp.concatenate(
        (jnp.zeros_like(q[..., :1]), q, jnp.ones_like(q[..., :1])), axis=-1
    )

    p = q_ext[..., 1:] * jnp.cumprod(1 - q_ext[..., :-1], axis=-1)
    return p

In [None]:
# | hide
from jax import random

In [None]:
# | hide
fct.test_close(
    to_consecutive_logits(jnp.ones(3) / 3), jsp.special.logit(jnp.array([1 / 3, 1 / 2]))
)
fct.test_close(
    from_consecutive_logits(jsp.special.logit(jnp.array([1 / 3, 1 / 2]))),
    jnp.ones(3) / 3,
)
fct.test_close(
    to_consecutive_logits(jnp.array([1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 16])),
    jnp.array([0.0, 0.0, 0.0, 0.0]),
)

random_logits = random.normal(random.PRNGKey(0), (100,))
random_probs = from_consecutive_logits(random_logits)
fct.test_close(to_consecutive_logits(random_probs), random_logits)
fct.test_close(
    from_consecutive_logits(to_consecutive_logits(random_probs)), random_probs
)

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()