# Tracking an Unknown Number of Objects

While SVI can be used to learn components and assignments of a mixture model, pyro.contrib.tracking provides more efficient inference algorithms to estimate assignments. This notebook demonstrates how to use the `MarginalAssignmentPersistent` inside SVI.

In [None]:
from __future__ import absolute_import, division, print_function
import math
import os
import torch
from torch.distributions import constraints
from matplotlib import pyplot

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.tracking.assignment import MarginalAssignmentPersistent
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import ClippedAdam

from datagen_utils import generate_observations, get_positions
from plot_utils import plot_solution, plot_exists_prob
%matplotlib inline
pyro.enable_validation(True)
smoke_test = ('CI' in os.environ)

## Model
It's tricky to define a fully generative model, so instead we'll separate our data generation process `generate_data()` from a factor graph `model()` that will be used in inference.

In [None]:
@poutine.broadcast
def model(args, observations):
    with pyro.iarange("objects", args.max_num_objects):
        exists = pyro.sample("exists",
                             dist.Bernoulli(args.expected_num_objects / args.max_num_objects))
        with poutine.scale(scale=exists):
            states_loc = pyro.sample("states", dist.Normal(0., 1.).expand([2]).independent(1))
            positions = get_positions(states_loc, args.num_frames)
    with pyro.iarange("detections", observations.shape[1]):
        with pyro.iarange("time", args.num_frames):
            # The combinatorial part of the log prob is approximated to allow independence.
            is_observed = (observations[..., -1] > 0)
            with poutine.scale(scale=is_observed.float()):
                assign = pyro.sample("assign",
                                     dist.Categorical(torch.ones(args.max_num_objects + 1)))
            is_spurious = (assign == args.max_num_objects)
            is_real = is_observed & ~is_spurious
            num_observed = is_observed.float().sum(-1, True)
            # TODO Make these Bernoulli probs more plausible.
            pyro.sample("is_real",
                        dist.Bernoulli(args.expected_num_objects / observations.shape[1]),
                        obs=is_real.float())
            pyro.sample("is_spurious",
                        dist.Bernoulli(args.expected_num_spurious / observations.shape[1]),
                        obs=is_spurious.float())

            # The remaining continuous part is exact.
            observed_positions = observations[..., 0]
            with poutine.scale(scale=is_real.float()):
                bogus_position = positions.new_zeros(args.num_frames, 1)
                augmented_positions = torch.cat([positions, bogus_position], -1)
                predicted_positions = augmented_positions[:, assign]
                pyro.sample("real_observations",
                            dist.Normal(predicted_positions, args.emission_noise_scale),
                            obs=observed_positions)
            with poutine.scale(scale=is_spurious.float()):
                pyro.sample("spurious_observations", dist.Normal(0., 1.),
                            obs=observed_positions)

This guide uses a smart assignment solver but a naive state estimator. A smarter implementation would use message passing also for state estimation, e.g. a Kalman filter-smoother.

In [None]:
@poutine.broadcast
def guide(args, observations):
    # Initialize states randomly from the prior.
    states_loc = pyro.param("states_loc", lambda: torch.randn(args.max_num_objects, 2))
    positions = get_positions(states_loc, args.num_frames)

    # Solve soft assignment problem.
    real_dist = dist.Normal(positions.unsqueeze(-2), args.emission_noise_scale)
    spurious_dist = dist.Normal(0., 1.)
    is_observed = (observations[..., -1] > 0)
    observed_positions = observations[..., 0].unsqueeze(-1)
    assign_logits = (real_dist.log_prob(observed_positions) -
                     spurious_dist.log_prob(observed_positions) +
                     math.log(args.expected_num_objects * args.emission_prob /
                              args.expected_num_spurious))
    assign_logits[~is_observed] = -float('inf')
    exists_logits = torch.empty(args.max_num_objects).fill_(
        math.log(args.expected_num_objects / args.max_num_objects))
    assignment = MarginalAssignmentPersistent(exists_logits, assign_logits, args.bp_iters)

    with pyro.iarange("objects", args.max_num_objects):
        exists = pyro.sample("exists", assignment.exists_dist, infer={"enumerate": "parallel"})
        with poutine.scale(scale=exists):
            pyro.sample("states", dist.Delta(states_loc).independent(1))
    with pyro.iarange("detections", observations.shape[1]):
        with poutine.scale(scale=is_observed.float()):
            with pyro.iarange("time", args.num_frames):
                pyro.sample("assign", assignment.assign_dist, infer={"enumerate": "parallel"})

    return assignment, states_loc

## Generate data

We'll define a global config object to make it easy to port code to `argparse`.

In [None]:
args = type('Args', (object,), {})  # A fake ArgumentParser.parse_args() result.

args.num_frames = 10
args.max_num_objects = 400
args.expected_num_objects = 2.
args.expected_num_spurious = 0.1  # If this is too small, BP will be unstable.
args.emission_prob = 0.9          # If this is too large, BP will be unstable.
args.emission_noise_scale = 0.2   # If this is too small, SVI will see flat gradients.
args.bp_iters = 20
args.svi_iters = 101

assert args.max_num_objects >= args.expected_num_objects

In [None]:
pyro.set_rng_seed(0)
true_states, true_positions, observations = generate_observations(args)
true_num_objects = len(true_states)
max_num_detections = observations.shape[1]
assert true_states.shape == (true_num_objects, 2)
assert true_positions.shape == (args.num_frames, true_num_objects)
assert observations.shape == (args.num_frames, max_num_detections, 1+1)
print("generated {:d} detections from {:d} objects".format(
    (observations[..., -1] > 0).long().sum(), true_num_objects))

Before training, our solution should replicate importance sampling.

In [None]:
pyro.set_rng_seed(1)  # Use a different seed from data generation
pyro.clear_param_store()
assignment, states_loc = guide(args, observations)
p_exists = assignment.exists_dist.probs
positions = get_positions(states_loc, args.num_frames)
plot_solution(observations, p_exists, positions, true_positions, args, 'Before training')
plot_exists_prob(p_exists)

## Train

In [None]:
%%time
pyro.set_rng_seed(1)  # Use a different seed from data generation
pyro.clear_param_store()
infer = SVI(model, guide, ClippedAdam({"lr": 0.1}), TraceEnum_ELBO(max_iarange_nesting=2))
losses = []
for epoch in range(args.svi_iters if not smoke_test else 2):
    loss = infer.step(args, observations)
    if epoch % 20 == 0:
        print("epoch {: >4d} loss = {}".format(epoch, loss))
    losses.append(loss)
pyplot.plot(losses);

## Evaluate
Currently it looks like we're getting stuck in a local mode.

In [None]:
assignment, states_loc = guide(args, observations)
p_exists = assignment.exists_dist.probs
positions = get_positions(states_loc, args.num_frames)
plot_solution(observations, p_exists, positions, true_positions, args, 'After training')
plot_exists_prob(p_exists)