Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro enum bugfix [WIP] #2226

Merged
merged 6 commits into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/distributions/dist_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def get_test_distribution_name(self):
def tensor_wrap(*args, **kwargs):
tensor_list, tensor_map = [], {}
for arg in args:
wrapped_arg = torch.tensor(arg) if isinstance(arg, list) else arg
wrapped_arg = torch.tensor(arg) if (isinstance(arg, list) or isinstance(arg, float)) else arg
tensor_list.append(wrapped_arg)
for k in kwargs:
kwarg = kwargs[k]
wrapped_kwarg = torch.tensor(kwarg) if isinstance(kwarg, list) else kwarg
wrapped_kwarg = torch.tensor(kwarg) if (isinstance(kwarg, list) or isinstance(kwarg, float)) else kwarg
tensor_map[k] = wrapped_kwarg
if args and not kwargs:
return tensor_list
Expand Down
36 changes: 36 additions & 0 deletions tests/infer/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyro.distributions as dist
import pyro.optim
import pyro.poutine as poutine
from pyro import infer
from pyro.distributions.testing.rejection_gamma import ShapeAugmentedGamma
from pyro.infer import SVI, config_enumerate
from pyro.infer.enum import iter_discrete_traces
Expand Down Expand Up @@ -3248,3 +3249,38 @@ def guide():
elbo = Trace_ELBO(vectorize_particles=True, num_particles=num_samples).loss(model, guide)

assert_equal(vectorized_weights.sum().item() / num_samples, -elbo, prec=0.02)


def test_multi_dependence_enumeration():
"""
This test checks whether enumeration works correctly in the case where multiple downstream
variables are coupled to the same random discrete variable.
This is based on [issue 2223](https://github.com/pyro-ppl/pyro/issues/2223), and should
pass when it has been resolved
"""
K = 5
d = 2
N_obs = 100

@infer.config_enumerate
def model(N=1):
with pyro.plate('data_plate', N, dim=-2):
mixing_weights = pyro.param('pi', torch.ones(K) / K, constraint=constraints.simplex)
means = pyro.sample('mu', dist.Normal(torch.zeros(K, d), torch.ones(K, d)).to_event(2))

with pyro.plate('observations', N_obs, dim=-1):
s = pyro.sample('s', dist.Categorical(mixing_weights))

pyro.sample('x', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1))
pyro.sample('y', dist.Normal(Vindex(means)[..., s, :], 0.1).to_event(1))

x = poutine.trace(model).get_trace(N=2).nodes['x']['value']

pyro.clear_param_store()
conditioned_model = pyro.condition(model, data={'x': x})
guide = infer.autoguide.AutoDelta(poutine.block(conditioned_model, hide=['s']))

elbo = infer.TraceEnum_ELBO(max_plate_nesting=2)

elbo.loss_and_grads(conditioned_model, guide, x.size(0))
assert pyro.get_param_store()._params['pi'].grad is not None