# Kalman Filter

Kalman filters are linear models for state estimation of dynamic systems [1].  They have been the <i>de facto</i> standard in many robotics and tracking/prediction applications because they are well suited for systems with uncertainty about an observable dynamic process.  They use a "observe, predict, correct" paradigm to extract information from an otherwise noisy signal. In Pyro, we can build differentiable Kalman filters using the [`pyro.contrib.tracking` library](http://docs.pyro.ai/en/dev/contrib.tracking.html#module-pyro.contrib.tracking.extended_kalman_filter)

## Dynamic process

To start, consider this simple motion model:

$$ X_{k+1} = FX_k + \mathbf{W}_k $$
$$ \mathbf{Z}_k = HX_k + \mathbf{V}_k $$

where $k$ is the state, $X$ is the signal estimate, $Z_k$ is the observed value at timestep $k$, $\mathbf{W}_k$ and  $\mathbf{V}_k$ are independent noise processes (ie $\mathbb{E}[w_k v_j^T] = 0$ for all $j, k$) which we'll approximate as Gaussians. Note that the state transitions are linear.

## Kalman Update
At each time step, we perform a prediction for the mean and covariance:
$$ \hat{X}_k = F\hat{X}_{k-1}$$
$$\hat{P}_k = FP_{k-1}F^T + Q$$
and a correction for the measurement:
$$ K_k = \hat{P}_k H^T(H\hat{P}_k H^T + R)^{-1}$$
$$ X_k = \hat{X}_k + K_k(z_k - H\hat{X}_k)$$
$$ P_k = (I-K_k H)\hat{P}_k$$

where $X$ is the position estimate, $P$ is the covariance matrix, $K$ is the Kalman Gain, and $Q$ and $R$ are covariance matrices.

For an in-depth derivation, see \[1\]

## Nonlinear Estimation: Extended Kalman Filter

What if our system is non-linear, eg in GPS navigation?  Consider this non-linear system:

$$ X_{k+1} = \mathbf{f}(X_k) + \mathbf{W}_k $$
$$ \mathbf{Z}_k = \mathbf{h}(X_k) + \mathbf{V}_k $$

Notice that $\mathbf{f}$ and $\mathbf{h}$ are now (smooth) non-linear functions.


The Extended Kalman Filter (EKF) attacks this problem by using a local linearization of the Kalman filter via a [Taylors Series expansion](https://en.wikipedia.org/wiki/Taylor_series).

$$ f(X_k, k) \approx f(x_k^R, k) + \mathbf{H}_k(X_k - x_k^R) + \cdots$$

where $\mathbf{H}_k$ is the Jacobian matrix at time $k$, $x_k^R$ is the previous optimal estimate, and we ignore the higher order terms.  At each time step, we compute a Jacobian conditioned the previous predictions (this computation is handled for us under the hood), and use the result to perform a prediction and update.

Omitting the derivations, the modification to the above predictions are now:
$$ \hat{X}_k \approx \mathbf{f}(X_{k-1}^R)$$
$$ \hat{P}_k = \mathbf{H}_\mathbf{f}(X_{k-1})P_{k-1}\mathbf{H}_\mathbf{f}^T(X_{k-1}) + Q$$
and the updates are now:
$$ X_k \approx \hat{X}_k + K_k\big(z_k - \mathbf{h}(\hat{X}_k)\big)$$
$$ K_k = \hat{P}_k \mathbf{H}_\mathbf{h}(\hat{X}_k) \Big(\mathbf{H}_\mathbf{h}(\hat{X}_k)\hat{P}_k \mathbf{H}_\mathbf{h}(\hat{X}_k) + R_k\Big)^{-1} $$
$$ P_k = \big(I - K_k \mathbf{H}_\mathbf{h}(\hat{X}_k)\big)\hat{P}_K$$

In Pyro, all we need to do is create an `EKFState` object and use its `predict` and `update` methods. Pyro will do exact inference to compute the innovations and we will use SVI to learn a MAP estimate of the position and measurement covariances.

As an example, let's look at an object moving at near-constant velocity in 2-D in a discrete time space over 100 time steps.

In [1]:
import os
import math

import torch
import pyro
from pyro.distributions import HalfCauchy, MultivariateNormal
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, config_enumerate
from pyro.contrib.tracking.extended_kalman_filter import EKFState
from pyro.contrib.tracking.distributions import EKFDistribution
from pyro.contrib.tracking.dynamic_models import NcvContinuous
from pyro.contrib.tracking.measurements import PositionMeasurement

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)

In [2]:
# Discrete time
dt = 1 / 100
num_frames = 10
dim = 4

# Continuous model
ncv = NcvContinuous(dimension=dim, sa2=2.0)

# Truth trajectory
xs_truth = torch.zeros(num_frames, dim)
# initial direction
theta0_truth = 0.0
# initial state
xs_truth[0, :] = torch.tensor([0.0, 0.0,  math.cos(theta0_truth), math.sin(theta0_truth)])
for frame_num in range(1, num_frames):
    # sample independent process noise
    dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))
    xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx

