New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pyro enum bugfix [WIP] #2226
Pyro enum bugfix [WIP] #2226
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clearly reproducible test. I believe the following fixes it:
diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py
index 73e676fc..ae8f1905 100644
--- a/pyro/infer/traceenum_elbo.py
+++ b/pyro/infer/traceenum_elbo.py
@@ -98,7 +98,7 @@ def _compute_model_factors(model_trace, guide_trace):
if site["type"] == "sample":
if name in guide_trace.nodes:
cost_sites.setdefault(ordering[name], []).append(site)
- non_enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
+ non_enum_dims.update(guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims)
elif site["infer"].get("_enumerate_dim") is None:
cost_sites.setdefault(ordering[name], []).append(site)
else:
I believe the issue was that the model site's log_prob
gets extra dimensions during model-side enumeration due to a dependency on an upstream enumerated variable, but we were assuming that guided sites would not depend on model-enumerated variables. I believe a valid fix is to instead read plates from guide's log_prob
which cannot depend on model-enumerated sites. Does this make sense?
Do you want to add that in this PR or would you prefer me to do so in a separate PR?
Also some minor nits:
- revert unrelated changes in dist_fixture.py
- use
@config_enumerate
as in other tests in test_enum.py - reduce N_obs to something smaller like 3
Reduce N_obs and make use of config_enumerate consistent with the rest of the file.
Do you want me to put the change to dist_fixture in a separate PR? I kind of agree it doesn't make sense to bundle it with this one, as it's off topic, but I think its a real (minor) bug - I wasn't able to run the test suite locally without it. |
Yes, actually could you also open up an issue with the failed test output? I'm not sure whether fixing the fixture is the right solution, or whether that indicates a missing coercion in a distribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
As suggested here, this pull request adds a failing test which reproduces #2223. It currently makes no attempt to fix the issue, but should hopefully be useful when debugging.
If anything needs cleaning up please let me know. Help is appreciated - I'm not very familiar with the low level mechanics of the enumeration code atm.
I also found that the test for the multivariate student T was failing on cuda because it parameterises df as a float rather than a tensor, and then the samples inherit the device of df. The first commit here fixes this test by changing tensor_wrap to also wrap floats, sending these to cuda if wrapped in a tensors_default_to context. I squashed this here because it came up first when I was trying to hit my failing test.