Skip to content

Commit

Permalink
Fix enumerate + iarange (#828)
Browse files Browse the repository at this point in the history
* Add failing test for enuerate + iarange

* Decrease test precision

* add passing test: test_non_mean_field_bern_bern_elbo_gradient

* Add a failing test of enumerate + nested iarange

* Raise error when enumerating reshaped distributions with nonempty event shape

* add two expensive passing tests

* Sketch fix to Trace_ELBO

* WIP sketched Trace_ELBO implementation that stores cumulative weight in site["scale"]

* Fix some bugs in iter_discrete_traces and Trace.copy()

* Fix some shaping errors in iter_discrete_traces

* Fix bugs in sequential enumeration

* Switch to MultiViewTensor in iter_discrete_traces; speed up tests

* Further parameterize one test

* WIP sketch enum_stack

* Attempt to avoid double counting when enumerating iarange with mixed nesting

* Fix trivial tests

* Fix scaling of observe sites, add test for loss value

* Reorganize computation using local score x upstream grads

* Fix bugs in TensorTree; add elbo loss tests

* Get all sequential tests to pass

* Add elbo loss tests

* flake8

* Fix test failures in distributions and poutine.replay

* Add stronger tests for TraceEnum_ELBO

* Add xfailing tests for irange + enumeration

* Combine tests cases to speed up tests

* Add minimal failing test for irange + enumerate

* Make cond_indep_stack a tuple (hence hashable)

* Fix bug in sequentially enumerating nested iarange

* Revert changes to replay

* Revert minor changes

* Fix TraceEnum_ELBO handling of observe sites
  • Loading branch information
fritzo authored and martinjankowiak committed Mar 6, 2018
1 parent 2c62f77 commit e3f0adf
Show file tree
Hide file tree
Showing 14 changed files with 869 additions and 281 deletions.
5 changes: 5 additions & 0 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,14 @@ def score_parts(self, value):
return ScoreParts(log_pdf, score_function, entropy_term)

def enumerate_support(self):
if self.extra_event_dims:
raise NotImplementedError("Pyro does not enumerate over cartesian products")

samples = self.base_dist.enumerate_support()
if not self.sample_shape:
return samples

# Shift enumeration dim to correct location.
enum_shape, base_shape = samples.shape[:1], samples.shape[1:]
samples = samples.contiguous()
samples = samples.view(enum_shape + (1,) * len(self.sample_shape) + base_shape)
Expand Down
22 changes: 19 additions & 3 deletions pyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,27 @@ def sum_leftmost(value, dim):
def scale_tensor(tensor, scale):
"""
Safely scale a tensor without increasing its ``.size()``.
This avoids NANs by assuming ``inf * 0 = 0 * inf = 0``.
"""
if is_identically_zero(tensor) or is_identically_one(scale):
return tensor
if isinstance(tensor, numbers.Number):
if isinstance(scale, numbers.Number):
return tensor * scale
elif tensor == 0:
return torch.zeros_like(scale)
elif tensor == 1:
return scale
else:
return scale
if isinstance(scale, numbers.Number):
if scale == 0:
return torch.zeros_like(tensor)
elif scale == 1:
return tensor
else:
return tensor * scale
result = tensor * scale
if not isinstance(result, numbers.Number) and result.shape != tensor.shape:
result[(scale == 0).expand_as(result)] = 0 # avoid NANs
if result.shape != tensor.shape:
raise ValueError("Broadcasting error: scale is incompatible with tensor: "
"{} vs {}".format(scale.shape, tensor.shape))
return result
Expand Down
90 changes: 62 additions & 28 deletions pyro/infer/enum.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from __future__ import absolute_import, division, print_function

import math
import functools

import torch
from six.moves.queue import LifoQueue
from torch.autograd import Variable

from pyro import poutine
from pyro.distributions.util import sum_rightmost
from pyro.infer.util import TreeSum
from pyro.poutine.trace import Trace


def _iter_discrete_filter(name, msg):
def _iter_discrete_filter(msg):
return ((msg["type"] == "sample") and
(not msg["is_observed"]) and
(msg["infer"].get("enumerate"))) # either sequential or parallel
msg["infer"].get("enumerate")) # sequential or parallel


def _iter_discrete_escape(trace, msg):
Expand All @@ -24,38 +22,74 @@ def _iter_discrete_escape(trace, msg):
(msg["name"] not in trace))


def iter_discrete_traces(graph_type, max_iarange_nesting, fn, *args, **kwargs):
def _iter_discrete_extend(trace, site, enum_tree):
values = site["fn"].enumerate_support()
log_probs = site["fn"].log_prob(values).detach()
for i, (value, log_prob) in enumerate(zip(values, log_probs)):
extended_site = site.copy()
extended_site["value"] = value
extended_trace = trace.copy()
extended_trace.add_node(site["name"], **extended_site)
extended_enum_tree = enum_tree.copy()
extended_enum_tree.add(site["cond_indep_stack"], (i,))
yield extended_trace, extended_enum_tree


