Skip to content

Commit

Permalink
Support nested iarange in tracegraph_elbo (#780)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak authored and fritzo committed Feb 27, 2018
1 parent cbba2e7 commit 29654dc
Show file tree
Hide file tree
Showing 7 changed files with 635 additions and 224 deletions.
99 changes: 52 additions & 47 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import absolute_import, division, print_function

import warnings
from operator import itemgetter

import networkx
import numpy as np
import torch
from torch.autograd import variable

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer import ELBO
from pyro.infer.util import MultiViewTensor as MVT
from pyro.infer.util import torch_backward, torch_data_sum
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, detach_iterable, ng_zeros, is_nan
from pyro.util import check_model_guide_match, detach_iterable


def _get_baseline_options(site):
Expand Down Expand Up @@ -41,54 +45,65 @@ def _compute_downstream_costs(model_trace, guide_trace, #
topo_sort_guide_nodes = list(reversed(list(networkx.topological_sort(guide_trace))))
topo_sort_guide_nodes = [x for x in topo_sort_guide_nodes
if guide_trace.nodes[x]["type"] == "sample"]
ordered_guide_nodes_dict = {n: i for i, n in enumerate(topo_sort_guide_nodes)}

downstream_guide_cost_nodes = {}
downstream_costs = {}
stacks = model_trace.graph["vectorized_map_data_info"]['vec_md_stacks']

def n_compatible_indices(dest_node, source_node):
n_compatible = 0
for xframe, yframe in zip(stacks[source_node], stacks[dest_node]):
if xframe.name == yframe.name:
n_compatible += 1
return n_compatible

for node in topo_sort_guide_nodes:
node_log_pdf_key = 'batch_log_pdf' if node in guide_vec_md_nodes else 'log_pdf'
downstream_costs[node] = model_trace.nodes[node][node_log_pdf_key] - \
guide_trace.nodes[node][node_log_pdf_key]
downstream_costs[node] = MVT(model_trace.nodes[node]['batch_log_pdf'] -
guide_trace.nodes[node]['batch_log_pdf'])
nodes_included_in_sum = set([node])
downstream_guide_cost_nodes[node] = set([node])
for child in guide_trace.successors(node):
# make more efficient by ordering children appropriately (higher children first)
children = [(k, -ordered_guide_nodes_dict[k]) for k in guide_trace.successors(node)]
sorted_children = sorted(children, key=itemgetter(1))
for child, _ in sorted_children:
child_cost_nodes = downstream_guide_cost_nodes[child]
downstream_guide_cost_nodes[node].update(child_cost_nodes)
if nodes_included_in_sum.isdisjoint(child_cost_nodes): # avoid duplicates
if node_log_pdf_key == 'log_pdf':
downstream_costs[node] += downstream_costs[child].sum()
else:
downstream_costs[node] += downstream_costs[child]
dims_to_keep = n_compatible_indices(node, child)
summed_child = downstream_costs[child].sum_leftmost_all_but(dims_to_keep)
downstream_costs[node].add(summed_child)
# XXX nodes_included_in_sum logic could be more fine-grained, possibly leading
# to speed-ups in case there are many duplicates
nodes_included_in_sum.update(child_cost_nodes)
missing_downstream_costs = downstream_guide_cost_nodes[node] - nodes_included_in_sum
# include terms we missed because we had to avoid duplicates
for missing_node in missing_downstream_costs:
mn_log_pdf_key = 'batch_log_pdf' if missing_node in guide_vec_md_nodes else 'log_pdf'
if node_log_pdf_key == 'log_pdf':
downstream_costs[node] += (model_trace.nodes[missing_node][mn_log_pdf_key] -
guide_trace.nodes[missing_node][mn_log_pdf_key]).sum()
else:
downstream_costs[node] += model_trace.nodes[missing_node][mn_log_pdf_key] - \
guide_trace.nodes[missing_node][mn_log_pdf_key]
missing_term = MVT(model_trace.nodes[missing_node]['batch_log_pdf'] -
guide_trace.nodes[missing_node]['batch_log_pdf'])
dims_to_keep = n_compatible_indices(node, missing_node)
summed_missing_term = missing_term.sum_leftmost_all_but(dims_to_keep)
downstream_costs[node].add(summed_missing_term)

# finish assembling complete downstream costs
# (the above computation may be missing terms from model)
# XXX can we cache some of the sums over children_in_model to make things more efficient?
for site in non_reparam_nodes:
children_in_model = set()
for node in downstream_guide_cost_nodes[site]:
children_in_model.update(model_trace.successors(node))
# remove terms accounted for above
children_in_model.difference_update(downstream_guide_cost_nodes[site])
for child in children_in_model:
child_log_pdf_key = 'batch_log_pdf' if child in model_vec_md_nodes else 'log_pdf'
site_log_pdf_key = 'batch_log_pdf' if site in guide_vec_md_nodes else 'log_pdf'
assert (model_trace.nodes[child]["type"] == "sample")
if site_log_pdf_key == 'log_pdf':
downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key].sum()
else:
downstream_costs[site] += model_trace.nodes[child][child_log_pdf_key]
dims_to_keep = n_compatible_indices(site, child)
summed_child = MVT(model_trace.nodes[child]['batch_log_pdf']).sum_leftmost_all_but(dims_to_keep)
downstream_costs[site].add(summed_child)
downstream_guide_cost_nodes[site].update([child])

for k in topo_sort_guide_nodes:
downstream_costs[k] = downstream_costs[k].contract_to(guide_trace.nodes[k]['batch_log_pdf'])

return downstream_costs
return downstream_costs, downstream_guide_cost_nodes


def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes):
Expand All @@ -110,7 +125,7 @@ def _compute_elbo_reparam(model_trace, guide_trace, non_reparam_nodes):
if not is_identically_zero(entropy_term):
surrogate_elbo -= entropy_term.sum()

