In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:60% !important; }</style>"))

# Tracking an unknown number of objects

While SVI can be used to learn components and assignments of a mixture model, pyro.contrib.tracking provides more efficient inference algorithms to estimate assignments. This notebook demonstrates how to use the `MarginalAssignmentPersistent` with EM.

In [None]:
%matplotlib inline
from __future__ import absolute_import, division, print_function
import math
import os
import torch
from torch.distributions import constraints
from torch import nn
from matplotlib import pyplot

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.tracking.assignment import MarginalAssignmentPersistent
from pyro.contrib.tracking.hashing import LSH, merge_points
from pyro.ops.newton import newton_step
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import ClippedAdam, ASGD, SGD
from pyro.util import warn_if_nan

pyro.enable_validation(True)
smoke_test = ('CI' in os.environ)

def diag_tensor(tensor):
    print("shape:{}, mean:{} std:{}, min:{}, max:{}".format(tensor.shape,tensor.mean(),
                                                            tensor.std(),tensor.min(),
                                                            tensor.max()))

We'll define a global config object to make it easy to port code to `argparse`.

In [None]:
args = type('Args', (object,), {})  # A fake ArgumentParser.parse_args() result.
args.num_frames = 40
args.max_detections_per_frame = 100
args.max_num_objects = 100
args.expected_num_objects = 2.

args.PNR = 10
args.num_sensors=100
args.x_min, args.x_max = -2.5, 2.5

args.bp_iters = 25
args.bp_momentum =0.5
args.svi_iters = 201
args.em_iters = 10
args.merge_radius = -1
args.prune_threshold = -1

assert args.max_num_objects >= args.expected_num_objects
assert args.x_max > args.x_min
assert args.max_detections_per_frame >= args.max_num_objects

Let's consider a model with deterministic dynamics, say sinusoids with known period but unknown phase and amplitude.

In [None]:
def get_dynamics(num_frames):
    time = torch.arange(num_frames,dtype=torch.float)*2*math.pi/num_frames
    return torch.stack([time.cos(), time.sin()], -1)

It's tricky to define a fully generative model, so instead we'll separate our data generation process `generate_data()` from a factor graph `model()` that will be used in inference.

In [None]:
def generate_data(args):
    num_objects = int(round(args.expected_num_objects))  # Deterministic.
        #num_objects = int(dist.Poisson(args.expected_num_objects).sample())
    states = dist.Normal(0., 1.).sample((num_objects, 2))
    confidence = torch.empty(args.num_frames,args.num_sensors)
    positions = get_dynamics(args.num_frames).mm(states.t())
    noise_power = 10 ** (-args.PNR/10)
    noise_dist = dist.Normal(0, noise_power)
    #confidence is number of objects indicating sensor senses object/s
    for t in range(args.num_frames):
        confidence[t]=torch.histc(positions[t],args.num_sensors,args.x_min,args.x_max)
    # if sensors are saturated: can't diff btw 1 object and multiple objects.
    confidence[confidence>1] = 1
    #AWGN model
    sensor_outputs = noise_dist.sample(confidence.shape) + confidence
    bin_width=(args.x_max - args.x_min)/args.num_sensors
    sensor_positions = torch.arange(args.x_min, args.x_max, bin_width) + bin_width/2
    return states, positions, sensor_positions, sensor_outputs, confidence

## Detector
This detector has 2 trainable parameters: w and b, where $confidence = sigmoid(wx+b)$

In [None]:
class Detector(nn.Module):
    # returns confidence of sensor sensing the object
    def __init__(self, max_detections_per_frame):
        super(Detector, self).__init__()
        self.linear = nn.Linear(1, 1)
        self.sigmoid = nn.Sigmoid()
        self.max_detections_per_frame = max_detections_per_frame
        nn.init.constant_(self.linear.weight, 1.)
        nn.init.constant_(self.linear.bias, 0.)
        if not self.linear.bias.requires_grad:
            print('wtf bias')
            self.linear.bias.requires_grad= True
        if not self.linear.weight.requires_grad:
            print('wtf weight')
            self.linear.weight.requires_grad= True
        
    def __str__(self):
        return "Detector: w={}, b={}, max_detections_per_frame={}".format(self.linear.weight.item(),
                                                                          self.linear.bias.item(),
                                                                          self.max_detections_per_frame)

    def forward(self, sensor_positions, sensor_outputs):
        # x * w + b
        return torch.sigmoid(self.linear((sensor_outputs-0.5).unsqueeze(-1)).squeeze(-1))
        