def _iter_discrete_queue(graph_type, fn, *args, **kwargs):
queue = LifoQueue()
partial_trace = Trace()
enum_tree = TreeSum()
queue.put((partial_trace, enum_tree))
while not queue.empty():
partial_trace, enum_tree = queue.get()
traced_fn = poutine.trace(poutine.escape(poutine.replay(fn, partial_trace),
functools.partial(_iter_discrete_escape, partial_trace)),
graph_type=graph_type)
try:
yield traced_fn.get_trace(*args, **kwargs), enum_tree
except poutine.util.NonlocalExit as e:
e.reset_stack()
for item in _iter_discrete_extend(traced_fn.trace, e.site, enum_tree):
queue.put(item)


def iter_discrete_traces(graph_type, fn, *args, **kwargs):
"""
Iterate over all discrete choices of a stochastic function.
When sampling continuous random variables, this behaves like `fn`.
When sampling discrete random variables, this iterates over all choices.
This yields `(scale, trace)` pairs, where `scale` is the probability of the
discrete choices made in the `trace`.
This yields traces scaled by the probability of the discrete choices made
in the `trace`.
:param str graph_type: The type of the graph, e.g. "flat" or "dense".
:param callable fn: A stochastic function.
:returns: An iterator over (scale, trace) pairs.
:returns: An iterator over (weights, trace) pairs, where weights is a
:class:`~pyro.infer.util.TreeSum` object.
"""
queue = LifoQueue()
queue.put(Trace())
q_fn = poutine.queue(fn, queue=queue, escape_fn=_iter_discrete_escape)
while not queue.empty():
full_trace = poutine.trace(q_fn, graph_type=graph_type).get_trace(*args, **kwargs)

# Scale trace by probability of discrete choices.
log_pdf = 0
full_trace.compute_batch_log_pdf(site_filter=_iter_discrete_filter)
for name, site in full_trace.nodes.items():
if _iter_discrete_filter(name, site):
log_pdf = log_pdf + sum_rightmost(site["batch_log_pdf"], max_iarange_nesting)
if isinstance(log_pdf, Variable):
scale = torch.exp(log_pdf.detach())
else:
scale = math.exp(log_pdf)

yield scale, full_trace
already_counted = set() # to avoid double counting
for trace, enum_tree in _iter_discrete_queue(graph_type, fn, *args, **kwargs):
# Collect log_probs for each iarange stack.
log_probs = TreeSum()
if not already_counted:
log_probs.add((), 0) # ensures globals are counted exactly once
for name, site in trace.nodes.items():
if _iter_discrete_filter(site):
cond_indep_stack = site["cond_indep_stack"]
log_prob = site["fn"].log_prob(site["value"]).detach()
log_probs.add(cond_indep_stack, log_prob)

# Avoid double-counting across traces.
weights = log_probs.exp()
for context in enum_tree.items():
if context in already_counted:
cond_indep_stack, _ = context
weights.prune(cond_indep_stack)
else:
already_counted.add(context)

yield weights, trace


def _config_enumerate(default):
Expand Down
137 changes: 68 additions & 69 deletions pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
from __future__ import absolute_import, division, print_function

import numbers
import warnings

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero, sum_rightmost
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import iter_discrete_traces
from pyro.infer.util import torch_data_sum, torch_sum
from pyro.infer.util import TreeSum
from pyro.poutine.enumerate_poutine import EnumeratePoutine
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, check_site_shape, is_nan


def _compute_upstream_grads(trace):
upstream_grads = TreeSum()

for site in trace.nodes.values():
if site["type"] != "sample":
continue
score_function_term = site["score_parts"].score_function
if is_identically_zero(score_function_term):
continue
upstream_grads.add(site["cond_indep_stack"], score_function_term)

return upstream_grads


class TraceEnum_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI that supports enumeration
Expand All @@ -38,28 +49,24 @@ def _get_traces(self, model, guide, *args, **kwargs):

for i in range(self.num_particles):
# iterate over a bag of traces, one trace per particle
for scale, guide_trace in iter_discrete_traces("flat", self.max_iarange_nesting, guide, *args, **kwargs):
for weights, guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

check_model_guide_match(model_trace, guide_trace)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

log_r = 0
model_trace.compute_batch_log_pdf()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r + sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)
guide_trace.compute_score_parts()
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r - sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)

