-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidating Poutines, Traces, and control flow #38
Conversation
… klqp that uses it
…ces, improve thread-safety, remove some global variables, and most importantly enable nesting
pyro/poutine/replay_poutine.py
Outdated
self.guide_trace = guide_trace | ||
self.all_sites = False | ||
# case 1: no sites | ||
if sites is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not do something like?
if sites is None:
sites = self.guide_trace.keys()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I guess I could do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: All of my one line comments are 4 line comments for some reason... im referring to the last line in the selected blocks of code.
github doesnt seem to be able to smartly handle commits during a review so please disregard any lingering comments that no longer apply...
looks good, this seems to be the right abstractions for trace and also modularizes poutines so theyre testable!
just some minor touchups; also can you fix the broken travis tests?
not necessarily for this PR but as a general note, currently our unit tests are 100% happy cases - we should also test if the program fails noisily when we expect it to.
pyro/infer/trace.py
Outdated
@@ -0,0 +1,101 @@ | |||
import pyro |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: should we move traces out of infer in their own module like poutine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I moved them from pyro.infer
to pyro.poutine
for now.
pyro/infer/trace.py
Outdated
log_p = 0.0 | ||
for name in self.keys(): | ||
if self[name]["type"] in ("observe", "sample"): | ||
self[name]["log_pdf"] = self[name]["fn"].log_pdf( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this calls log_pdf
instead of batch_log_pdf
as the current KLQP
does... shouldnt this call batch until we decide if we are going to deprecate/remove one? Currently batch should handle 1-element tensors properly but log_pdf will error on vectorized input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a batch_log_pdf
method
pyro/infer/trace_search.py
Outdated
@@ -0,0 +1,42 @@ | |||
import pyro | |||
import torch | |||
from queue import Queue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesnt work in python2 or 3 for me unless i do from multiprocessing import Queue
and travis has the same problem
pyro/poutine/__init__.py
Outdated
return a return value from a complete trace in the queue | ||
""" | ||
def _fn(*args, **kwargs): | ||
p = BeamPoutine(fn, queue=queue, max_tries=max_tries) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BeamPoutine
is NotImplemented?
pyro/poutine/queue_poutine.py
Outdated
@@ -0,0 +1,80 @@ | |||
import pyro | |||
import torch | |||
from queue import Queue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesnt work in python2 or 3 for me unless i do from multiprocessing import Queue
and travis has the same problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ill push the fix, assuming you dont mind i push to your branch
pyro/poutine/poutine.py
Outdated
""" | ||
Default pyro.map_data Poutine behavior | ||
""" | ||
if self.transparent and not (prev_val is None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
pyro/poutine/poutine.py
Outdated
if self.transparent and not (prev_val is None): | ||
return prev_val | ||
else: | ||
if isinstance(data, torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to check Variable
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think both work because of inheritance, but done
pyro/poutine/poutine.py
Outdated
else: | ||
if obs is None: | ||
return fn(*args, **kwargs) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L115: dont need this else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
pyro/poutine/poutine.py
Outdated
val = fn(*args, **kwargs) | ||
if self.transparent and not (prev_val is None): | ||
return prev_val | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edit: idk why github keeps having these multiline highlights when i only click one line... my comment was you dont need the else:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
|
||
from tests.common import TestCase | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont necessarily need to do this in this PR, but we should have some that negative test cases that test if the proper exceptions are being thrown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a good idea. Some tests like that already exist but they could be more comprehensive. I'll punt on this for now.
Major refactor of internals over the last couple of days, focused on consolidating poutines and traces. Includes cleaned-up version of systematic search, but no examples for it yet. Inference interfaces and expected behavior remain unchanged.
Start by looking at
pyro/poutine/__init__.py
Context: I originally set out to consolidate the various Trace data structures, but in the process I realized the poutines were a bit of an untested mess that were holding up my development progress and making inference code needlessly complicated and repetitive. So there I was in outer Mongolia shaving a yak...
In this branch, I've been separating out some abstract control operations from the algorithms and putting them in a single unified poutine interface. I've also changed the way poutines work, from monkeypatching the primitive definitions to living on a dedicated control stack. This achieves a number of things: it massively simplifies inference code by putting the most complicated parts of inference code (control flow) behind an extremely simple interface and makes those parts reusable, modular, and testable, it simplifies the structure of the codebase, and it enables nesting of poutines so that we can do things like logging, visualization, and causal intervention. It does not change the interface or expected behavior of the inference algorithms, so everything else should still work.
Addresses or related to #19 , #2 , #28 , #20 and #1