detector = Detector(args.max_detections_per_frame)

## Model

In [None]:
def compute_exists_logits(states_loc, replicates):
    FUDGE = -2
    # TODO add a term for prior over object location
    return states_loc.new_empty(states_loc.shape[0]).fill_(-math.log(replicates) + FUDGE)

def compute_assign_logits(positions, observations, replicates, args):
    log_likelihood = detection_log_likelihood(positions, observations, args)
    assign_logits = log_likelihood[...,:-1] - log_likelihood[...,-1:] - math.log(replicates)
    assign_logits[log_likelihood[..., :-1] == -float('inf')] = -float('inf')
    #assign_logits -= torch.max(assign_logits, -1, keepdim=True)[0]
    return assign_logits

def detection_log_likelihood(positions, observations, args):
    noise_power = 10 ** (-args.PNR/10)
    bin_width = (args.x_max-args.x_min)/args.num_sensors
    real_loc_dist = dist.Normal(positions.unsqueeze(-2), bin_width)
    real_output_dist = dist.Normal(1., noise_power)
    spurious_output_dist = dist.Normal(0., noise_power)
    spurious_loc_dist = dist.Uniform(args.x_min, args.x_max)
    observed_positions = observations[..., 0].unsqueeze(-1)
    observed_outputs = observations[..., 2].unsqueeze(-1)
    a = (real_loc_dist.log_prob(observed_positions) +
         real_output_dist.log_prob(observed_outputs) + 
         math.log(args.expected_num_objects)
        )
    b = (spurious_loc_dist.log_prob(observed_positions) +
         spurious_output_dist.log_prob(observed_outputs) +
         math.log(args.max_detections_per_frame-args.expected_num_objects)
        )
    print("a:{}, b:{}".format(a.shape,b.shape))
    return torch.cat((a,b), dim=-1) 

def obs2sensor(obs,args):
    sensor_outputs = torch.zeros(args.num_frames,args.num_sensors)
    pos2sensoridx= lambda pos: torch.floor((pos - args.x_min)/
                                           (args.x_max-args.x_min)*
                                           args.num_sensors).long()
    for i in range(obs.shape[0]):
        for j in range(obs.shape[1]):
            if obs[i,j,1]>=0.0:
                sensor_outputs[i,pos2sensoridx(obs[i,j,0])] = obs[i,j,2]
    return sensor_outputs

def sensor2obs(sensor_positions, sensor_outputs, confidence, args):
    observations = torch.zeros(sensor_outputs.shape[-2], args.max_detections_per_frame, 3)
    for i in range(args.num_frames):
        k=0
        _ , idx = torch.sort(confidence[i], descending=True)
        for j in range(min(sensor_positions.shape[0], int(args.max_detections_per_frame))):
            observations[i,j,0] = sensor_positions[idx[j]]
            observations[i,j,1] = confidence[i,idx[j]]
            observations[i,j,2] = sensor_outputs[i,idx[j]]
    return observations

