Skip to content

Commit

Permalink
Split TraceEnum_ELBO out of Trace_ELBO (#848)
Browse files Browse the repository at this point in the history
* Split TraceEnum_ELBO out of Trace_ELBO

* Add TraceEnum_ELBO to test_valid_models; fix a test

* Fix typos
  • Loading branch information
fritzo authored and martinjankowiak committed Mar 5, 2018
1 parent 1ce6234 commit 4ff3ba0
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 186 deletions.
3 changes: 2 additions & 1 deletion examples/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ def run_inference_ss_vae(args):
# set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
# by enumerating each class label for the sampled discrete categorical distribution in the model
guide = config_enumerate(ss_vae.guide, args.enum_discrete)
loss_basic = SVI(ss_vae.model, guide, optimizer, loss="ELBO", max_iarange_nesting=1)
loss_basic = SVI(ss_vae.model, guide, optimizer, loss="ELBO",
enum_discrete=True, max_iarange_nesting=1)

# build a list of all losses considered
losses = [loss_basic]
Expand Down
16 changes: 13 additions & 3 deletions pyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ class ELBO(object):
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param bool enum_discrete: Whether to sum over discrete latent variables,
rather than sample them.
:param int max_iarange_nesting: optional bound on max number of nested
:func:`pyro.iarange` contexts. This is only required to enumerate over
sample sites in parallel, e.g. if a site sets
Expand All @@ -34,7 +32,7 @@ def __init__(self,
self.max_iarange_nesting = max_iarange_nesting

@staticmethod
def make(trace_graph=False, **kwargs):
def make(trace_graph=False, enum_discrete=False, **kwargs):
"""
Factory to construct an ELBO implementation.
Expand All @@ -46,10 +44,22 @@ def make(trace_graph=False, **kwargs):
dependency information can be expensive. See the tutorial
`SVI Part III <http://pyro.ai/examples/svi_part_iii.html>`_ for a
discussion.
:param bool enum_discrete: Whether to support summing over discrete
latent variables, rather than sampling them. To sum out latent
variables, either wrap the guide in
:func:`~pyro.infer.enum.config_enumerate` or mark individual sample
sites with ``infer={"enumerate": "sequential"}`` or
``infer={"enumerate": "parallel"}``.
"""
if trace_graph and enum_discrete:
raise ValueError("Cannot combine trace_graph with enum_discrete")

if trace_graph:
from .tracegraph_elbo import TraceGraph_ELBO
return TraceGraph_ELBO(**kwargs)
elif enum_discrete:
from .traceenum_elbo import TraceEnum_ELBO
return TraceEnum_ELBO(**kwargs)
else:
from .trace_elbo import Trace_ELBO
return Trace_ELBO(**kwargs)
117 changes: 38 additions & 79 deletions pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@
import numbers
import warnings

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero, sum_rightmost
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import iter_discrete_traces
from pyro.infer.util import torch_backward, torch_data_sum, torch_sum
from pyro.poutine.enumerate_poutine import EnumeratePoutine
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, check_site_shape, is_nan

Expand All @@ -26,33 +21,24 @@ def _get_traces(self, model, guide, *args, **kwargs):
runs the guide and runs the model against the guide with
the result packaged as a trace generator
"""
# enable parallel enumeration
guide = EnumeratePoutine(guide, first_available_dim=self.max_iarange_nesting)

for i in range(self.num_particles):
# iterate over a bag of traces, one trace per particle
for scale, guide_trace in iter_discrete_traces("flat", self.max_iarange_nesting, guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

check_model_guide_match(model_trace, guide_trace)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

log_r = 0
model_trace.compute_batch_log_pdf()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r + sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)
guide_trace.compute_score_parts()
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r - sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)

weight = scale / self.num_particles
yield weight, model_trace, guide_trace, log_r
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace(*args, **kwargs)

check_model_guide_match(model_trace, guide_trace)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

model_trace.compute_batch_log_pdf()
guide_trace.compute_score_parts()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)

yield model_trace, guide_trace

def loss(self, model, guide, *args, **kwargs):
"""
Expand All @@ -62,26 +48,9 @@ def loss(self, model, guide, *args, **kwargs):
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
"""
elbo = 0.0
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = sum_rightmost(model_site["batch_log_pdf"], self.max_iarange_nesting)
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_log_pdf
else:
guide_site = guide_trace.nodes[name]
guide_log_pdf = sum_rightmost(guide_site["batch_log_pdf"], self.max_iarange_nesting)
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch.zeros_like(elbo_particle)
else:
elbo_particle[weight == 0] = 0.0

elbo += torch_data_sum(weight * elbo_particle)
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = (model_trace.log_pdf() - guide_trace.log_pdf()).item()
elbo += elbo_particle / self.num_particles

loss = -elbo
if is_nan(loss):
Expand All @@ -98,54 +67,44 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
"""
elbo = 0.0
# grab a trace from the generator
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
surrogate_elbo_particle = weight * 0
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
log_r = model_trace.log_pdf() - guide_trace.log_pdf()
if not isinstance(log_r, numbers.Number):
log_r = log_r.detach()

elbo_particle = 0
surrogate_elbo_particle = 0
# compute elbo and surrogate elbo
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = sum_rightmost(model_site["batch_log_pdf"], self.max_iarange_nesting)
model_log_pdf = model_site["log_pdf"]
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_log_pdf
elbo_particle = elbo_particle + model_log_pdf.item()
surrogate_elbo_particle = surrogate_elbo_particle + model_log_pdf
else:
guide_site = guide_trace.nodes[name]
guide_log_pdf, score_function_term, entropy_term = guide_site["score_parts"]

guide_log_pdf = sum_rightmost(guide_log_pdf, self.max_iarange_nesting)
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf.sum()
surrogate_elbo_particle = surrogate_elbo_particle + model_log_pdf

if not is_identically_zero(entropy_term):
entropy_term = sum_rightmost(entropy_term, self.max_iarange_nesting)
surrogate_elbo_particle -= entropy_term
surrogate_elbo_particle -= entropy_term.sum()

if not is_identically_zero(score_function_term):
score_function_term = sum_rightmost(score_function_term, self.max_iarange_nesting)
surrogate_elbo_particle = surrogate_elbo_particle + log_r.detach() * score_function_term

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch.zeros_like(elbo_particle)
surrogate_elbo_particle = torch.zeros_like(surrogate_elbo_particle)
else:
weight_eq_zero = (weight == 0)
elbo_particle[weight_eq_zero] = 0.0
surrogate_elbo_particle[weight_eq_zero] = 0.0

elbo += torch_data_sum(weight * elbo_particle)
surrogate_elbo_particle = torch_sum(weight * surrogate_elbo_particle)
surrogate_elbo_particle = surrogate_elbo_particle + log_r * score_function_term.sum()

elbo += elbo_particle / self.num_particles

# collect parameters to train from model and guide
trainable_params = set(site["value"]
for trace in (model_trace, guide_trace)
for site in trace.nodes.values()
if site["type"] == "param")

if trainable_params:
surrogate_loss_particle = -surrogate_elbo_particle
torch_backward(surrogate_loss_particle)
if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False):
surrogate_loss_particle = -surrogate_elbo_particle / self.num_particles
surrogate_loss_particle.backward()
pyro.get_param_store().mark_params_active(trainable_params)

loss = -elbo
Expand Down
161 changes: 161 additions & 0 deletions pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from __future__ import absolute_import, division, print_function

import numbers
import warnings

import torch

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero, sum_rightmost
from pyro.infer.elbo import ELBO
from pyro.infer.enum import iter_discrete_traces
from pyro.infer.util import torch_data_sum, torch_sum
from pyro.poutine.enumerate_poutine import EnumeratePoutine
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, check_site_shape, is_nan


class TraceEnum_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI that supports enumeration
over discrete sample sites.
This implementation makes strong restrictions on the dependency
structure of the ``model`` and ``guide``:
Across :func:`~pyro.irange` and :func:`~pyro.iarange` blocks,
both dependency graphs should follow a tree structure. That is,
no variable outside of a block can depend on a variable in the block.
"""

def _get_traces(self, model, guide, *args, **kwargs):
"""
runs the guide and runs the model against the guide with
the result packaged as a trace generator
"""
# enable parallel enumeration
guide = EnumeratePoutine(guide, first_available_dim=self.max_iarange_nesting)

for i in range(self.num_particles):
# iterate over a bag of traces, one trace per particle
for scale, guide_trace in iter_discrete_traces("flat", self.max_iarange_nesting, guide, *args, **kwargs):
model_trace = poutine.trace(poutine.replay(model, guide_trace),
graph_type="flat").get_trace(*args, **kwargs)

check_model_guide_match(model_trace, guide_trace)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

log_r = 0
model_trace.compute_batch_log_pdf()
for site in model_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r + sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)
guide_trace.compute_score_parts()
for site in guide_trace.nodes.values():
if site["type"] == "sample":
check_site_shape(site, self.max_iarange_nesting)
log_r = log_r - sum_rightmost(site["batch_log_pdf"], self.max_iarange_nesting)

weight = scale / self.num_particles
yield weight, model_trace, guide_trace, log_r

def loss(self, model, guide, *args, **kwargs):
"""
:returns: returns an estimate of the ELBO
:rtype: float
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
"""
elbo = 0.0
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = sum_rightmost(model_site["batch_log_pdf"], self.max_iarange_nesting)
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_log_pdf
else:
guide_site = guide_trace.nodes[name]
guide_log_pdf = sum_rightmost(guide_site["batch_log_pdf"], self.max_iarange_nesting)
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch.zeros_like(elbo_particle)
else:
elbo_particle[weight == 0] = 0.0

elbo += torch_data_sum(weight * elbo_particle)

loss = -elbo
if is_nan(loss):
warnings.warn('Encountered NAN loss')
return loss

def loss_and_grads(self, model, guide, *args, **kwargs):
"""
:returns: returns an estimate of the ELBO
:rtype: float
Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator.
Performs backward on the latter. Num_particle many samples are used to form the estimators.
"""
elbo = 0.0
# grab a trace from the generator
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
surrogate_elbo_particle = weight * 0
# compute elbo and surrogate elbo
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = sum_rightmost(model_site["batch_log_pdf"], self.max_iarange_nesting)
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_log_pdf
surrogate_elbo_particle = surrogate_elbo_particle + model_log_pdf
else:
guide_site = guide_trace.nodes[name]
guide_log_pdf, score_function_term, entropy_term = guide_site["score_parts"]

guide_log_pdf = sum_rightmost(guide_log_pdf, self.max_iarange_nesting)
elbo_particle = elbo_particle + model_log_pdf - guide_log_pdf
surrogate_elbo_particle = surrogate_elbo_particle + model_log_pdf

if not is_identically_zero(entropy_term):
entropy_term = sum_rightmost(entropy_term, self.max_iarange_nesting)
surrogate_elbo_particle -= entropy_term

if not is_identically_zero(score_function_term):
score_function_term = sum_rightmost(score_function_term, self.max_iarange_nesting)
surrogate_elbo_particle = surrogate_elbo_particle + log_r.detach() * score_function_term

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch.zeros_like(elbo_particle)
surrogate_elbo_particle = torch.zeros_like(surrogate_elbo_particle)
else:
weight_eq_zero = (weight == 0)
elbo_particle[weight_eq_zero] = 0.0
surrogate_elbo_particle[weight_eq_zero] = 0.0

elbo += torch_data_sum(weight * elbo_particle)
surrogate_elbo_particle = torch_sum(weight * surrogate_elbo_particle)

# collect parameters to train from model and guide
trainable_params = set(site["value"]
for trace in (model_trace, guide_trace)
for site in trace.nodes.values()
if site["type"] == "param")

if trainable_params and getattr(surrogate_elbo_particle, 'requires_grad', False):
surrogate_loss_particle = -surrogate_elbo_particle
surrogate_loss_particle.backward()
pyro.get_param_store().mark_params_active(trainable_params)

loss = -elbo
if is_nan(loss):
warnings.warn('Encountered NAN loss')
return loss

0 comments on commit 4ff3ba0

Please sign in to comment.