# elbo is never differentiated, surragate_elbo is
# elbo is never differentiated, surrogate_elbo is

return torch_data_sum(elbo), surrogate_elbo

Expand All @@ -135,11 +150,13 @@ def _compute_elbo_non_reparam(guide_trace, guide_vec_md_nodes, #
assert(not (use_nn_baseline and use_baseline_value)), \
"cannot use baseline_value and nn_baseline simultaneously"
if use_decaying_avg_baseline:
dc_shape = downstream_cost.shape
avg_downstream_cost_old = pyro.param("__baseline_avg_downstream_cost_" + node,
ng_zeros(1), tags="__tracegraph_elbo_internal_tag")
avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \
variable(0.0).expand(dc_shape).clone(),
tags="__tracegraph_elbo_internal_tag")
avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost.detach() + \
baseline_beta * avg_downstream_cost_old
avg_downstream_cost_old.data = avg_downstream_cost_new.data # XXX copy_() ?
avg_downstream_cost_old.copy_(avg_downstream_cost_new) # XXX is this copy_() what we want?
baseline += avg_downstream_cost_old
if use_nn_baseline:
# block nn_baseline_input gradients except in baseline loss
Expand All @@ -155,8 +172,7 @@ def _compute_elbo_non_reparam(guide_trace, guide_vec_md_nodes, #
if log_pdf_key == 'log_pdf':
score_function_term = score_function_term.sum()
if use_nn_baseline or use_decaying_avg_baseline or use_baseline_value:
if (downstream_cost.dim() == 0 and baseline.dim() > 1) or \
(downstream_cost.dim() > 0 and downstream_cost.size() != baseline.size()):
if downstream_cost.size() != baseline.size():
raise ValueError("Expected baseline at site {} to be {} instead got {}".format(
node, downstream_cost.size(), baseline.size()))
downstream_cost = downstream_cost - baseline
Expand Down Expand Up @@ -226,7 +242,7 @@ def loss(self, model, guide, *args, **kwargs):
elbo += torch_data_sum(weight * elbo_particle)

loss = -elbo
if is_nan(loss):
if np.isnan(loss):
warnings.warn('Encountered NAN loss')
return loss

Expand All @@ -246,19 +262,8 @@ def loss_and_grads(self, model, guide, *args, **kwargs):

def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
# get info regarding rao-blackwellization of vectorized map_data
guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"]
model_vec_md_info = model_trace.graph["vectorized_map_data_info"]
guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition']
model_vec_md_condition = model_vec_md_info['rao-blackwellization-condition']
do_vec_rb = guide_vec_md_condition and model_vec_md_condition
if not do_vec_rb:
warnings.warn(
"Unable to do fully-vectorized Rao-Blackwellization in TraceGraph_ELBO. "
"Falling back to higher-variance gradient estimator. "
"Try to avoid these issues in your model and guide:\n{}".format("\n".join(
guide_vec_md_info["warnings"] | model_vec_md_info["warnings"])))
guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set()
model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set()
guide_vec_md_nodes = guide_trace.graph["vectorized_map_data_info"]['nodes']
model_vec_md_nodes = model_trace.graph["vectorized_map_data_info"]['nodes']

# have the trace compute all the individual (batch) log pdf terms
# and score function terms (if present) so that they are available below
Expand All @@ -272,7 +277,7 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
# the following computations are only necessary if we have non-reparameterizable nodes
baseline_loss = 0.0
if non_reparam_nodes:
downstream_costs = _compute_downstream_costs(
downstream_costs, _ = _compute_downstream_costs(
model_trace, guide_trace, model_vec_md_nodes, guide_vec_md_nodes, non_reparam_nodes)
surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
guide_trace, guide_vec_md_nodes, non_reparam_nodes, downstream_costs)
Expand All @@ -290,6 +295,6 @@ def _loss_and_grads_particle(self, weight, model_trace, guide_trace):
pyro.get_param_store().mark_params_active(trainable_params)

loss = -elbo
if is_nan(loss):
if np.isnan(loss):
warnings.warn('Encountered NAN loss')
return weight * loss
76 changes: 76 additions & 0 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numbers

import torch
from torch.autograd import Variable

from pyro.distributions.util import sum_leftmost


def torch_data_sum(x):
Expand Down Expand Up @@ -32,3 +35,76 @@ def torch_backward(x):
"""
if isinstance(x, torch.autograd.Variable):
x.backward()


def reduce_to_target(source, target):
"""
Sums out any dimensions in source that are of size > 1 in source but of size 1 in target.
This preserves source.dim().
"""
if source.dim() > target.dim():
raise ValueError
for k in range(1, 1 + source.dim()):
if source.size(-k) > target.size(-k):
source = source.sum(-k, keepdim=True)
return source


class MultiViewTensor(dict):
"""
A container for Variables with different shapes. Used in TraceGraph_ELBO
to simplify downstream cost computation logic.
Example::
downstream_cost = MultiViewTensor()
downstream_cost.add(self.cost)
for node in downstream_nodes:
summed = node.downstream_cost.sum_leftmost(dims)
downstream_cost.add(summed)
"""
def __init__(self, value=None):
if value is not None:
if isinstance(value, Variable):
self[value.shape] = value

def add(self, term):
"""
Add tensor to collection of tensors stored in MultiViewTensor; key by shape
"""
if isinstance(term, Variable):
if term.shape in self:
self[term.shape] = self[term.shape] + term
else:
self[term.shape] = term
else:
for shape, value in term.items():
if shape in self:
self[shape] = self[shape] + value
else:
self[shape] = value

def sum_leftmost_all_but(self, dim):
"""
This behaves like sum_leftmost(term, -dim) except for dim=0 where everything is summed out
"""
assert dim >= 0
result = MultiViewTensor()
for shape, term in self.items():
if dim == 0:
result.add(term.sum())
elif dim > term.dim():
result.add(term)
else:
result.add(sum_leftmost(term, -dim))
return result

def contract_to(self, target):
"""Opposite of broadcast."""
result = 0
for tensor in self.values():
result = result + reduce_to_target(tensor, target)
return result

def __repr__(self):
return '%s(%s)' % (type(self).__name__, ", ".join([str(k) for k in self.keys()]))
28 changes: 20 additions & 8 deletions pyro/poutine/trace_poutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ def get_vectorized_map_data_info(trace):

vectorized_map_data_info = {'rao-blackwellization-condition': True, 'warnings': set()}
vec_md_stacks = set()
stack_dict = {}

for name, node in nodes.items():
if site_is_subsample(node):
continue
if node["type"] in ("sample", "param"):
stack = tuple(node["cond_indep_stack"])
vec_mds = [x for x in stack if x.vectorized]
stack_dict[name] = vec_mds

for name, node in nodes.items():
if site_is_subsample(node):
continue
if node["type"] in ("sample", "param"):
stack = tuple(node["cond_indep_stack"])
vec_mds = [x for x in stack if x.vectorized]
stack_dict[name] = vec_mds
# check for nested vectorized map datas
if len(vec_mds) > 1:
vectorized_map_data_info['rao-blackwellization-condition'] = False
Expand Down Expand Up @@ -58,15 +68,17 @@ def get_vectorized_map_data_info(trace):
vectorized_map_data_info['warnings'].add('there exist dependent iaranges')
break

vec_md_stacks = list(vec_md_stacks)
vectorized_map_data_info['vec_md_stacks'] = stack_dict

# construct data structure consumed by tracegraph_kl_qp
if vectorized_map_data_info['rao-blackwellization-condition']:
vectorized_map_data_info['nodes'] = set()
for name, node in nodes.items():
if site_is_subsample(node):
continue
if node["type"] in ("sample", "param"):
if any(x.vectorized for x in node["cond_indep_stack"]):
vectorized_map_data_info['nodes'].add(name)
vectorized_map_data_info['nodes'] = set()
for name, node in nodes.items():
if site_is_subsample(node):
continue
if node["type"] in ("sample", "param"):
if any(x.vectorized for x in node["cond_indep_stack"]):
vectorized_map_data_info['nodes'].add(name)

return vectorized_map_data_info

Expand Down

0 comments on commit 29654dc

Please sign in to comment.