<a href="https://colab.research.google.com/github/petergchang/sarkka-jax/blob/main/car_track.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import jax.random as jr

In [10]:
A = jnp.array([[0.4, -0.2],
               [-0.2, 0.5]])
mean = jnp.array([0.5, 0.5])

In [9]:
B = jnp.linalg.cholesky(A)
B

DeviceArray([[ 0.6324555 ,  0.        ],
             [-0.31622776,  0.6324555 ]], dtype=float32)

In [16]:
key = jr.PRNGKey(0)
jr.multivariate_normal(key, mean, A, shape=(1000,), method='cholesky')

DeviceArray([[ 0.67499745,  0.29148975],
             [-0.40067995,  1.0965476 ],
             [-0.8792937 ,  1.6155363 ],
             ...,
             [ 0.7249069 ,  0.8972255 ],
             [ 0.83733207, -0.89884055],
             [ 1.1302555 ,  0.9071536 ]], dtype=float32)

In [20]:
def _atleast2d(*args):
    return tuple(jnp.atleast_2d(elem) for elem in args)

In [26]:
def generate_ssm(key, m_0, A, Q, H, R, steps):
    m_0 = jnp.atleast_1d(m_0)
    A, Q, H, R = _atleast2d(A, Q, H, R)
    
    M, N = m_0.shape[-1], R.shape[-1]
    states = []
    observations = []

    state = m_0
    for i in range(steps):
        key, rng = jr.split(key)
        state = A @ state + jr.multivariate_normal(rng, jnp.zeros(M), Q, method='cholesky')
        states.append(state)
        obs = H @ state + jr.multivariate_normal(rng, jnp.zeros(N), R, method='cholesky')
        observations.append(obs)
    
    return states, observations

In [27]:
M, N = 2, 1
m_0 = jnp.zeros(M)
P_0 = Q = jnp.array([[0.4, -0.2],
                     [-0.2, 0.5]])
A = jnp.zeros((M, M))
H = [0., 0.]
R = 0.5
key = jr.PRNGKey(0)
states, observations = generate_ssm(key, m_0, A, Q, H, R, 10000)

In [29]:
observations

[DeviceArray([-0.8849716], dtype=float32),
 DeviceArray([-0.41482458], dtype=float32),
 DeviceArray([0.34399548], dtype=float32),
 DeviceArray([0.15354186], dtype=float32),
 DeviceArray([-0.4586531], dtype=float32),
 DeviceArray([0.9473319], dtype=float32),
 DeviceArray([0.7342872], dtype=float32),
 DeviceArray([-0.5333644], dtype=float32),
 DeviceArray([-0.42764392], dtype=float32),
 DeviceArray([-1.6617413], dtype=float32),
 DeviceArray([-1.4761981], dtype=float32),
 DeviceArray([1.1549116], dtype=float32),
 DeviceArray([1.4489145], dtype=float32),
 DeviceArray([-1.7813579], dtype=float32),
 DeviceArray([0.5775596], dtype=float32),
 DeviceArray([0.24012971], dtype=float32),
 DeviceArray([1.4820312], dtype=float32),
 DeviceArray([1.1123037], dtype=float32),
 DeviceArray([0.19786133], dtype=float32),
 DeviceArray([-0.08700866], dtype=float32),
 DeviceArray([1.1052133], dtype=float32),
 DeviceArray([-0.74115896], dtype=float32),
 DeviceArray([-0.40711507], dtype=float32),
 DeviceArray([