Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyro enum bugfix [WIP] #2226

Merged
merged 6 commits into from Dec 20, 2019
Merged

Pyro enum bugfix [WIP] #2226

merged 6 commits into from Dec 20, 2019

Conversation

lsgos
Copy link
Contributor

@lsgos lsgos commented Dec 17, 2019

As suggested here, this pull request adds a failing test which reproduces #2223. It currently makes no attempt to fix the issue, but should hopefully be useful when debugging.

If anything needs cleaning up please let me know. Help is appreciated - I'm not very familiar with the low level mechanics of the enumeration code atm.

I also found that the test for the multivariate student T was failing on cuda because it parameterises df as a float rather than a tensor, and then the samples inherit the device of df. The first commit here fixes this test by changing tensor_wrap to also wrap floats, sending these to cuda if wrapped in a tensors_default_to context. I squashed this here because it came up first when I was trying to hit my failing test.

@claassistantio
Copy link

claassistantio commented Dec 17, 2019

CLA assistant check
All committers have signed the CLA.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clearly reproducible test. I believe the following fixes it:

diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py
index 73e676fc..ae8f1905 100644
--- a/pyro/infer/traceenum_elbo.py
+++ b/pyro/infer/traceenum_elbo.py
@@ -98,7 +98,7 @@ def _compute_model_factors(model_trace, guide_trace):
         if site["type"] == "sample":
             if name in guide_trace.nodes:
                 cost_sites.setdefault(ordering[name], []).append(site)
-                non_enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
+                non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims)
             elif site["infer"].get("_enumerate_dim") is None:
                 cost_sites.setdefault(ordering[name], []).append(site)
             else:

I believe the issue was that the model site's log_prob gets extra dimensions during model-side enumeration due to a dependency on an upstream enumerated variable, but we were assuming that guided sites would not depend on model-enumerated variables. I believe a valid fix is to instead read plates from guide's log_prob which cannot depend on model-enumerated sites. Does this make sense?

Do you want to add that in this PR or would you prefer me to do so in a separate PR?

Also some minor nits:

  • revert unrelated changes in dist_fixture.py
  • use @config_enumerate as in other tests in test_enum.py
  • reduce N_obs to something smaller like 3

Reduce N_obs and make use of config_enumerate consistent with the rest
of the file.
@lsgos
Copy link
Contributor Author

lsgos commented Dec 17, 2019

Do you want me to put the change to dist_fixture in a separate PR? I kind of agree it doesn't make sense to bundle it with this one, as it's off topic, but I think its a real (minor) bug - I wasn't able to run the test suite locally without it.

@fritzo
Copy link
Member

fritzo commented Dec 17, 2019

Do you want me to put the change to dist_fixture in a separate PR?

Yes, actually could you also open up an issue with the failed test output? I'm not sure whether fixing the fixture is the right solution, or whether that indicates a missing coercion in a distribution .__init__() method which we should fix instead. Thanks!

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

@fritzo fritzo merged commit 2aaf8fc into pyro-ppl:dev Dec 20, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants