# Tracking Sensor Bias

We want to compute the joint posterior over sensors' biases in a 2-D tracking setting.

In [81]:
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.optim import Adam

import pyro
import pyro.distributions as dist

import funsor
import funsor.pyro
import funsor.distributions as f_dist
import funsor.ops as ops
from funsor.pyro.convert import dist_to_funsor, mvn_to_funsor, matrix_and_mvn_to_funsor, tensor_to_funsor
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.terms import lazy, eager_or_die
from funsor.domains import bint, reals
from funsor.torch import Tensor, Variable
from funsor.sum_product import sequential_sum_product

import matplotlib.pyplot as plt

Simulate some synthetic data:

In [78]:
num_sensors = 5
num_frames = 100

# simulate biased sensors
sensors  = []
for _ in range(num_sensors):
    bias = 0.5 * torch.randn(2)
    sensors.append(bias)

# simulate a single track
# TODO heterogeneous time
partial_obs = []
z = 10 * torch.rand(2)  # initial state
v = 2 * torch.randn(2)  # velocity
for t in range(num_frames):
    # Advance latent state.
    z += v + 0.1 * torch.randn(2)
#     z.clamp_(min=0, max=10)  # keep in the box
    
    # Observe via a random sensor.
    sensor_id = pyro.sample('id', dist.Categorical(torch.ones(num_sensors)))
    x = z - sensors[sensor_id]
    partial_obs.append({"sensor_id": sensor_id, "x": x})
    
# simulate all tracks
full_observations = []
z = 10 * torch.rand(5, 2)  # initial state
v = 2 * torch.randn(5, 2)  # velocity
for t in range(num_frames):
    # Advance latent state.
    z += v + 0.1 * torch.randn(5, 2)
#     z.clamp_(min=0, max=10)  # keep in the box
    
    # Observe via a random sensor.
    x = z - torch.stack(sensors)
    full_observations.append(x)
full_observations = torch.stack(full_observations)
assert full_observations.shape == (num_frames, 5, 2)
full_observations = Tensor(full_observations, OrderedDict([("time", bint(num_frames))]))

Now let's set up a tracking problem in Funsor. We start by modeling the biases of each sensor.

In [79]:
# TODO transform this to cholesky decomposition
# print(bias_cov.shape)
# bias_cov = bias_cov @ bias_cov.t()
# create a joint Gaussian over biases

# covs = [torch.eye(2, requires_grad=True) for i in range(num_sensors)]
# bias_dist = 0.
# for i in range(num_sensors):
#     bias += funsor.pyro.convert.mvn_to_funsor(
#         dist.MultivariateNormal(torch.zeros(2), covs[i]),
# #         event_dims=("pos",),
# #         real_inputs=OrderedDict([("bias_{}".format(i), reals(2))])
#         real_inputs=OrderedDict([("bias", reals(2))])
#     )(value="bias_{}".format(i))
# bias_dist.__dict__

# we can't write bias_dist as a sum of mvns because affine transformation
# of mvns is not supported yet.  instead we will combine all the sensors
# into a giant tensor
bias_scales = torch.ones(2, requires_grad=True)  # This can be learned
bias_dist = funsor.pyro.convert.mvn_to_funsor(
    dist.MultivariateNormal(
        torch.zeros(num_sensors * 2),
        bias_scales.expand(num_sensors, 2).reshape(-1).diag_embed()
    ),
    real_inputs=OrderedDict([("bias", reals(num_sensors, 2))])
)
bias_dist.__dict__

{'inputs': OrderedDict([('bias', reals(5, 2))]),
 'output': reals(),
 'fresh': frozenset(),
 'bound': frozenset(),
 'deltas': (),
 'discrete': Tensor(-9.189385414123535, OrderedDict(), 'real'),
 'gaussian': Gaussian(..., ((bias, reals(5, 2)),)),
 '_ast_values': ((),
  Tensor(-9.189385414123535, OrderedDict(), 'real'),
  Gaussian(..., ((bias, reals(5, 2)),)))}

Set up the filter in funsor.

In [57]:
%pdb off

Automatic pdb calling has been turned OFF


In [58]:
from pdb import set_trace as bb

In [77]:
# TODO
# this can be parameterized by a lower dimensional vector 
# to learn a structured transition matrix
# eg a GP with a matern v=3/2 kernel
# see paper for details 

