Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use provenance tracking to compute downstream costs in TraceGraph_ELBO #3081

Merged
merged 9 commits into from May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/air/main.py
Expand Up @@ -50,7 +50,7 @@ def count_vec_to_mat(vec, max_index):
true_counts_m = count_vec_to_mat(true_counts_batch, 2)
inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
counts += torch.mm(true_counts_m.t(), inferred_counts_m)
error_ind = 1 - (true_counts_batch == inferred_counts)
error_ind = 1 - (true_counts_batch == inferred_counts).long()
error_ix = error_ind.nonzero(as_tuple=False).squeeze()
error_latents.append(
latents_to_tensor((z_where, z_pres)).index_select(0, error_ix)
Expand Down
115 changes: 80 additions & 35 deletions pyro/infer/tracegraph_elbo.py
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import weakref
from collections import defaultdict
from operator import itemgetter

import torch
Expand All @@ -18,6 +19,9 @@
torch_backward,
torch_item,
)
from pyro.ops.provenance import detach_provenance, get_provenance, track_provenance
from pyro.poutine.messenger import Messenger
from pyro.poutine.subsample_messenger import _Subsample
from pyro.util import check_if_enumerated, warn_if_nan


Expand Down Expand Up @@ -172,7 +176,7 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes): #
return downstream_costs, downstream_guide_cost_nodes


def _compute_elbo_reparam(model_trace, guide_trace):
def _compute_elbo(model_trace, guide_trace):

# In ref [1], section 3.2, the part of the surrogate loss computed here is
# \sum{cost}, which in this case is the ELBO. Instead of using the ELBO,
Expand All @@ -182,12 +186,18 @@ def _compute_elbo_reparam(model_trace, guide_trace):

elbo = 0.0
surrogate_elbo = 0.0
baseline_loss = 0.0
# mapping from non-reparameterizable sample sites to cost terms influenced by each of them
downstream_costs = defaultdict(lambda: MultiFrameTensor())

# Bring log p(x, z|...) terms into both the ELBO and the surrogate
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
elbo += site["log_prob_sum"]
surrogate_elbo += site["log_prob_sum"]
# add the log_prob to each non-reparam sample site upstream
for key in get_provenance(site["log_prob_sum"]):
downstream_costs[key].add((site["cond_indep_stack"], site["log_prob"]))

# Bring log q(z|...) terms into the ELBO, and effective terms into the
# surrogate. Depending on the parameterization of a site, its log q(z|...)
Expand All @@ -202,19 +212,16 @@ def _compute_elbo_reparam(model_trace, guide_trace):
# For fully non-reparameterized terms, it is zero
if not is_identically_zero(entropy_term):
surrogate_elbo -= entropy_term.sum()
# add the -log_prob to each non-reparam sample site upstream
for key in get_provenance(site["log_prob_sum"]):
downstream_costs[key].add((site["cond_indep_stack"], -site["log_prob"]))

return elbo, surrogate_elbo


def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs):
# construct all the reinforce-like terms.
# we include only downstream costs to reduce variance
# optionally include baselines to further reduce variance
surrogate_elbo = 0.0
baseline_loss = 0.0
for node in non_reparam_nodes:
for node, downstream_cost in downstream_costs.items():
guide_site = guide_trace.nodes[node]
downstream_cost = downstream_costs[node]
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
score_function = guide_site["score_parts"].score_function

use_baseline, baseline_loss_term, baseline = _construct_baseline(
Expand All @@ -227,7 +234,59 @@ def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs):

surrogate_elbo += (score_function * downstream_cost.detach()).sum()

return surrogate_elbo, baseline_loss
surrogate_loss = -surrogate_elbo + baseline_loss
return detach_provenance(elbo), detach_provenance(surrogate_loss)


