Skip to content

Commit

Permalink
Pyro enum bugfix [WIP] (#2226)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsgos authored and fritzo committed Dec 20, 2019
1 parent fae9c31 commit 2aaf8fc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyro/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _compute_model_factors(model_trace, guide_trace):
if site["type"] == "sample":
if name in guide_trace.nodes:
cost_sites.setdefault(ordering[name], []).append(site)
non_enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims)
elif site["infer"].get("_enumerate_dim") is None:
cost_sites.setdefault(ordering[name], []).append(site)
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/infer/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
from pyro import infer
from pyro.distributions.testing.rejection_gamma import ShapeAugmentedGamma
from pyro.infer import SVI, config_enumerate
from pyro.infer.enum import iter_discrete_traces
Expand Down Expand Up @@ -3248,3 +3249,38 @@ def guide():
elbo = Trace_ELBO(vectorize_particles=True, num_particles=num_samples).loss(model, guide)

assert_equal(vectorized_weights.sum().item() / num_samples, -elbo, prec=0.02)


def test_multi_dependence_enumeration():
"""
This test checks whether enumeration works correctly in the case where multiple downstream
variables are coupled to the same random discrete variable.
This is based on [issue 2223](https://github.com/pyro-ppl/pyro/issues/2223), and should
pass when it has been resolved
"""
K = 5
d = 2
N_obs = 3

@config_enumerate
def model(N=1):
with pyro.plate('data_plate', N, dim=-2):
mixing_weights = pyro.param('pi', torch.ones(K) / K, constraint=constraints.simplex)
means = pyro.sample('mu', dist.Normal(torch.zeros(K, d), torch.ones(K, d)).to_event(2))

with pyro.plate('observations', N_obs, dim=-1):
s = pyro.sample('s', dist.Categorical(mixing_weights))

pyro.sample('x', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1))
pyro.sample('y', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1))

x = poutine.trace(model).get_trace(N=2).nodes['x']['value']

pyro.clear_param_store()
conditioned_model = pyro.condition(model, data={'x': x})
guide = infer.autoguide.AutoDelta(poutine.block(conditioned_model, hide=['s']))

elbo = infer.TraceEnum_ELBO(max_plate_nesting=2)

elbo.loss_and_grads(conditioned_model, guide, x.size(0))
assert pyro.get_param_store()._params['pi'].grad is not None

0 comments on commit 2aaf8fc

Please sign in to comment.