# transition matrix from discretization as in 
# http://webee.technion.ac.il/people/shimkin/Estimation09/ch8_target.pdf
T = 1.  # timestep
trans_matrix_noise = torch.randn(1, requires_grad=True)
trans_dist_cov = torch.tensor([[1./3 * T ** 3, 0.5 * T ** 2],
                                  [0.5 * T ** 2, T]]) * trans_matrix_noise ** 2
transition_matrix = torch.randn(2, 2, requires_grad=True)
transition_matrix = torch.tensor([[1., T],
                                  [0, 1]])

class HMM(nn.Module):
    def __init__(self, num_sensors):
        init_dist = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
        transition_dist = torch.distributions.MultivariateNormal(
            torch.zeros(2), trans_dist_cov)
        observation_matrix = torch.eye(2) + 0.2 * torch.randn(2, 2)
        # this is the bias for all the sensors
        biases = torch.zeros(num_sensors, 2, requires_grad=True)
        observation_dist = torch.distributions.MultivariateNormal(
            bias,
            torch.eye(2))

        self.init = dist_to_funsor(init_dist)(value="state")
        # inputs are the previous state ``state`` and the next state
        self.trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist,
                                         ("time",), "state", "state(time=1)")
        self.obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist,
                                       ("time",), "state(time=1)", "value")
        super(HMM, self).__init__()
    
    def forward(track):
        sensor_ids = torch.tensor([frame["sensor_id"] for frame in track])
        bias = self.biases[sensor_ids]
        # we add bias to the observation as a global variable
        # single interleaved track
        sensor_ids = Tensor(
            torch.tensor([frame["sensor_id"] for frame in track]),
            OrderedDict([("time", bint(num_frames))]),
            dtype=len(sensors)
        )
        
#         biased_observations = Tensor(
#             torch.stack([frame["x"] for frame in track]),
#             OrderedDict([("time", bint(num_frames))])
#         )

        # incorporate sensor id in the observation by creating
        # a free variable that has the signature
        # inputs: bias of shape (num_sensors, 2), sensor_ids
        # outputs shape 2
        bias = Variable("bias", reals(num_sensors, 2))#  [sensor_ids]
        debiased_observations = track - bias
    #     debiased_observations = biased_observations
        # this indexing pattern is not implemented to sub into a Gaussian
        # https://github.com/pyro-ppl/funsor/pull/220
        # instead, we can use matrix_and_mvn_to_funsor and index  the proper latents and just 
        # observe naively

    #     bias = Variable("bias", reals(num_sensors, 2))
    #     debiased_observations = all_tracks - bias
        obs = obs(value=debiased_observations)
    #     print(obs)

        logp = trans + obs + bias_dist

        bb()
        
        # collapse out the time variable
        # TODO this can only handle homogeneous funsor types
        logp = sequential_sum_product(ops.logaddexp, ops.add,
                                      logp, "time", {"state": "state(time=1)"})
        logp += init
        # logaddexp across all states
        logp = logp.reduce(ops.logaddexp, frozenset(["state", "state(time=1)"]))
    #     # ensure we collapsed out the right dim
    #     assert logp.data.dim() == 0
        return logp

## Inference

Finally we have a result that is a joint Gaussian over the biases.
We can
1. optimize all parameters to maximize `result`
2. estimate the joint distribution over all bias parameters.

In [80]:
num_epochs = 200
params = [bias_scales]
# params.append(transition_matrix)
optim = Adam(params, lr=1e-3)
model = HMM(num_sensors)
for i in range(num_epochs):
    optim.zero_grad()
    with interpretation(lazy):
        log_prob = apply_optimizer(model(partial_obs))
    loss = -reinterpret(log_prob).data
    loss.backward()
    if i % 10 == 0:
        print(loss)
    optim.step()
print(params)

