Skip to content

Commit

Permalink
more downstream cost tests (#855)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak authored and fritzo committed Mar 6, 2018
1 parent ff88887 commit 2c62f77
Showing 1 changed file with 76 additions and 4 deletions.
80 changes: 76 additions & 4 deletions tests/infer/test_compute_downstream_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ def n_compatible_indices(dest_node, source_node):


def big_model_guide(include_obs=True, include_single=False, include_inner_1=False, flip_c23=False,
include_triple=False):
include_triple=False, include_z1=False):
p0 = variable(math.exp(-0.20), requires_grad=True)
p1 = variable(math.exp(-0.33), requires_grad=True)
p2 = variable(math.exp(-0.70), requires_grad=True)
if include_triple:
with pyro.iarange("iarange_triple1", 6) as ind_triple1:
with pyro.iarange("iarange_triple2", 7) as ind_triple2:
if include_z1:
pyro.sample("z1", dist.Bernoulli(p2).reshape(sample_shape=[
len(ind_triple2), len(ind_triple1)]))
with pyro.iarange("iarange_triple3", 9) as ind_triple3:
pyro.sample("z0", dist.Bernoulli(p2).reshape(sample_shape=[len(ind_triple3),
len(ind_triple2), len(ind_triple1)]))
Expand Down Expand Up @@ -104,15 +107,17 @@ def big_model_guide(include_obs=True, include_single=False, include_inner_1=Fals
@pytest.mark.parametrize("include_single", [True, False])
@pytest.mark.parametrize("flip_c23", [True, False])
@pytest.mark.parametrize("include_triple", [True, False])
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23, include_triple):
@pytest.mark.parametrize("include_z1", [True, False])
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23,
include_triple, include_z1):
guide_trace = poutine.trace(big_model_guide,
graph_type="dense").get_trace(include_obs=False, include_inner_1=include_inner_1,
include_single=include_single, flip_c23=flip_c23,
include_triple=include_triple)
include_triple=include_triple, include_z1=include_z1)
model_trace = poutine.trace(poutine.replay(big_model_guide, guide_trace),
graph_type="dense").get_trace(include_obs=True, include_inner_1=include_inner_1,
include_single=include_single, flip_c23=flip_c23,
include_triple=include_triple)
include_triple=include_triple, include_z1=include_z1)

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
Expand Down Expand Up @@ -365,3 +370,70 @@ def test_compute_downstream_costs_iarange_in_irange(dim1):
for k in dc:
assert(guide_trace.nodes[k]['batch_log_pdf'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])


def nested_model_guide2(include_obs=True, dim1=3, dim2=2):
p0 = variable(math.exp(-0.40 - include_obs * 0.2), requires_grad=True)
p1 = variable(math.exp(-0.33 - include_obs * 0.1), requires_grad=True)
pyro.sample("a1", dist.Bernoulli(p0 * p1))
with pyro.iarange("iarange", dim1) as ind:
c = pyro.sample("c", dist.Bernoulli(p1).reshape(sample_shape=[len(ind)]))
assert c.shape == (dim1,)
for i in pyro.irange("irange", dim2):
b_i = pyro.sample("b{}".format(i), dist.Bernoulli(p0).reshape(sample_shape=[len(ind)]))
assert b_i.shape == (dim1,)
if include_obs:
obs_i = pyro.sample("obs{}".format(i), dist.Bernoulli(b_i), obs=Variable(torch.ones(b_i.size())))
assert obs_i.shape == (dim1,)


@pytest.mark.parametrize("dim1", [2, 5])
@pytest.mark.parametrize("dim2", [3, 4])
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
guide_trace = poutine.trace(nested_model_guide2,
graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
model_trace = poutine.trace(poutine.replay(nested_model_guide2, guide_trace),
graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
model_trace.compute_batch_log_pdf()
guide_trace.compute_batch_log_pdf()

guide_vec_md_info = guide_trace.graph["vectorized_map_data_info"]
model_vec_md_info = model_trace.graph["vectorized_map_data_info"]
guide_vec_md_condition = guide_vec_md_info['rao-blackwellization-condition']
model_vec_md_condition = model_vec_md_info['rao-blackwellization-condition']
do_vec_rb = guide_vec_md_condition and model_vec_md_condition
guide_vec_md_nodes = guide_vec_md_info['nodes'] if do_vec_rb else set()
model_vec_md_nodes = model_vec_md_info['nodes'] if do_vec_rb else set()
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
model_vec_md_nodes, guide_vec_md_nodes,
non_reparam_nodes)

assert dc_nodes == dc_nodes_brute

for k in dc:
assert(guide_trace.nodes[k]['batch_log_pdf'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])

expected_b1 = model_trace.nodes['b1']['batch_log_pdf'] - guide_trace.nodes['b1']['batch_log_pdf']
expected_b1 += model_trace.nodes['obs1']['batch_log_pdf']
assert_equal(expected_b1, dc['b1'])

expected_c = model_trace.nodes['c']['batch_log_pdf'] - guide_trace.nodes['c']['batch_log_pdf']
for i in range(dim2):
expected_c += model_trace.nodes['b{}'.format(i)]['batch_log_pdf'] - \
guide_trace.nodes['b{}'.format(i)]['batch_log_pdf']
expected_c += model_trace.nodes['obs{}'.format(i)]['batch_log_pdf']
assert_equal(expected_c, dc['c'])

expected_a1 = model_trace.nodes['a1']['batch_log_pdf'] - guide_trace.nodes['a1']['batch_log_pdf']
expected_a1 += expected_c.sum()
assert_equal(expected_a1, dc['a1'])

0 comments on commit 2c62f77

Please sign in to comment.