In [3]:
# Measurements
measurements = []
mean = torch.zeros(2)
# no correlations
cov = 1e-5 * torch.eye(2)
# sample independent measurement noise
dzs = pyro.sample('dzs', MultivariateNormal(mean, cov).expand((num_frames,)))
# compute measurement means
zs = xs_truth[:, :2] + dzs

In [32]:
def model(data):
    # TODO replace this with HalfNormal
    R = pyro.sample('R', dist.HalfCauchy(2e-6)) * torch.eye(4)
    Q = pyro.sample('Q', dist.HalfCauchy(1e-6)) * torch.eye(2)
    # observe the measurements
    pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,
                                                      Q, time_steps=num_frames),
                obs=data)
    
guide = AutoDelta(model)

In [33]:
optim = pyro.optim.Adam({'lr': 2e-2})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

pyro.set_rng_seed(0)
pyro.clear_param_store()

for i in range(100 if not smoke_test else 2):
    loss = svi.step(zs)
    if not i % 10:
        print('loss: ', loss)

loss:  -10.265189170837402
loss:  -10.808286666870117
loss:  -11.267187118530273
loss:  -11.635580062866211
loss:  -11.912515640258789
loss:  -12.10242748260498
loss:  -12.216792106628418
loss:  -12.274310111999512
loss:  -12.297000885009766
loss:  -12.303607940673828
loss:  -12.305276870727539
loss:  -12.306243896484375
loss:  -12.307249069213867
loss:  -12.308155059814453
loss:  -12.308900833129883
loss:  -12.30950927734375
loss:  -12.310028076171875
loss:  -12.310478210449219
loss:  -12.310866355895996
loss:  -12.311205863952637


In [34]:
for i in pyro.get_param_store().get_all_param_names():
    print(i, pyro.param(i))

auto_R tensor(6.0387e-06, grad_fn=<AddBackward>)
auto_Q tensor(1.9711e-07, grad_fn=<AddBackward>)


In [35]:
R = guide()['R'] * torch.eye(4)
Q = guide()['Q'] * torch.eye(2)
ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)
states= ekf_dist.get_states(zs)


In [36]:
#TRUTH
[x[:2] for x in xs_truth.numpy().tolist()]

[[0.0, 0.0],
 [0.010107642970979214, -6.847964868939016e-06],
 [0.020236091688275337, -0.0001418385945726186],
 [0.03030969947576523, -0.0001620639959583059],
 [0.04040909931063652, -9.933050023391843e-06],
 [0.05033167451620102, 0.00024324559490196407],
 [0.06021888926625252, 0.0006623838562518358],
 [0.0702277421951294, 0.0010909209959208965],
 [0.08015932142734528, 0.0015986886573955417],
 [0.0900019183754921, 0.0019451151601970196]]

In [37]:
# MEASUREMENT MEANS
zs.numpy().tolist()

[[-0.0047094551846385, 0.003680370980873704],
 [0.009187041781842709, 0.0009707469143904746],
 [0.018267622217535973, -0.003966068848967552],
 [0.026824595406651497, 0.003120339708402753],
 [0.039249252527952194, -0.0023127663880586624],
 [0.05123267322778702, -0.00259613711386919],
 [0.06542954593896866, -0.009974617511034012],
 [0.07483357191085815, -0.0030986773781478405],
 [0.07863657176494598, 0.0004923918750137091],
 [0.0964636355638504, -0.0014208999928086996]]

In [38]:
# STATE MEANS
[s.mean[:2].detach().numpy().tolist() for s in states]

[[-0.004560594912618399, 0.00356403854675591],
 [0.009187281131744385, 0.0009707475546747446],
 [0.018267542123794556, -0.003966068848967552],
 [0.026824623346328735, 0.00312033761292696],
 [0.0392492450773716, -0.0023127617314457893],
 [0.05123267322778702, -0.0025961389765143394],
 [0.06542954593896866, -0.009974615648388863],
 [0.07483357191085815, -0.003098679706454277],
 [0.07863657176494598, 0.000492393970489502],
 [0.0964636355638504, -0.0014208992943167686]]

In [11]:
# STATE COV
[s.cov[:2, :2].detach().numpy().tolist() for s in states]

[[[3.5885700526705477e-06, 0.0], [0.0, 3.5885700526705477e-06]],
 [[3.877793915307848e-06, 0.0], [0.0, 3.877793915307848e-06]],
 [[3.8778030102548655e-06, 0.0], [0.0, 3.8778030102548655e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]],
 [[3.877803919749567e-06, 0.0], [0.0, 3.877803919749567e-06]]]

In [12]:
cov.numpy().tolist()

[[9.999999747378752e-06, 0.0], [0.0, 9.999999747378752e-06]]