weight = scale / self.num_particles
yield weight, model_trace, guide_trace, log_r
yield weights, model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand All @@ -69,26 +76,24 @@ def loss(self, model, guide, *args, **kwargs):
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
"""
elbo = 0.0
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
for weights, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = 0
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = sum_rightmost(model_site["batch_log_pdf"], self.max_iarange_nesting)
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_log_pdf
else:
guide_site = guide_trace.nodes[name]
guide_log_pdf = sum_rightmost(guide_site["batch_log_pdf"], self.max_iarange_nesting)
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch.zeros_like(elbo_particle)
else:
elbo_particle[weight == 0] = 0.0

elbo += torch_data_sum(weight * elbo_particle)
if model_site["type"] != "sample":
continue

# grab weights introduced by enumeration
cond_indep_stack = model_site["cond_indep_stack"]
weight = weights.get_upstream(cond_indep_stack)
if weight is None:
continue

log_r = model_site["batch_log_pdf"]
if not model_site["is_observed"]:
log_r = log_r - guide_trace.nodes[name]["batch_log_pdf"]
elbo_particle += (log_r * weight).sum().item()

elbo += elbo_particle / self.num_particles

loss = -elbo
if is_nan(loss):
Expand All @@ -105,53 +110,47 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
"""
elbo = 0.0
# grab a trace from the generator
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
surrogate_elbo_particle = weight * 0
for weights, model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
upstream_grads = _compute_upstream_grads(guide_trace)
elbo_particle = 0
surrogate_elbo_particle = 0
# compute elbo and surrogate elbo
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = sum_rightmost(model_site["batch_log_pdf"], self.max_iarange_nesting)
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_log_pdf
surrogate_elbo_particle = surrogate_elbo_particle + model_log_pdf
else:
guide_site = guide_trace.nodes[name]
guide_log_pdf, score_function_term, entropy_term = guide_site["score_parts"]

guide_log_pdf = sum_rightmost(guide_log_pdf, self.max_iarange_nesting)
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf
surrogate_elbo_particle = surrogate_elbo_particle + model_log_pdf

if not is_identically_zero(entropy_term):
entropy_term = sum_rightmost(entropy_term, self.max_iarange_nesting)
surrogate_elbo_particle -= entropy_term

if not is_identically_zero(score_function_term):
score_function_term = sum_rightmost(score_function_term, self.max_iarange_nesting)
surrogate_elbo_particle = surrogate_elbo_particle + log_r.detach() * score_function_term

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch.zeros_like(elbo_particle)
surrogate_elbo_particle = torch.zeros_like(surrogate_elbo_particle)
else:
weight_eq_zero = (weight == 0)
elbo_particle[weight_eq_zero] = 0.0
surrogate_elbo_particle[weight_eq_zero] = 0.0

elbo += torch_data_sum(weight * elbo_particle)
surrogate_elbo_particle = torch_sum(weight * surrogate_elbo_particle)

# collect parameters to train from model and guide
if model_site["type"] != "sample":
continue

# grab weights introduced by enumeration
cond_indep_stack = model_site["cond_indep_stack"]
weight = weights.get_upstream(cond_indep_stack)
if weight is None:
continue

model_log_pdf = model_site["batch_log_pdf"]
log_r = model_log_pdf
surrogate_elbo_site = model_log_pdf
score_function_term = upstream_grads.get_upstream(cond_indep_stack)

if not model_site["is_observed"]:
guide_log_pdf, _, entropy_term = guide_trace.nodes[name]["score_parts"]
log_r = log_r - guide_log_pdf
if not is_identically_zero(entropy_term):
surrogate_elbo_site = surrogate_elbo_site - entropy_term

if score_function_term is not None:
surrogate_elbo_site = surrogate_elbo_site + log_r.detach() * score_function_term

elbo_particle += (log_r * weight).sum().item()
surrogate_elbo_particle = surrogate_elbo_particle + (surrogate_elbo_site * weight).sum()

elbo += elbo_particle / self.num_particles

trainable_params = set(site["value"]
for trace in (model_trace, guide_trace)
for site in trace.nodes.values()
if site["type"] == "param")

if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False):
surrogate_loss_particle = -surrogate_elbo_particle
surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles
surrogate_loss_particle.backward()
pyro.get_param_store().mark_params_active(trainable_params)

Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def n_compatible_indices(dest_node, source_node):
downstream_guide_cost_nodes[site].update([child])

for k in topo_sort_guide_nodes:
downstream_costs[k] = downstream_costs[k].contract_to(guide_trace.nodes[k]['batch_log_pdf'])
downstream_costs[k] = downstream_costs[k].contract_as(guide_trace.nodes[k]['batch_log_pdf'])

return downstream_costs, downstream_guide_cost_nodes

Expand Down

0 comments on commit e3f0adf

Please sign in to comment.