# Kalman Filter

Kalman filters are linear models for state estimation of dynamic systems [1].  They have been the <i>de facto</i> standard in many robotics and tracking/prediction applications because they are well suited for systems with uncertainty about an observable dynamic process.  They use a "observe, predict, correct" paradigm to extract information from an otherwise noisy signal. In Pyro, we can build differentiable Kalman filters using the [`pyro.contrib.tracking` library](http://docs.pyro.ai/en/dev/contrib.tracking.html#module-pyro.contrib.tracking.extended_kalman_filter)

## Dynamic process

To start, consider this simple motion model:

$$ X_{k+1} = FX_k + \mathbf{W}_k $$
$$ \mathbf{Z}_k = HX_k + \mathbf{V}_k $$

where $k$ is the state, $X$ is the signal estimate, $Z_k$ is the observed value at timestep $k$, $\mathbf{W}_k$ and  $\mathbf{V}_k$ are independent noise processes (ie $\mathbb{E}[w_k v_j^T] = 0$ for all $j, k$) which we'll approximate as Gaussians. Note that the state transitions are linear.

## Kalman Update
At each time step, we perform a prediction for the mean and covariance:
$$ \hat{X}_k = F\hat{X}_{k-1}$$
$$\hat{P}_k = FP_{k-1}F^T + Q$$
and a correction for the measurement:
$$ K_k = \hat{P}_k H^T(H\hat{P}_k H^T + R)^{-1}$$
$$ X_k = \hat{X}_k + K_k(z_k - H\hat{X}_k)$$
$$ P_k = (I-K_k H)\hat{P}_k$$

where $X$ is the position estimate, $P$ is the covariance matrix, $K$ is the Kalman Gain, and $Q$ and $R$ are covariance matrices.

For an in-depth derivation, see \[1\]

## Nonlinear Estimation: Extended Kalman Filter

What if our system is non-linear, eg in GPS navigation?  Consider this non-linear system:

$$ X_{k+1} = \mathbf{f}(X_k) + \mathbf{W}_k $$
$$ \mathbf{Z}_k = \mathbf{h}(X_k) + \mathbf{V}_k $$

Notice that $\mathbf{f}$ and $\mathbf{h}$ are now (smooth) non-linear functions.


The Extended Kalman Filter (EKF) attacks this problem by using a local linearization of the Kalman filter via a [Taylors Series expansion](https://en.wikipedia.org/wiki/Taylor_series).

$$ f(X_k, k) \approx f(x_k^R, k) + \mathbf{H}_k(X_k - x_k^R) + \cdots$$

where $\mathbf{H}_k$ is the Jacobian matrix at time $k$, $x_k^R$ is the previous optimal estimate, and we ignore the higher order terms.  At each time step, we compute a Jacobian conditioned the previous predictions (this computation is handled for us under the hood), and use the result to perform a prediction and update.

Omitting the derivations, the modification to the above predictions are now:
$$ \hat{X}_k \approx \mathbf{f}(X_{k-1}^R)$$
$$ \hat{P}_k = \mathbf{H}_\mathbf{f}(X_{k-1})P_{k-1}\mathbf{H}_\mathbf{f}^T(X_{k-1}) + Q$$
and the updates are now:
$$ X_k \approx \hat{X}_k + K_k\big(z_k - \mathbf{h}(\hat{X}_k)\big)$$
$$ K_k = \hat{P}_k \mathbf{H}_\mathbf{h}(\hat{X}_k) \Big(\mathbf{H}_\mathbf{h}(\hat{X}_k)\hat{P}_k \mathbf{H}_\mathbf{h}(\hat{X}_k) + R_k\Big)^{-1} $$
$$ P_k = \big(I - K_k \mathbf{H}_\mathbf{h}(\hat{X}_k)\big)\hat{P}_K$$

In Pyro, all we need to do is create an `EKFState` object and use its `predict` and `update` methods. Pyro will do exact inference to compute the innovations and we will use SVI to learn a MAP estimate of the position and measurement covariances.

As an example, let's look at an object moving at near-constant velocity in 2-D in a discrete time space over 100 time steps.

In [1]:
import os
import math

import torch
import pyro
from pyro.distributions import MultivariateNormal
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, config_enumerate
from pyro.contrib.tracking.extended_kalman_filter import EKFState
from pyro.contrib.tracking.distributions import EKFDistribution
from pyro.contrib.tracking.dynamic_models import NcvContinuous
from pyro.contrib.tracking.measurements import PositionMeasurement

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)

In [2]:
# Discrete time
dt = 1 / 100
num_frames = 10
dim = 4

# Continuous model
ncv = NcvContinuous(dimension=dim, sa2=2.0)

# Truth trajectory
xs_truth = torch.zeros(num_frames, dim)
# initial direction
theta0_truth = 0.0
# initial state
xs_truth[0, :] = torch.tensor([0.0, 0.0, math.cos(theta0_truth), math.sin(theta0_truth)])
for frame_num in range(1, num_frames):
    # sample independent process noise
    dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))
    xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx

In [3]:
# Measurements
measurements = []
mean = torch.zeros(2)
# no correlations
cov = 1e-5 * torch.eye(2)
# sample independent measurement noise
dzs = pyro.sample('dzs', MultivariateNormal(mean, cov).expand((num_frames,)))
# compute measurement means
zs = xs_truth[:, :2] + dzs

In [4]:
def model(data):
    R = pyro.sample('R', dist.LogNormal(1., 1.)) * torch.eye(4)
    Q = pyro.sample('Q', dist.LogNormal(1., 1.)) * torch.eye(2)
    # observe the measurements
    pyro.sample('track_{}'.format(i), EKFDistribution(xs_truth[0], R, ncv,
                                                      Q, time_steps=num_frames),
                obs=data)
    
