Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized map_data #62

Merged
merged 38 commits into from
Sep 6, 2017
Merged

Vectorized map_data #62

merged 38 commits into from
Sep 6, 2017

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Jul 20, 2017

Review, but needs some additional tests, do not merge yet

Should close #13

@eb8680 eb8680 mentioned this pull request Jul 20, 2017
2 tasks
Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a quick first look on my end. as you pointed out, lacking tests.

"""
Default pyro.map_data Poutine behavior
"""
if self.transparent and prev_val is not None:
if self.transparent and not (prev_val is None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for this change?

else:
if batch_size is None:
batch_size = 0
assert batch_size >= 0, "cannot have negative batch sizes"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have batch_size default to 0 if it's negative?


def scaled_sample(_prev_val, _name, _fn, *args, **kwargs):
return old_sample(_prev_val, _name,
pyro.util.rescale_dist(_fn, scale),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import rescale_dist from pyro.util so you dont have to write it out (as other classes do as well)


self._pyro_sample = scaled_sample
ret = list(map(lambda ix: fn(*ix), [(i, data[i]) for i in ind]))
self._pyro_sample = old_sample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment here explaining multiplying the scaling factor then resetting it

@@ -40,14 +40,19 @@ def add_observe(self, name, val, fn, obs, *args, **kwargs):
self[name] = site
return self

def add_map_data(self, name, data, fn):
def add_map_data(self, name, fn, batch_size, scale, ind, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason you added **kwargs but not *args?

site["scale"] = scale
site["fn"] = fn
# site["value"] = val # XXX too large to store
# site["args"] = ((), kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats this for?

pyro/util.py Outdated
def new_log_pdf(*args, **kwargs):
return old_log_pdf(*args, **kwargs) * scale

new_fn = copy.copy(fn) # XXX incorrect?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is shallow sufficient for this?

pyro/util.py Outdated
if hasattr(fn, "log_pdf"):
old_log_pdf = fn.log_pdf

def new_log_pdf(*args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaled_log_pdf might be a more descriptive name

pyro/util.py Outdated
new_fn.log_pdf = new_log_pdf
return new_fn
else:
# XXX should raise an error here?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or at the very least a warning since it might help catch bugs that fall here


old_sample = self._pyro_sample

def scaled_sample(_prev_val, _name, _fn, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either here, or in rescale_dist, check for scale == 1.0 and avoid unnecessary wrapping?

if batch_size is None:
batch_size = 0
assert batch_size >= 0, "cannot have negative batch sizes"
if hasattr(fn, "__map_data_indices") and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it strange to store this information as attributes of fn?

Copy link
Member Author

@eb8680 eb8680 Jul 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, very strange, but I'm not sure if there's a more elegant way to do it. Naively, we'd want to put the rescaling and subsampling into a new Poutine subclass, but because map_data is implemented in the base Poutine, you'd be using an instance of a child class inside one of the base class methods. I imagine there's a way to do this, since I'm pretty sure mutual recursion with functions is OK, but I actually tried this initially and got a bunch of mysterious ImportErrors related to that which I was unable to clear up, so here we are.

Suggestions welcome!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't it be an additional return value from map_data?

@martinjankowiak
Copy link
Collaborator

@jpchen @eb8680 what do we need to finish up and merge this PR?

@eb8680
Copy link
Member Author

eb8680 commented Aug 5, 2017

Mainly needs some tests, currently the map_data sites in the existing inference tests are (almost?) all maps over lists/tuples rather than tensors.

@jpchen
Copy link
Member

jpchen commented Aug 8, 2017

@eb8680 since youre away this week, @OptimusLime or i can help with the tests, but the rest of the comments should probably be addressed by you

@martinjankowiak
Copy link
Collaborator

@eb8680 awesome! i look forward to going through this [ and proving it still doesn't work ;) ]

pyro/__init__.py Outdated
# default behavior
if isinstance(data, (torch.Tensor, Variable)): # XXX and np.ndarray?
if batch_size > 0:
if not hasattr(fn, "__map_data_indices"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the attribute stuff

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

ret = self._pyro_sample(msg, msg["name"],
msg["fn"],
*msg["args"], **msg["kwargs"])
new_msg = msg.copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the copy()s ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

x: pyro.observe(
"obs_%d" % i, dist.diagnormal,
x, mu_latent, torch.pow(self.lam, -0.5)), batch_size=batch_size)
pyro.map_data("bbb", self.data, lambda i,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of the z samples in these tests? dummies to see that things still work with 2 'map_data's?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, I should scrutinize the test models/guides more carefully

@martinjankowiak
Copy link
Collaborator

martinjankowiak commented Sep 5, 2017

LGTM once comments addressed by @eb8680 and @jpchen gives the go ahead.

though one possible concern is whether poutines outside of the ones used by map_data have adequate test coverage given the poutine rewrite?

Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

took a first look. should be good for merge after fixes

pyro/__init__.py Outdated
if len(_PYRO_STACK) == 0:
# default behavior
if isinstance(data, (torch.Tensor, Variable)): # XXX and np.ndarray?
if batch_size > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check batch_size <= len(Tensor/list) or youll index out of bounds here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

"fn": fn,
"data": data,
"batch_size": batch_size,
# XXX should these be added here or during application
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add args, kwargs in msg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, better to have map_data functions have a standard fixed interface

from pyro.infer.kl_qp import KL_QP


class NormalNormalTests(TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: nested mapdata

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do in separate PR

def test_elbo_reparameterized(self):
for batch_size in [8, 7, 6, 4, 3, 0]:
self.do_elbo_test(True, 5000, batch_size, map_type="list")
self.do_elbo_test(True, 5000, batch_size, map_type="tensor")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you separate these two into two unit tests? ie (test_elbo_reparam_list and test_elbo_reparam_tensor)? because as it stands it's one giant unit test and when it fails you dont know which of the 13 tests it failed. at least this will split the list and tensor types, which take two different control flows.

Copy link
Member Author

@eb8680 eb8680 Sep 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will do

pyro/__init__.py Outdated
"""
:param name: named argument
:param data: data tp subsample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you change the input parameters, change the comments dont remove them. these get auto-generated to docs so it will be missing parameter descriptions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm going to do a separate PR with just poutine documentation

msg["fn"],
*msg["args"], **msg["kwargs"])
new_msg = msg.copy()
new_msg.update({"ret": ret})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull these two lines out of the switch statement

for i in range(0, loc + 1):
pyro._PYRO_STACK.pop(0)

def _pyro_sample(self, prev_val, name, fn, *args, **kwargs):
def _get_scale(self, data, batch_size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably can combine most of this repeated code with map_data above.. or have that function call this one


def _pyro_param(self, prev_val, name, *args, **kwargs):
else:
if batch_size is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper bound batch_size

"""
Use the batch indices from the guide trace
"""
if batch_size is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upper bound batch_size

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be addressed by util.get_scale

@@ -1,4 +1,5 @@
import pyro
import pdb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

@jpchen
Copy link
Member

jpchen commented Sep 6, 2017

looks good, approving but noting things that should be addressed in future PRs:

  • nested mapData test
  • moving msg struct to utils or trace
  • documentation for up and down and other methods in docstrings
  • write observe in terms of sample and eliminate repeated code

jpchen
jpchen previously approved these changes Sep 6, 2017
@jpchen jpchen merged commit e41b3de into dev Sep 6, 2017
@jpchen jpchen mentioned this pull request Sep 6, 2017
4 tasks
@jpchen jpchen deleted the eli-map_data-pr branch September 6, 2017 04:58
This was referenced Sep 6, 2017
@fritzo fritzo mentioned this pull request Aug 8, 2018
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor map_data
3 participants