In [1]:
{
    "tags": [
        "hide-cell"
    ]
}

### Install necessary libraries

try:
    import jax
except:
    # For cuda version, see https://github.com/google/jax#installation
    %pip install --upgrade "jax[cpu]" 
    import jax

try:
    import jsl
except:
    %pip install git+https://github.com/probml/jsl
    import jsl

try:
    import rich
except:
    %pip install rich
    import rich




Collecting jax[cpu]
  Downloading jax-0.3.5.tar.gz (946 kB)
[K     |████████████████████████████████| 946 kB 2.7 MB/s eta 0:00:01
[?25hCollecting absl-py
  Downloading absl_py-1.0.0-py3-none-any.whl (126 kB)
[K     |████████████████████████████████| 126 kB 47.7 MB/s eta 0:00:01
[?25hCollecting numpy>=1.19
  Downloading numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl (17.6 MB)
[K     |████████████████████████████████| 17.6 MB 47.5 MB/s eta 0:00:01
[?25hCollecting opt_einsum
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting scipy>=1.2.1
  Downloading scipy-1.8.0-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl (55.3 MB)
[K     |████████████████████████████████| 55.3 MB 73.1 MB/s eta 0:00:01
[?25hCollecting typing_extensions
  Using cached typing_extensions-4.1.1-py3-none-any.whl (26 kB)
Collecting jaxlib==0.3.5
  Downloading jaxlib-0.3.5-cp38-none-macosx_10_9_x86_64.whl (70.5 MB)
[K     |████████████████████████████████| 70.5 MB 723 kB/s  eta 0:00:01
[?2

In [1]:
{
    "tags": [
        "hide-cell"
    ]
}


### Import standard libraries

import abc
from dataclasses import dataclass
import functools
import itertools

from typing import Any, Callable, NamedTuple, Optional, Union, Tuple

import matplotlib.pyplot as plt
import numpy as np


import jax
import jax.numpy as jnp
from jax import lax, vmap, jit, grad
from jax.scipy.special import logit
from jax.nn import softmax
from functools import partial
from jax.random import PRNGKey, split

import inspect
import inspect as py_inspect
from rich import inspect as r_inspect
from rich import print as r_print

def print_source(fname):
    r_print(py_inspect.getsource(fname))

ModuleNotFoundError: No module named 'rich'

(sec:ssm-intro)=
# What are State Space Models?


A state space model or SSM
is a partially observed Markov model,
in which the hidden state,  $z_t$,
evolves over time according to a Markov process.



```{figure} /figures/SSM-AR-inputs.png
:scale: 100%
:name: ssm-ar

Illustration of an SSM as a graphical model.
```

```{figure} /figures/SSM-simplified.png
:scale: 100%
:name: ssm-simplifed

Illustration of a simplified SSM.
```

(sec:casino-ex)=
## Example: Casino HMM

We first create the "Ocassionally dishonest casino" model from {cite}`Durbin98`.



There are 2 hidden states, each of which emit 6 possible observations.

In [5]:
# state transition matrix
A = np.array([
    [0.95, 0.05],
    [0.10, 0.90]
])

# observation matrix
B = np.array([
    [1/6, 1/6, 1/6, 1/6, 1/6, 1/6], # fair die
    [1/10, 1/10, 1/10, 1/10, 1/10, 5/10] # loaded die
])

pi, _ = normalize(np.array([1, 1]))
pi = np.array(pi)


(nstates, nobs) = np.shape(B)