class TrackNonReparam(Messenger):
"""
Track non-reparameterizable sample sites.

**References:**

1. *Nonstandard Interpretations of Probabilistic Programs for Efficient Inference*,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

**Example:**

.. doctest::

>>> import torch
>>> import pyro
>>> import pyro.distributions as dist
>>> from pyro.infer.tracegraph_elbo import TrackNonReparam
>>> from pyro.ops.provenance import get_provenance
>>> from pyro.poutine import trace

>>> def model():
... probs_a = torch.tensor([0.3, 0.7])
... probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]])
... probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]])
... a = pyro.sample("a", dist.Categorical(probs_a))
... b = pyro.sample("b", dist.Categorical(probs_b[a]))
... pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0))

>>> with TrackNonReparam():
... model_tr = trace(model).get_trace()
>>> model_tr.compute_log_prob()

>>> print(get_provenance(model_tr.nodes["a"]["log_prob"])) # doctest: +SKIP
frozenset({'a'})
>>> print(get_provenance(model_tr.nodes["b"]["log_prob"])) # doctest: +SKIP
frozenset({'b', 'a'})
>>> print(get_provenance(model_tr.nodes["c"]["log_prob"])) # doctest: +SKIP
frozenset({'b', 'a'})
"""

def _pyro_post_sample(self, msg):
if (
msg["type"] == "sample"
and not isinstance(msg["fn"], _Subsample)
and not msg["is_observed"]
and not getattr(msg["fn"], "has_rsample", False)
):
provenance = frozenset({msg["name"]})
msg["value"] = track_provenance(msg["value"], provenance)


class TraceGraph_ELBO(ELBO):
Expand All @@ -236,13 +295,10 @@ class TraceGraph_ELBO(ELBO):
is constructed along the lines of reference [1] specialized to the case
of the ELBO. It supports arbitrary dependency structure for the model
and guide as well as baselines for non-reparameterizable random variables.
Where possible, conditional dependency information as recorded in the
Fine-grained conditional dependency information as recorded in the
:class:`~pyro.poutine.trace.Trace` is used to reduce the variance of the gradient estimator.
In particular two kinds of conditional dependency information are
used to reduce variance:

- the sequential order of samples (z is sampled after y => y does not depend on z)
- :class:`~pyro.plate` generators
In particular provenance tracking [3] is used to find the ``cost`` terms
that depend on each non-reparameterizable sample site.

References

Expand All @@ -251,16 +307,20 @@ class TraceGraph_ELBO(ELBO):

[2] `Neural Variational Inference and Learning in Belief Networks`
Andriy Mnih, Karol Gregor

[3] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
"""

def _get_trace(self, model, guide, args, kwargs):
"""
Returns a single trace from the guide, and the model that is run
against it.
"""
model_trace, guide_trace = get_importance_trace(
"dense", self.max_plate_nesting, model, guide, args, kwargs
)
with TrackNonReparam():
model_trace, guide_trace = get_importance_trace(
"dense", self.max_plate_nesting, model, guide, args, kwargs
)
if is_validation_enabled():
check_if_enumerated(guide_trace)
return model_trace, guide_trace
Expand Down Expand Up @@ -319,22 +379,7 @@ def _loss_and_surrogate_loss(self, model, guide, args, kwargs):

def _loss_and_surrogate_loss_particle(self, model_trace, guide_trace):

# compute elbo for reparameterized nodes
elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace)
baseline_loss = 0.0

# the following computations are only necessary if we have non-reparameterizable nodes
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
if non_reparam_nodes:
downstream_costs, _ = _compute_downstream_costs(
model_trace, guide_trace, non_reparam_nodes
)
surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
guide_trace, non_reparam_nodes, downstream_costs
)
surrogate_elbo += surrogate_elbo_term

surrogate_loss = -surrogate_elbo + baseline_loss
elbo, surrogate_loss = _compute_elbo(model_trace, guide_trace)

return elbo, surrogate_loss

Expand Down
6 changes: 6 additions & 0 deletions pyro/ops/provenance.py
Expand Up @@ -104,6 +104,12 @@ def _track_provenance_list(x, provenance: frozenset):
return type(x)(track_provenance(part, provenance) for part in x)


@track_provenance.register
def _track_provenance_provenancetensor(x: ProvenanceTensor, provenance: frozenset):
x_value, old_provenance = extract_provenance(x)
return track_provenance(x_value, old_provenance | provenance)


@singledispatch
def extract_provenance(x) -> Tuple[object, frozenset]:
"""
Expand Down