In [None]:
class DetectorTracker(nn.Module):
    def __init__(self, args):
        super(DetectorTracker, self).__init__()
        self.detector = Detector(args.max_detections_per_frame)
        self.num_objects = args.max_num_objects

    @poutine.broadcast
    def model(self, sensor_positions, sensor_outputs, args):
        bin_width = (args.x_max-args.x_min)/args.num_sensors
        pyro.module("detectorTracker", self)
        confidence = self.detector.forward(sensor_positions, sensor_outputs)
        observations = sensor2obs(sensor_positions, sensor_outputs, confidence, args)
        with pyro.iarange("objects", self.num_objects):
            exists = pyro.sample("exists",
                                 dist.Bernoulli(args.expected_num_objects / self.num_objects),
                                 obs = confidence)
            with poutine.scale(scale=exists):
                states = pyro.sample("states", dist.Normal(0., 1.).expand([2]).independent(1))
                positions = get_dynamics(args.num_frames).mm(states.t())

        with pyro.iarange("time", args.num_frames):
            with pyro.iarange("detections", args.max_detections_per_frame):
                # The combinatorial part of the log prob is approximated to allow independence.
                assign = pyro.sample("assign",
                                     dist.Categorical(torch.ones(self.num_objects)))
                is_real= exists[assign]
                observed_positions = observations[..., 0]

                with poutine.scale(scale=is_real.float()):
                    #bogus_position = positions.new_zeros(args.num_frames, 1)
                    #augmented_positions = torch.cat([positions, bogus_position], -1)
                    predicted_positions = positions[:, assign]
                    pyro.sample("real_observations",
                                dist.Normal(predicted_positions, bin_width),
                                obs=observed_positions)
                with poutine.scale(scale=(1-is_real.float())):
                    pyro.sample("spurious_observations", dist.Uniform(args.x_min,args.x_max),
                                obs=observed_positions)
        observation_hat= torch.stack([observed_positions,is_real,is_real], observed_positions.dims())
        sensor_hat = obs2sensor(observation_hat,args)
        noise_power = 10 ** (-args.PNR/10)
        pyro.sample("sensor_output",
            dist.Normal(sensor_hat.view(-1),noise_power).independent(1),
            obs=sensor_outputs.view(-1))

    @poutine.broadcast    
    def guide(self, sensor_positions, sensor_outputs, args):
        pyro.module('detectorTracker', self)
        bin_width = (args.x_max-args.x_min)/args.num_sensors
        confidence = self.detector.forward(sensor_positions, sensor_outputs)
        observations = sensor2obs(sensor_positions, sensor_outputs, confidence, args)
        if observations.dim() == 3:
            states_loc = torch.randn(self.num_objects, 2, requires_grad=True)
            for em_iter in range(args.em_iters):
                states_loc = states_loc.detach()
                states_loc.requires_grad = True
                positions = get_dynamics(args.num_frames).mm(states_loc.t())
                assert states_loc.requires_grad
                assert positions.requires_grad
                assert positions.grad_fn is not None
                replicates = max(1, states_loc.shape[0]/args.expected_num_objects)
                # E-step: compute soft assignments
                with torch.no_grad():
                    assign_logits = compute_assign_logits(positions, observations, replicates, args)
                    exists_logits = compute_exists_logits(states_loc, replicates)
                    assignment = MarginalAssignmentPersistent(exists_logits, assign_logits,
                                                      args.bp_iters, bp_momentum=args.bp_momentum)
                    p_exists = assignment.exists_dist.probs
                    p_assign = assignment.assign_dist.probs

                log_likelihood = detection_log_likelihood(positions, observations, args)
                loss = -(log_likelihood * p_assign).sum()
                # M-step
                states_loc, _ = newton_step(loss, states_loc, bin_width)  

                if args.prune_threshold > 0.0:
                    states_loc = states_loc[p_exists > args.prune_threshold]
                    self.num_objects = states_loc.shape[0]
                if args.merge_radius >= 0.0:
                    states_loc, _ = merge_points(states_loc, args.merge_radius)
                    self.num_objects = states_loc.shape[0]
                warn_if_nan(states_loc, 'states_loc')
        else:
            print("Warning: no object detected, observations.shape:{}".format(observations.shape))

        positions = get_dynamics(args.num_frames).mm(states_loc.t())
        replicates = max(1, states_loc.shape[0]/args.expected_num_objects)
        assign_logits = compute_assign_logits(positions, observations, replicates, args)
        exists_logits = compute_exists_logits(states_loc, replicates)
        assignment = MarginalAssignmentPersistent(exists_logits, assign_logits,
                                          args.bp_iters, bp_momentum=args.bp_momentum)

        with pyro.iarange("objects", states_loc.shape[0]):
            exists = pyro.sample("exists", assignment.exists_dist, infer={"enumerate": "parallel"})
            with poutine.scale(scale=exists):
                pyro.sample("states", dist.Delta(states_loc).independent(1))
        with pyro.iarange("detections", observations.shape[1]):
            with pyro.iarange("time", args.num_frames):
                pyro.sample("assign", assignment.assign_dist, infer={"enumerate": "parallel"})
        return assignment, states_loc, observations

## Plotting utils

