-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorized map_data #62
Conversation
…de from replay_poutine
…aced with monkeypatching in base poutine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a quick first look on my end. as you pointed out, lacking tests.
pyro/poutine/poutine.py
Outdated
""" | ||
Default pyro.map_data Poutine behavior | ||
""" | ||
if self.transparent and prev_val is not None: | ||
if self.transparent and not (prev_val is None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason for this change?
pyro/poutine/poutine.py
Outdated
else: | ||
if batch_size is None: | ||
batch_size = 0 | ||
assert batch_size >= 0, "cannot have negative batch sizes" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have batch_size
default to 0 if it's negative?
pyro/poutine/poutine.py
Outdated
|
||
def scaled_sample(_prev_val, _name, _fn, *args, **kwargs): | ||
return old_sample(_prev_val, _name, | ||
pyro.util.rescale_dist(_fn, scale), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import rescale_dist
from pyro.util so you dont have to write it out (as other classes do as well)
pyro/poutine/poutine.py
Outdated
|
||
self._pyro_sample = scaled_sample | ||
ret = list(map(lambda ix: fn(*ix), [(i, data[i]) for i in ind])) | ||
self._pyro_sample = old_sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment here explaining multiplying the scaling factor then resetting it
pyro/poutine/trace.py
Outdated
@@ -40,14 +40,19 @@ def add_observe(self, name, val, fn, obs, *args, **kwargs): | |||
self[name] = site | |||
return self | |||
|
|||
def add_map_data(self, name, data, fn): | |||
def add_map_data(self, name, fn, batch_size, scale, ind, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason you added **kwargs
but not *args
?
pyro/poutine/trace.py
Outdated
site["scale"] = scale | ||
site["fn"] = fn | ||
# site["value"] = val # XXX too large to store | ||
# site["args"] = ((), kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whats this for?
pyro/util.py
Outdated
def new_log_pdf(*args, **kwargs): | ||
return old_log_pdf(*args, **kwargs) * scale | ||
|
||
new_fn = copy.copy(fn) # XXX incorrect? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is shallow sufficient for this?
pyro/util.py
Outdated
if hasattr(fn, "log_pdf"): | ||
old_log_pdf = fn.log_pdf | ||
|
||
def new_log_pdf(*args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scaled_log_pdf
might be a more descriptive name
pyro/util.py
Outdated
new_fn.log_pdf = new_log_pdf | ||
return new_fn | ||
else: | ||
# XXX should raise an error here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or at the very least a warning since it might help catch bugs that fall here
pyro/poutine/poutine.py
Outdated
|
||
old_sample = self._pyro_sample | ||
|
||
def scaled_sample(_prev_val, _name, _fn, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either here, or in rescale_dist, check for scale == 1.0 and avoid unnecessary wrapping?
pyro/poutine/poutine.py
Outdated
if batch_size is None: | ||
batch_size = 0 | ||
assert batch_size >= 0, "cannot have negative batch sizes" | ||
if hasattr(fn, "__map_data_indices") and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it strange to store this information as attributes of fn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, very strange, but I'm not sure if there's a more elegant way to do it. Naively, we'd want to put the rescaling and subsampling into a new Poutine
subclass, but because map_data
is implemented in the base Poutine
, you'd be using an instance of a child class inside one of the base class methods. I imagine there's a way to do this, since I'm pretty sure mutual recursion with functions is OK, but I actually tried this initially and got a bunch of mysterious ImportError
s related to that which I was unable to clear up, so here we are.
Suggestions welcome!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't it be an additional return value from map_data?
Mainly needs some tests, currently the |
@eb8680 since youre away this week, @OptimusLime or i can help with the tests, but the rest of the comments should probably be addressed by you |
…dev in, then refactored scaling. Scaling now lives in its own Poutine hidden inside TracePoutine, since that's the only place rescaled log_pdf can ever get called from.
…es sometimes, but why???
…th conservative downward behavior
@eb8680 awesome! i look forward to going through this [ and proving it still doesn't work ;) ] |
…, and updated call signatures of primitives in tracegraphpoutine
pyro/__init__.py
Outdated
# default behavior | ||
if isinstance(data, (torch.Tensor, Variable)): # XXX and np.ndarray? | ||
if batch_size > 0: | ||
if not hasattr(fn, "__map_data_indices"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the attribute stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
pyro/poutine/poutine.py
Outdated
ret = self._pyro_sample(msg, msg["name"], | ||
msg["fn"], | ||
*msg["args"], **msg["kwargs"]) | ||
new_msg = msg.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the copy()s ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
x: pyro.observe( | ||
"obs_%d" % i, dist.diagnormal, | ||
x, mu_latent, torch.pow(self.lam, -0.5)), batch_size=batch_size) | ||
pyro.map_data("bbb", self.data, lambda i, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of the z samples in these tests? dummies to see that things still work with 2 'map_data's?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess so, I should scrutinize the test models/guides more carefully
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
took a first look. should be good for merge after fixes
pyro/__init__.py
Outdated
if len(_PYRO_STACK) == 0: | ||
# default behavior | ||
if isinstance(data, (torch.Tensor, Variable)): # XXX and np.ndarray? | ||
if batch_size > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to check batch_size <= len(Tensor/list) or youll index out of bounds here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
"fn": fn, | ||
"data": data, | ||
"batch_size": batch_size, | ||
# XXX should these be added here or during application |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add args, kwargs
in msg
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, better to have map_data
functions have a standard fixed interface
from pyro.infer.kl_qp import KL_QP | ||
|
||
|
||
class NormalNormalTests(TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: nested mapdata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do in separate PR
tests/test_mapdata.py
Outdated
def test_elbo_reparameterized(self): | ||
for batch_size in [8, 7, 6, 4, 3, 0]: | ||
self.do_elbo_test(True, 5000, batch_size, map_type="list") | ||
self.do_elbo_test(True, 5000, batch_size, map_type="tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you separate these two into two unit tests? ie (test_elbo_reparam_list
and test_elbo_reparam_tensor
)? because as it stands it's one giant unit test and when it fails you dont know which of the 13 tests it failed. at least this will split the list and tensor types, which take two different control flows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will do
pyro/__init__.py
Outdated
""" | ||
:param name: named argument | ||
:param data: data tp subsample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you change the input parameters, change the comments dont remove them. these get auto-generated to docs so it will be missing parameter descriptions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm going to do a separate PR with just poutine documentation
pyro/poutine/poutine.py
Outdated
msg["fn"], | ||
*msg["args"], **msg["kwargs"]) | ||
new_msg = msg.copy() | ||
new_msg.update({"ret": ret}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pull these two lines out of the switch
statement
pyro/poutine/poutine.py
Outdated
for i in range(0, loc + 1): | ||
pyro._PYRO_STACK.pop(0) | ||
|
||
def _pyro_sample(self, prev_val, name, fn, *args, **kwargs): | ||
def _get_scale(self, data, batch_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably can combine most of this repeated code with map_data
above.. or have that function call this one
pyro/poutine/poutine.py
Outdated
|
||
def _pyro_param(self, prev_val, name, *args, **kwargs): | ||
else: | ||
if batch_size is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
upper bound batch_size
""" | ||
Use the batch indices from the guide trace | ||
""" | ||
if batch_size is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
upper bound batch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be addressed by util.get_scale
pyro/poutine/trace.py
Outdated
@@ -1,4 +1,5 @@ | |||
import pyro | |||
import pdb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
…d up default map_data code
…uplicate code in poutine.up
looks good, approving but noting things that should be addressed in future PRs:
|
Review, but needs some additional tests, do not merge yet
Should close #13