Skip to content

Commit

Permalink
add brute force comparison for downstream costs test (#830)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak authored and fritzo committed Mar 1, 2018
1 parent c61c17e commit 18a08c1
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions tests/infer/test_compute_downstream_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math

import networkx
import pytest
import torch
from torch.autograd import Variable, variable
Expand All @@ -10,10 +11,59 @@
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.tracegraph_elbo import _compute_downstream_costs
from pyro.infer.util import MultiViewTensor
from pyro.poutine.util import prune_subsample_sites
from tests.common import assert_equal


def _brute_force_compute_downstream_costs(model_trace, guide_trace, #
model_vec_md_nodes, guide_vec_md_nodes, #
non_reparam_nodes):

guide_nodes = [x for x in guide_trace.nodes if guide_trace.nodes[x]["type"] == "sample"]
downstream_costs, downstream_guide_cost_nodes = {}, {}
stacks = model_trace.graph["vectorized_map_data_info"]['vec_md_stacks']

def n_compatible_indices(dest_node, source_node):
n_compatible = 0
for xframe, yframe in zip(stacks[source_node], stacks[dest_node]):
if xframe.name == yframe.name:
n_compatible += 1
return n_compatible

for node in guide_nodes:
downstream_costs[node] = MultiViewTensor(model_trace.nodes[node]['batch_log_pdf'] -
guide_trace.nodes[node]['batch_log_pdf'])
downstream_guide_cost_nodes[node] = set([node])

descendants = networkx.descendants(guide_trace._graph, node)

for desc in descendants:
dims_to_keep = n_compatible_indices(node, desc)
desc_mvt = MultiViewTensor(model_trace.nodes[desc]['batch_log_pdf'] -
guide_trace.nodes[desc]['batch_log_pdf'])
summed_desc = desc_mvt.sum_leftmost_all_but(dims_to_keep)
downstream_costs[node].add(summed_desc)
downstream_guide_cost_nodes[node].update([desc])

for site in non_reparam_nodes:
children_in_model = set()
for node in downstream_guide_cost_nodes[site]:
children_in_model.update(model_trace.successors(node))
children_in_model.difference_update(downstream_guide_cost_nodes[site])
for child in children_in_model:
assert (model_trace.nodes[child]["type"] == "sample")
dims_to_keep = n_compatible_indices(site, child)
summed_child = MultiViewTensor(model_trace.nodes[child]['batch_log_pdf']).sum_leftmost_all_but(dims_to_keep)
downstream_costs[site].add(summed_child)
downstream_guide_cost_nodes[site].update([child])

for k in downstream_costs:
downstream_costs[k] = downstream_costs[k].contract_to(guide_trace.nodes[k]['batch_log_pdf'])

return downstream_costs, downstream_guide_cost_nodes


def big_model_guide(include_obs=True, include_single=False, include_inner_1=False, flip_c23=False,
include_triple=False):
p0 = variable(math.exp(-0.20), requires_grad=True)
Expand Down Expand Up @@ -82,6 +132,12 @@ def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

assert dc_nodes == dc_nodes_brute

expected_nodes_full_model = {'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'},
'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'},
'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'},
Expand Down Expand Up @@ -161,6 +217,7 @@ def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_

for k in dc:
assert(guide_trace.nodes[k]['batch_log_pdf'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])


def diamond_model(dim):
Expand Down Expand Up @@ -208,6 +265,12 @@ def test_compute_downstream_costs_duplicates(dim):
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

assert dc_nodes == dc_nodes_brute

expected_a1 = (model_trace.nodes['a1']['batch_log_pdf'] - guide_trace.nodes['a1']['batch_log_pdf'])
for d in range(dim):
expected_a1 += model_trace.nodes['b{}'.format(d)]['batch_log_pdf']
Expand All @@ -232,6 +295,7 @@ def test_compute_downstream_costs_duplicates(dim):

for k in dc:
assert(guide_trace.nodes[k]['batch_log_pdf'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])


def nested_model_guide(include_obs=True, dim1=11, dim2=7):
Expand Down Expand Up @@ -273,6 +337,12 @@ def test_compute_downstream_costs_iarange_in_irange(dim1):
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

assert dc_nodes == dc_nodes_brute

expected_c1 = (model_trace.nodes['c1']['batch_log_pdf'] - guide_trace.nodes['c1']['batch_log_pdf'])
expected_c1 += model_trace.nodes['obs1']['batch_log_pdf']

Expand All @@ -294,3 +364,4 @@ def test_compute_downstream_costs_iarange_in_irange(dim1):

for k in dc:
assert(guide_trace.nodes[k]['batch_log_pdf'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])

0 comments on commit 18a08c1

Please sign in to comment.