In [None]:
def plot_solution(sensor_outputs, assignment, states_loc, observations, args, msg=''):
    with torch.no_grad():
        assignment, states_loc, observations, gdetector = self.guide(args, sensor_positions, sensor_outputs)
    positions = get_dynamics(args.num_frames).mm(states_loc.t())
    fig, ax = pyplot.subplots(figsize=(12,9))
    fig.patch.set_color('white')
    extent = [0-.5, args.num_frames-.5, args.x_min, args.x_max]
    cax = ax.matshow(sensor_outputs.t(),aspect='auto',extent=extent, origin='lower',alpha=0.5)
    pyplot.colorbar(cax)
    pyplot.plot(true_positions.numpy(), 'k--')
    is_observed = (observations[..., -2] > args.confidence_threshold)
    pos = observations[..., 0]
    conf = observations[..., 1]
    time = torch.arange(args.num_frames).unsqueeze(-1).expand_as(pos)
    pyplot.scatter(time[is_observed].view(-1).numpy(),
                   pos[is_observed].detach().view(-1).numpy(), color='k', marker='+',
                   s=8*100**conf[is_observed].detach().view(-1).numpy(),
                   label='observation')
    for i in range(assignment.exists_dist.probs.shape[0]):
        p_exists = assignment.exists_dist.probs[i].item()
        position = positions[:, i].detach().numpy()
        pyplot.plot(position, alpha=p_exists, color='C0')
    if args.expected_num_objects == 1:
        p_exists = assignment.exists_dist.probs
        mean = (p_exists * positions).sum(-1) / p_exists.sum(-1)
        pyplot.plot(mean.detach().numpy(), 'r--', alpha=0.5, label='mean')
    pyplot.title('Truth, observations, and {:0.1f} predicted tracks {}'.format(
                 assignment.exists_dist.probs.sum().item(), message))
    pyplot.plot([], 'k--', label='truth')
    pyplot.plot([], color='C0', label='prediction')
    pyplot.legend(loc='best')
    pyplot.xlabel('time step')
    pyplot.ylabel('position')
    pyplot.tight_layout()
    
def plot_exists_histogram(p_exists, args):
    p_exists = p_exists.detach().numpy()
    pyplot.figure(figsize=(6,4)).patch.set_color('white')
    pyplot.plot(sorted(p_exists))
    pyplot.ylim(0, None)
    pyplot.xlim(0, len(p_exists))
    pyplot.ylabel('p_exists')
    pyplot.xlabel('rank')
    pyplot.title('Prob(exists) of {} potential objects, total = {:0.2f}'.format(
        len(p_exists), p_exists.sum()))
    pyplot.tight_layout()

## Generate data

In [None]:
pyro.set_rng_seed(0)
true_states, true_positions, sensor_positions, sensor_outputs, true_confidence = generate_data(args)
true_num_objects = len(true_states)
assert true_states.shape == (true_num_objects, 2)
assert true_positions.shape == (args.num_frames, true_num_objects)

## Train
10 iterations of EM with and without merging

In [None]:
dt = DetectorTracker(args)
pyro.set_rng_seed(1)  # Use a different seed from data generation
pyro.clear_param_store()
assignment, states_loc, observations = dt.guide(sensor_positions, sensor_outputs, args)
plot_solution(sensor_outputs, assignment, states_loc, observations, args, 'Before training')
plot_exists_histogram(assignment.exists_dist.probs)

In [None]:
%debug

In [None]:
%%time
pyro.set_rng_seed(1)  # Use a different seed from data generation
pyro.clear_param_store()
infer = SVI(dt.model, dt.guide, ClippedAdam({"lr": 0.1}), TraceEnum_ELBO(max_iarange_nesting=2))
losses = []
for epoch in range(args.svi_iters if not smoke_test else 2):
    loss = infer.step(args, sensor_positions, sensor_outputs, true_confidence)
    if epoch % 10 == 0:
        print("epoch {: >4d} loss = {}".format(epoch, loss))
    losses.append(loss)
pyplot.plot(losses);

In [None]:
pyro.set_rng_seed(1)  # Use a different seed from data generation
assignment, states_loc, observations = guide(sensor_positions, sensor_outputs, args)
plot_solution(sensor_outputs, assignment, states_loc, observations, args, 'After training')
plot_exists_histogram(assignment.exists_dist.probs)