<a href="https://colab.research.google.com/github/petergchang/sarkka-jax/blob/main/car_track.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import jax.numpy as jnp
import jax.random as jr
from jax import lax

In [2]:
A = jnp.array([[0.4, -0.2],
               [-0.2, 0.5]])
mean = jnp.array([0.5, 0.5])



In [3]:
B = jnp.linalg.cholesky(A)
B

DeviceArray([[ 0.6324555 ,  0.        ],
             [-0.31622776,  0.6324555 ]], dtype=float32)

In [4]:
key = jr.PRNGKey(0)
jr.multivariate_normal(key, mean, A, shape=(1000,), method='cholesky')

DeviceArray([[ 0.67499745,  0.29148975],
             [-0.40067995,  1.0965476 ],
             [-0.8792937 ,  1.6155363 ],
             ...,
             [ 0.7249069 ,  0.8972255 ],
             [ 0.83733207, -0.89884055],
             [ 1.1302555 ,  0.9071536 ]], dtype=float32)

In [5]:
def _atleast2d(*args):
    return tuple(jnp.atleast_2d(elem) for elem in args)

In [6]:
def generate_ssm(key, m_0, A, Q, H, R, steps):
    m_0 = jnp.atleast_1d(m_0)
    A, Q, H, R = _atleast2d(A, Q, H, R)
    
    M, N = m_0.shape[-1], R.shape[-1]
    states = []
    observations = []

    state = m_0
    for i in range(steps):
        key, rng = jr.split(key)
        state = A @ state + jr.multivariate_normal(rng, jnp.zeros(M), Q, method='cholesky')
        states.append(state)
        obs = H @ state + jr.multivariate_normal(rng, jnp.zeros(N), R, method='cholesky')
        observations.append(obs)
    
    return states, observations

In [7]:
M, N = 2, 1
m_0 = jnp.zeros(M)
P_0 = Q = jnp.array([[0.4, -0.2],
                     [-0.2, 0.5]])
A = jnp.zeros((M, M))
H = [0., 0.]
R = 0.5
key = jr.PRNGKey(0)
states, observations = generate_ssm(key, m_0, A, Q, H, R, 10000)

In [10]:
def generate_ssm(key, m_0, A, Q, H, R, steps):
    def _step(carry, rng_and_t):
        state = carry
        rng, t = rng_and_t
        rng1, rng2 = jr.split(rng, 2)
        next_state = A @ state + jr.multivariate_normal(rng1, jnp.zeros(M), Q)
        observation = H @ state + jr.multivariate_normal(rng2, jnp.zeros(N), R)
        return next_state, (state, observation)
    m_0 = jnp.atleast_1d(m_0)
    A, Q, H, R = _atleast2d(A, Q, H, R)
    M, N = m_0.shape[-1], R.shape[-1]

    state = m_0
    rngs = jr.split(key, steps)
    _, (states, observations) = lax.scan(
        _step, state, (rngs, jnp.arange(steps))
    )
    return states, observations

In [11]:
M, N = 2, 1
m_0 = jnp.zeros(M)
P_0 = Q = jnp.array([[0.4, -0.2],
                     [-0.2, 0.5]])
A = jnp.zeros((M, M))
H = [0., 0.]
R = 0.5
key = jr.PRNGKey(0)
states, observations = generate_ssm(key, m_0, A, Q, H, R, 10000)

In [12]:
states

DeviceArray([[ 0.        ,  0.        ],
             [-0.5000623 , -0.02209828],
             [ 0.17737901,  0.69841117],
             ...,
             [ 1.093647  , -1.5127336 ],
             [-0.19785918,  0.9732304 ],
             [ 1.0864521 ,  0.43852913]], dtype=float32)