Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro enum bugfix [WIP] #2226

Merged
merged 6 commits into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/infer/traceenum_elbo.py
Expand Up @@ -98,7 +98,7 @@ def _compute_model_factors(model_trace, guide_trace):
if site["type"] == "sample":
if name in guide_trace.nodes:
cost_sites.setdefault(ordering[name], []).append(site)
non_enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims)
elif site["infer"].get("_enumerate_dim") is None:
cost_sites.setdefault(ordering[name], []).append(site)
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/infer/test_enum.py
Expand Up @@ -13,6 +13,7 @@
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
from pyro import infer
from pyro.distributions.testing.rejection_gamma import ShapeAugmentedGamma
from pyro.infer import SVI, config_enumerate
from pyro.infer.enum import iter_discrete_traces
Expand Down Expand Up @@ -3248,3 +3249,38 @@ def guide():
elbo = Trace_ELBO(vectorize_particles=True, num_particles=num_samples).loss(model, guide)

assert_equal(vectorized_weights.sum().item() / num_samples, -elbo, prec=0.02)


def test_multi_dependence_enumeration():
"""
This test checks whether enumeration works correctly in the case where multiple downstream
variables are coupled to the same random discrete variable.
This is based on [issue 2223](https://github.com/pyro-ppl/pyro/issues/2223), and should
pass when it has been resolved
"""
K = 5
d = 2
N_obs = 3

@config_enumerate
def model(N=1):
with pyro.plate('data_plate', N, dim=-2):
mixing_weights = pyro.param('pi', torch.ones(K) / K, constraint=constraints.simplex)
means = pyro.sample('mu', dist.Normal(torch.zeros(K, d), torch.ones(K, d)).to_event(2))

with pyro.plate('observations', N_obs, dim=-1):
s = pyro.sample('s', dist.Categorical(mixing_weights))

pyro.sample('x', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1))
pyro.sample('y', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1))

x = poutine.trace(model).get_trace(N=2).nodes['x']['value']

pyro.clear_param_store()
conditioned_model = pyro.condition(model, data={'x': x})
guide = infer.autoguide.AutoDelta(poutine.block(conditioned_model, hide=['s']))

elbo = infer.TraceEnum_ELBO(max_plate_nesting=2)

elbo.loss_and_grads(conditioned_model, guide, x.size(0))
assert pyro.get_param_store()._params['pi'].grad is not None