# Tracking Sensor Bias

We want to compute the joint posterior over sensors' biases in a 2-D tracking setting.

In [None]:
from collections import OrderedDict

import torch
from torch.optim import Adam

import pyro
import pyro.distributions as dist

import funsor
import funsor.pyro
import funsor.distributions as f_dist
import funsor.ops as ops
from funsor.pyro.convert import dist_to_funsor, mvn_to_funsor, matrix_and_mvn_to_funsor, tensor_to_funsor
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.terms import lazy
from funsor.domains import bint, reals
from funsor.torch import Tensor, Variable
from funsor.sum_product import sequential_sum_product

import matplotlib.pyplot as plt

Simulate some synthetic data:

In [None]:
num_sensors = 5
num_frames = 100

# simulate biased sensors
sensors  = []
for _ in range(num_sensors):
    bias = 0.5 * torch.randn(2)
    sensors.append(bias)

# simulate a single track
track = []
z = 10 * torch.rand(2)  # initial state
v = 2 * torch.randn(2)  # velocity
for t in range(num_frames):
    # Advance latent state.
    z += v + 0.1 * torch.randn(2)
#     z.clamp_(min=0, max=10)  # keep in the box
    
    # Observe via a random sensor.
    sensor_id = pyro.sample('id', dist.Categorical(torch.ones(num_sensors)))
    x = z - sensors[sensor_id]
    track.append({"sensor_id": sensor_id, "x": x})

Now let's set up a tracking problem in Funsor. We start by modeling the biases of each sensor.

In [None]:
%pdb off

In [None]:
# TODO transform this to cholesky decomposition
# print(bias_cov.shape)
# bias_cov = bias_cov @ bias_cov.t()
# create a joint Gaussian over biases

covs = [torch.eye(2, requires_grad=True) for i in range(num_sensors)]
bias_dist = 0.
for i in range(num_sensors):
    bias += funsor.pyro.convert.mvn_to_funsor(
        dist.MultivariateNormal(torch.zeros(2), covs[i]),
#         event_dims=("pos",),
#         real_inputs=OrderedDict([("bias_{}".format(i), reals(2))])
        real_inputs=OrderedDict([("bias", reals(2))])
    )(value="bias_{}".format(i))
bias_dist.__dict__

In [None]:
# original
bias_scale = torch.ones(2, requires_grad=True)  # This can be learned
bias_dist = funsor.pyro.convert.mvn_to_funsor(
    dist.MultivariateNormal(
        torch.zeros(num_sensors * 2),
        bias_scale_tril.expand(num_sensors, 2).reshape(-1).diag_embed()
    )
    OrderedDict(bias=reals(num_sensors, 2))
)
bias_dist.__dict__

Set up the filter in funsor.

In [None]:
%pdb on

In [None]:
from pdb import set_trace as bb

In [None]:
# TODO
# this can be parameterized by a lower dimensional vector 
# to learn a structured transition matrix
# eg a GP with a matern v=3/2 kernel
# see paper for details 
transition_matrix = torch.randn(2, 2, requires_grad=True)

def model(track):
    init_dist = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))

    transition_dist = torch.distributions.MultivariateNormal(
        torch.zeros(2),
        torch.eye(2))
    observation_matrix = torch.eye(2) + 0.2 * torch.randn(2, 2)
    observation_dist = torch.distributions.MultivariateNormal(
        torch.zeros(2),
        torch.eye(2))

    init = dist_to_funsor(init_dist)(value="state")
    # inputs are the previous state ``state`` and the next state
    trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist,
                                     ("time",), "state", "state(time=1)")
    obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist,
                                   ("time",), "state(time=1)", "value")
    
    # Now this is the crux, we add bias to the observation
    sensor_ids = Tensor(
        torch.tensor([frame["sensor_id"] for frame in track]),
        OrderedDict([("time", bint(num_frames))]),
        dtype=len(sensors)
    )
    biased_observations = Tensor(
        torch.stack([frame["x"] for frame in track]),
        OrderedDict([("time", bint(num_frames))])
    )
    
    # incorporate sensor id in the observation
#     bias_over_time = bias(value=sensor_ids)
    # bias_over_time = bias(bias=biased_observations)
    # inputs: bias shape (num_sensors, 2), sensor_ids
    # outputs: 2
    bias = Variable("bias", reals(num_sensors, 2))[sensor_ids]
    
    debiased_observations = biased_observations - bias
    obs = obs(value=debiased_observations)
    print(obs)
    
    # Similar to funsor.pyro.hmm.GaussianHMM.log_prob()
    # ndims = max(len(batch_shape), value.dim() - event_dim)
    # value = tensor_to_funsor(value, ("time",), event_output=event_dim - 1,
    #                          dtype=self.dtype)

    # obs = obs(value=value)
    logp = trans + obs + bias_dist

    # collapse out the time variable
    logp = sequential_sum_product(ops.logaddexp, ops.add,
                                  logp, "time", {"state": "state(time=1)"})
    logp += init
    # logaddexp across all states
    logp = logp.reduce(ops.logaddexp, frozenset(["state", "state(time=1)"]))
#     # ensure we collapsed out the right dim
#     assert logp.data.dim() == 0
    return logp

## Inference

Finally we have a result that is a joint Gaussian over the biases.
We can
1. optimize all parameters to maximize `result`
2. estimate the joint distribution over all bias parameters.

In [None]:
num_epochs = 200
params = covs.copy()
params.append(transition_matrix)
optim = Adam(params, lr=1e-3)
for i in range(num_epochs):
    optim.zero_grad()
    with interpretation(lazy):
        log_prob = apply_optimizer(model(track))
    loss = -reinterpret(log_prob).data
    loss.backward()
    if i % 10 == 0:
        print(loss)
    optim.step()
print(params)

In [None]:
params

Visualize the joint posterior distribution.