ValueError: Cannot convert to Funsor: [{'sensor_id': tensor(2), 'x': tensor([9.1694, 3.3267])}, {'sensor_id': tensor(3), 'x': tensor([13.0524,  4.3502])}, {'sensor_id': tensor(2), 'x': tensor([15.0693,  3.7658])}, {'sensor_id': tensor(4), 'x': tensor([18.3525,  3.4833])}, {'sensor_id': tensor(1), 'x': tensor([21.6490,  3.9347])}, {'sensor_id': tensor(1), 'x': tensor([24.8167,  4.2799])}, {'sensor_id': tensor(1), 'x': tensor([27.7863,  4.3933])}, {'sensor_id': tensor(3), 'x': tensor([31.1898,  5.5257])}, {'sensor_id': tensor(0), 'x': tensor([34.0032,  4.6109])}, {'sensor_id': tensor(1), 'x': tensor([36.9519,  4.8488])}, {'sensor_id': tensor(4), 'x': tensor([39.6025,  4.7435])}, {'sensor_id': tensor(2), 'x': tensor([42.1560,  5.3400])}, {'sensor_id': tensor(1), 'x': tensor([45.9227,  5.2604])}, {'sensor_id': tensor(2), 'x': tensor([48.3392,  5.6908])}, {'sensor_id': tensor(1), 'x': tensor([52.1634,  5.7445])}, {'sensor_id': tensor(2), 'x': tensor([54.4107,  6.1939])}, {'sensor_id': tensor(0), 'x': tensor([58.2860,  6.3251])}, {'sensor_id': tensor(1), 'x': tensor([61.0124,  6.3153])}, {'sensor_id': tensor(0), 'x': tensor([64.1793,  6.6614])}, {'sensor_id': tensor(2), 'x': tensor([66.3854,  7.0601])}, {'sensor_id': tensor(1), 'x': tensor([70.1047,  6.9807])}, {'sensor_id': tensor(4), 'x': tensor([72.5953,  7.0011])}, {'sensor_id': tensor(2), 'x': tensor([75.1890,  7.5688])}, {'sensor_id': tensor(3), 'x': tensor([78.8799,  8.4907])}, {'sensor_id': tensor(4), 'x': tensor([81.3046,  7.5959])}, {'sensor_id': tensor(3), 'x': tensor([84.7973,  8.9431])}, {'sensor_id': tensor(4), 'x': tensor([87.0585,  8.0343])}, {'sensor_id': tensor(4), 'x': tensor([90.0389,  8.1663])}, {'sensor_id': tensor(2), 'x': tensor([92.8330,  8.5220])}, {'sensor_id': tensor(2), 'x': tensor([95.7760,  8.5050])}, {'sensor_id': tensor(1), 'x': tensor([99.3801,  8.6107])}, {'sensor_id': tensor(4), 'x': tensor([101.9841,   8.5197])}, {'sensor_id': tensor(2), 'x': tensor([104.6606,   9.0728])}, {'sensor_id': tensor(0), 'x': tensor([108.3842,   9.0323])}, {'sensor_id': tensor(3), 'x': tensor([111.4700,  10.0289])}, {'sensor_id': tensor(4), 'x': tensor([113.8042,   9.0392])}, {'sensor_id': tensor(3), 'x': tensor([117.3052,  10.4275])}, {'sensor_id': tensor(1), 'x': tensor([120.0671,   9.6234])}, {'sensor_id': tensor(1), 'x': tensor([123.1322,   9.8111])}, {'sensor_id': tensor(3), 'x': tensor([126.3046,  11.1277])}, {'sensor_id': tensor(1), 'x': tensor([129.0721,  10.3092])}, {'sensor_id': tensor(3), 'x': tensor([132.2044,  11.3628])}, {'sensor_id': tensor(3), 'x': tensor([135.1691,  11.2532])}, {'sensor_id': tensor(0), 'x': tensor([137.9679,  10.5809])}, {'sensor_id': tensor(4), 'x': tensor([140.3839,  10.3766])}, {'sensor_id': tensor(4), 'x': tensor([143.4000,  10.5894])}, {'sensor_id': tensor(4), 'x': tensor([146.2598,  10.9724])}, {'sensor_id': tensor(1), 'x': tensor([149.7123,  11.2845])}, {'sensor_id': tensor(2), 'x': tensor([151.8876,  11.8553])}, {'sensor_id': tensor(3), 'x': tensor([155.7931,  12.8001])}, {'sensor_id': tensor(2), 'x': tensor([157.9830,  12.2561])}, {'sensor_id': tensor(3), 'x': tensor([161.8929,  13.2875])}, {'sensor_id': tensor(1), 'x': tensor([164.7876,  12.5654])}, {'sensor_id': tensor(0), 'x': tensor([167.7211,  12.7872])}, {'sensor_id': tensor(1), 'x': tensor([170.6288,  12.8048])}, {'sensor_id': tensor(4), 'x': tensor([173.3824,  12.8233])}, {'sensor_id': tensor(2), 'x': tensor([176.0866,  13.2088])}, {'sensor_id': tensor(3), 'x': tensor([180.0560,  13.7414])}, {'sensor_id': tensor(1), 'x': tensor([183.0132,  12.9792])}, {'sensor_id': tensor(0), 'x': tensor([186.3090,  13.2550])}, {'sensor_id': tensor(0), 'x': tensor([189.1656,  13.4832])}, {'sensor_id': tensor(3), 'x': tensor([192.3359,  14.3532])}, {'sensor_id': tensor(3), 'x': tensor([195.4426,  14.5637])}, {'sensor_id': tensor(3), 'x': tensor([198.5751,  14.5875])}, {'sensor_id': tensor(2), 'x': tensor([200.6177,  13.9135])}, {'sensor_id': tensor(0), 'x': tensor([204.5692,  13.8697])}, {'sensor_id': tensor(2), 'x': tensor([206.7561,  14.3471])}, {'sensor_id': tensor(0), 'x': tensor([210.5986,  14.2313])}, {'sensor_id': tensor(0), 'x': tensor([213.4642,  14.2828])}, {'sensor_id': tensor(2), 'x': tensor([215.5287,  14.7842])}, {'sensor_id': tensor(3), 'x': tensor([219.3798,  15.7601])}, {'sensor_id': tensor(0), 'x': tensor([222.1944,  15.0777])}, {'sensor_id': tensor(4), 'x': tensor([224.6969,  15.0153])}, {'sensor_id': tensor(4), 'x': tensor([227.7648,  15.1654])}, {'sensor_id': tensor(0), 'x': tensor([231.2411,  15.7304])}, {'sensor_id': tensor(2), 'x': tensor([233.6424,  16.0436])}, {'sensor_id': tensor(2), 'x': tensor([236.7544,  16.1129])}, {'sensor_id': tensor(2), 'x': tensor([239.7679,  16.2487])}, {'sensor_id': tensor(0), 'x': tensor([243.7346,  16.2731])}, {'sensor_id': tensor(1), 'x': tensor([246.5113,  16.3077])}, {'sensor_id': tensor(1), 'x': tensor([249.4129,  16.2574])}, {'sensor_id': tensor(3), 'x': tensor([252.8325,  17.3497])}, {'sensor_id': tensor(2), 'x': tensor([254.7593,  16.7187])}, {'sensor_id': tensor(2), 'x': tensor([257.8399,  16.8686])}, {'sensor_id': tensor(3), 'x': tensor([261.5269,  17.7100])}, {'sensor_id': tensor(0), 'x': tensor([264.4040,  17.1056])}, {'sensor_id': tensor(2), 'x': tensor([266.4769,  17.4210])}, {'sensor_id': tensor(2), 'x': tensor([269.4237,  17.7553])}, {'sensor_id': tensor(4), 'x': tensor([272.7732,  17.5344])}, {'sensor_id': tensor(2), 'x': tensor([275.3050,  17.9998])}, {'sensor_id': tensor(3), 'x': tensor([279.2312,  18.8181])}, {'sensor_id': tensor(1), 'x': tensor([282.0009,  18.1257])}, {'sensor_id': tensor(2), 'x': tensor([284.3007,  18.5558])}, {'sensor_id': tensor(2), 'x': tensor([287.2471,  18.8128])}, {'sensor_id': tensor(0), 'x': tensor([291.1439,  18.9672])}, {'sensor_id': tensor(1), 'x': tensor([293.9634,  19.0506])}, {'sensor_id': tensor(1), 'x': tensor([296.8717,  19.2347])}, {'sensor_id': tensor(2), 'x': tensor([299.0641,  19.7452])}, {'sensor_id': tensor(3), 'x': tensor([303.0298,  20.6010])}, {'sensor_id': tensor(3), 'x': tensor([305.9593,  20.7961])}]

Visualize the joint posterior distribution.

### possible plots
1. plot the MSE of the MAP estimates with and without bias (or table)
2. train with and without marginalizing out bias, plot both loss curves
  - plot nll and MSE at each epoch
3. smoothing? would require adjoint algorithm `tests/test_adjoint.py`
4. 