# Tracking an object using the extended / unscented Kalman filter





Consider an object moving in $R^2$.
We assume that we observe a noisy version of its location at each time step.
We want to track the object and possibly forecast its future motion.
We now show how to do this using a simple nonlinear Gaussian SSM, combined with 
various extensions of the Kalman filter algorithm.

Let the hidden state represent
the position  of the object,
$z_t =\begin{pmatrix} u_t &  v_t  \end{pmatrix}$.
(We use $u$ and $v$ for the two coordinates,
to avoid confusion with the state and observation variables.)
We assume the following nonlinear dynamics:

\begin{align}
z_t &= f(z_{t-1}) + q_t \\
f(\begin{pmatrix} u \\ v \end{pmatrix})
 &= \begin{pmatrix} u + 0.5 \sin(v) \\ v + \cos(u) \end{pmatrix}
\end{align}

where $q_t \in R^2$ is the process noise, which we assume is Gaussian,
so  $q_t \sim N(0,Q)$.




At each discrete time point we
observe the location corrupted by  Gaussian noise.
Thus the observation model becomes

\begin{align}
y_t &= h(z_t) + r_t \\
h(z) &= z 
\end{align}

where $r_t \sim N(0,R)$ is the observation noise.




# Setup

In [1]:


try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -qq git+https://github.com/probml/dynamax.git
    import dynamax

In [2]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [5]:
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt

from dynamax.plotting import plot_inference, plot_uncertainty_ellipses
from dynamax.nonlinear_gaussian_ssm.containers import NLGSSMParams


# Create the model

In [7]:
initial_mean = jnp.array([1.5, 0.0])
state_dim = initial_mean.shape[0]
nlgssm = NLGSSMParams(
    dynamics_function=lambda x: x + 0.4 * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])]),
    dynamics_covariance=jnp.eye(state_dim) * 0.001,
    emission_function=lambda x: x,
    emission_covariance=jnp.eye(state_dim) * 0.05,
    initial_mean=initial_mean,
    initial_covariance=jnp.eye(state_dim),
)

# Sample some data from the model

In [8]:
states, emissions = nlgssm.sample(key=0, num_timesteps=100)


all_figures = {}
fig, ax = plt.subplots()
true_title = "Noisy obervations from hidden trajectory"
_ = plot_inference(states, emissions, ax=ax, title=true_title)
all_figures["ekf_spiral_true"] = fig

AttributeError: 'NLGSSMParams' object has no attribute 'sample'

# Perform online filtering