guide = AutoDelta(model)

In [5]:
optim = pyro.optim.Adam({'lr': 0.2})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

pyro.set_rng_seed(0)
pyro.clear_param_store()

for i in range(50 if not smoke_test else 2):
    loss = svi.step(zs)
    print('loss: ', loss)

loss:  47.574042201042175
loss:  45.49557280540466
loss:  43.53862643241882
loss:  41.70700752735138
loss:  40.00416624546051
loss:  38.43302273750305
loss:  36.99586057662964
loss:  35.69414782524109
loss:  34.52836811542511
loss:  33.49785327911377
loss:  32.60059905052185
loss:  31.833141565322876
loss:  31.190431833267212
loss:  30.66579782962799
loss:  30.25099265575409
loss:  29.936315059661865
loss:  29.710856199264526
loss:  29.562790632247925
loss:  29.479783415794373
loss:  29.449408411979675
loss:  29.45954966545105
loss:  29.498809576034546
loss:  29.55682396888733
loss:  29.624511241912842
loss:  29.694225788116455
loss:  29.759836196899414
loss:  29.816715240478516
loss:  29.86168074607849
loss:  29.892847537994385
loss:  29.909519910812378
loss:  29.9119553565979
loss:  29.901209592819214
loss:  29.878931760787964
loss:  29.84717559814453
loss:  29.808237075805664
loss:  29.764477729797363
loss:  29.718207359313965
loss:  29.671549558639526
loss:  29.62637233734131
loss:

In [8]:
R = guide()['R'] * torch.eye(4)
Q = guide()['Q'] * torch.eye(2)
ekf_dist = EKFDistribution(xs_truth[0], R, ncv, Q, time_steps=num_frames)
states, innovations = ekf_dist.get_states(zs)


[tensor([-0.0016,  0.0008,  1.0000,  0.0000], grad_fn=<ThSubBackward>),
 tensor([ 0.0897,  0.0022, -0.1193,  0.0016], grad_fn=<ThSubBackward>),
 tensor([ 0.0178,  0.0085, -0.0684,  0.0067], grad_fn=<ThSubBackward>),
 tensor([ 0.0235,  0.0039,  0.0116, -0.0055], grad_fn=<ThSubBackward>),
 tensor([ 0.0348, -0.0011,  0.0113, -0.0050], grad_fn=<ThSubBackward>),
 tensor([0.0515, 0.0055, 0.0171, 0.0075], grad_fn=<ThSubBackward>),
 tensor([ 0.0572, -0.0029,  0.0048, -0.0096], grad_fn=<ThSubBackward>),
 tensor([ 0.0690, -0.0048,  0.0123, -0.0013], grad_fn=<ThSubBackward>),
 tensor([ 0.0800, -0.0019,  0.0110,  0.0032], grad_fn=<ThSubBackward>),
 tensor([ 0.0828, -0.0083,  0.0021, -0.0071], grad_fn=<ThSubBackward>)]

In [None]:
#TRUTH
[x[:2] for x in xs_truth.numpy().tolist()]

In [30]:
# STATE MEANS
[s.mean[:2].detach().numpy().tolist() for s in states]

[[-0.0016402173787355423, 0.0008416093769483268],
 [0.08970636129379272, 0.0021629068069159985],
 [0.01775578409433365, 0.008512821048498154],
 [0.023472897708415985, 0.003906172700226307],
 [0.03478563576936722, -0.0011203686008229852],
 [0.051492709666490555, 0.005487498827278614],
 [0.05721009895205498, -0.002855192869901657],
 [0.06895408779382706, -0.004760378506034613],
 [0.08001435548067093, -0.0019312351942062378],
 [0.08280424773693085, -0.008302029222249985]]

In [29]:
# STATE COV
[s.cov[:2, :2].detach().numpy().tolist() for s in states]

[[[0.07635901868343353, 0.0], [0.0, 0.07635901868343353]],
 [[0.0887937992811203, 0.0], [0.0, 0.0887937992811203]],
 [[0.09159314632415771, 0.0], [0.0, 0.09159314632415771]],
 [[0.09162957221269608, 0.0], [0.0, 0.09162957221269608]],
 [[0.09163796901702881, 0.0], [0.0, 0.09163796901702881]],
 [[0.09163802117109299, 0.0], [0.0, 0.09163802117109299]],
 [[0.09163804352283478, 0.0], [0.0, 0.09163804352283478]],
 [[0.09163803607225418, 0.0], [0.0, 0.09163803607225418]],
 [[0.09163803607225418, 0.0], [0.0, 0.09163803607225418]],
 [[0.09163803607225418, 0.0], [0.0, 0.09163803607225418]]]

In [None]:
# INNOVATION COV
[i[0].detach().numpy().tolist() for i in innovations]

In [21]:
# INNOVATION COV
[i[1].detach().numpy().tolist() for i in innovations]

[[[0.46201056241989136, 0.0], [0.0, 0.46201056241989136]],
 [[1.2050361633300781, 0.0], [0.0, 1.2050361633300781]],
 [[1.888930082321167, 0.0], [0.0, 1.888930082321167]],
 [[1.9029815196990967, 0.0], [0.0, 1.9029815196990967]],
 [[1.9062511920928955, 0.0], [0.0, 1.9062511920928955]],
 [[1.9062671661376953, 0.0], [0.0, 1.9062671661376953]],
 [[1.9062767028808594, 0.0], [0.0, 1.9062767028808594]],
 [[1.9062762260437012, 0.0], [0.0, 1.9062762260437012]],
 [[1.9062762260437012, 0.0], [0.0, 1.9062762260437012]],
 [[1.9062762260437012, 0.0], [0.0, 1.9062762260437012]]]