# Utilities for the models of this thesis
> comment

In [None]:
# | default_exp models.util

## Visualization

In [None]:
# | export
import matplotlib.pyplot as plt
from isssm.typing import PGSSM
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import jax.numpy as jnp


def __zero_to_nan(arr, eps=1e-10):
    return jnp.where(jnp.abs(arr) < eps, jnp.nan, arr)


def visualize_pgssm(pgssm: PGSSM):
    fig, axes = plt.subplots(nrows=1, ncols=3)
    cmap = cm.get_cmap("viridis")

    A, B, D, Sigma = pgssm.A[0], pgssm.B[0], pgssm.D[0], pgssm.Sigma[0]
    max = jnp.max(jnp.array([A.max(), B.max(), D.max()]))
    min = jnp.min(jnp.array([A.min(), B.min(), D.min()]))

    normalizer = Normalize(min, max)
    im = cm.ScalarMappable(norm=normalizer)
    axes[0].imshow(__zero_to_nan(A), cmap=cmap, norm=normalizer)
    axes[0].set_title("A")
    axes[1].imshow(__zero_to_nan(B), cmap=cmap, norm=normalizer)
    axes[1].set_title("B")
    axes[2].imshow(__zero_to_nan(D), cmap=cmap, norm=normalizer)
    axes[2].set_title("D")

    fig.colorbar(im, ax=axes.ravel().tolist())
    plt.show()

    plt.imshow(__zero_to_nan(Sigma))
    plt.colorbar()
    plt.show()

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()