Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceGraphPoutine and TraceGraph_KL_QP #64

Closed
wants to merge 20 commits into from

Conversation

martinjankowiak
Copy link
Collaborator

@martinjankowiak martinjankowiak commented Aug 2, 2017

this PR should go some way in getting us more robust SVI with reduced variance.

basically it's an implementation of Gradient Estimation Using Stochastic Computation Graphs from John Schulman, Nicolas Heess, Theophane Weber, & Pieter Abbeel for the ELBO. the only thing that's currently missing is fancier baselines. that in principle should be straightforward to add. it's largely a matter of thinking through how we want the user to interact with baselines, and which tradeoffs between fine-grained user control and syntactic succinctness we want to accept.

note that this is written for the ELBO, but not all that much more code is required to make it a gradient estimator for generic objective functions. we should have a conversation about how that interface should look and think about which, if any, parts of it should be pushed upwards into pytorch. if we decide yes, we should consult with the pytorchers.

it would be good for someone to look through the logic that builds the surrogate loss that is differentiated and make sure i'm not missing any edge cases. as of right now the following inference tests for conjugate models all pass:

  • NormalNormal (2)
  • NormalNormalNormal (4)
  • BernoulliBeta (1)
  • PoissonGamma (1)
  • ExponentialGamma (1)
  • LogNormalNormal (2.5)

the numbers in parentheses give the number of combinations of nonreparam/reparam latent variables that are tested. that is, diag_normal now takes reparameterizable=True/False as an argument, which is useful for getting better test coverage. note that anecdotally i can see that optimization proceeds more smoothly for most of these models with the fancier gradient estimator, although some of them are still a bit finicky (something baselines should hopefully address). these tests are all in test_tracegraph_klqp. note that i've retained the old kl_qp. we should consider when might be an appropriate time to remove it.

implementation-wise, TraceGraphPoutine inherits from TracePoutine and records the forward graph as a TraceGraph object. it also produces an optional visualization (although the visualization could still be improved in various ways). we may also want to combine/refactor the Trace and TraceGraph data structures in various ways, but i leave it like this for the time-being, so that inference algorithms that don't need the forward graph aren't impacted.

probably the next main step is to implement a fancier version of baselines.

additional notes:

  • the forward graph is still constructed via some monkeypatching of pytorch. i cleaned this up a bit, and it should be a satisfactory solution for now, but we will need to reconsider later on, as autograd dev continues.
  • dependencies on python packages graphviz (for visualization) and networkx (for graphs). the latter is probably overkill?
  • caveat: currently won't work unless sample/observe statements use the argument-passing idiom, as this is where those dependencies are tracked
  • it would be great if someone (@karalets?) could see if this gradient estimator works for whatever VAE models we have lying around that have non-reparameterizable latent variables

@eb8680 eb8680 self-requested a review August 2, 2017 03:16
@martinjankowiak martinjankowiak force-pushed the martin-dev-tracegraph branch 4 times, most recently from c581f89 to f52e37c Compare August 2, 2017 05:03
@eb8680
Copy link
Member

eb8680 commented Aug 5, 2017

Great work, this is really nice. I like the mechanism a lot. Some high-level comments in addition to a code review:

  • Organizationally, I think it makes the most sense to package the gradient estimators with the distribution library as we discussed rather than implement them at the level of a single inference algorithm or poutine. Personally, I like the idea of having two functions pyro.distributions.grad and pyro.distributions.backward that mimic the functionality and call signatures of torch.autograd.grad/backward but use the more sophisticated gradient estimators in computation graphs that contain Variables sampled from Pyro distributions.
  • To take full advantage of this system it should be paired with fancier gradient estimators for some distributions (e.g. the ones in Reparametrization Gradients Through Acceptance-Rejection Algorithms)
  • I agree with you that we should probably combine TracePoutine and TraceGraphPoutine. The simplest way to do that at the API level is to have a graph=True/False argument to TracePoutine's constructor.
  • On that note, as long as we're going to use autograd graphs as the graph substrate (as opposed to constructing our own entirely separate from autograd, e.g. by using sys.setprofile), I think we can move visualization to a separate module that visualizes autograd graphs with special helpers/indicators for the Pyro random variables in the graph.
  • You're right that we should probably ditch networkx at some point, it's really slow.
  • How should we go about adding control variates and baselines automatically, especially fancier state-dependent ones?

Thoughts on the necessity of these things before merging?

@ngoodman
Copy link
Collaborator

ngoodman commented Aug 6, 2017

This is a nice approach -- it should generalize well. I'll take a look at the code soon, to make sure I understand the details.

I'm curious how much overhead is incurred by this. Having to track the forward graph ourselves seems like it could slow things down substantially; constructing the surrogate loss less so, but possibly. Could you minimally do a speed test comparing your tracegraph kl_qp to the original version?

@martinjankowiak
Copy link
Collaborator Author

@ngoodman if i do a speed comparison for one of the inference tests (which are very overhead heavy given the tiny tensors involved etc.) i find that the tracegraph version of kl_qp takes about 50% more time per gradient step

@martinjankowiak
Copy link
Collaborator Author

added the simplest kind of baseline, namely the exponential decaying average kind. does a lot to stabilize some of the inferential tests. for now am punting on adding support for fancier baselines, as we should first think through interface stuff.

@ngoodman
Copy link
Collaborator

ngoodman commented Aug 7, 2017

having skimmed through the code, i have a few questions about the monkeypatching method for building the SCG.

  • it seems that this will work only if distributions return Variables? is this true for all distributions, including discrete ones?

  • more generally, what happens if we derive a non-Variables value from a Variable, eg by an inequality operator? i think this will drop off the graph, but this is exactly where we might need a score function correction?

  • what about data structures? that is, if i wrap a bunch of Variables in a list and then pass them as the arg to a function, will the monkeypatched Function correctly add dependence to the Variables?

  • does this approach Do The Right Thing with stochastic control structure? eg an "if-flip"? i suspect that it does because we only need to capture the objective within an epsilon ball of current params.... but i'm not sure, and we should double check.

Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really cool stuff.
i didnt comb through the klqp algo since i havent had time to read the paper yet, but the tests seem comprehensive

some things to keep in mind moving forward:

  • pt just had a new release with a new and improved autograd. does that break anything?
  • overhead of graph generation
  • dont forget to document this (when it's merged)
  • probably should stress test the graph structure later on, i feel like there might be weird edge cases we're missing

@@ -42,6 +42,10 @@ def sample(self, mu=None, sigma=None, *args, **kwargs):
_mu, _sigma = self._sanitize_input(mu, sigma)
eps = Variable(torch.randn(_mu.size()))
z = _mu + eps * _sigma
if 'reparameterized' in kwargs:
self.reparameterized = kwargs['reparameterized']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work correctly? if reparameterized is in kwargs then there is no object initialized so setting self.reparameterized will not be persisted, no?

i assume you checked in klqp that the reparameterized flag was doing the right thing because in dev it is not, currently commented out in test_inference.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't understand. self.reparameterized = False by default from Distribution, no?

# XXX not actually using the result of this computation, but these two calls
# trigger the actual log_pdf calculations and fill in the trace dictionary.
# do elsewhere?
_ = guide_trace.log_pdf() - model_trace.log_pdf()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can call this without assignment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah well the whole got to go anyway

# XXX parents currently include parameters (unnecessarily so)
cost_nodes = []
for name in model_trace.keys():
mtn = model_trace[name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you name this something better? like mtrace_name or even model_name would be better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that'll make my lines too long

@@ -42,6 +42,10 @@ def sample(self, mu=None, sigma=None, *args, **kwargs):
_mu, _sigma = self._sanitize_input(mu, sigma)
eps = Variable(torch.randn(_mu.size()))
z = _mu + eps * _sigma
if 'reparameterized' in kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do this for all reparameterizable dists?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can although i don't necessarily have any inference tests waiting to take advantage

from .trace_poutine import TracePoutine


def varid(v): # XXX what's the best way to do this??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you generalize as a get_id and handle both Variables and tensors? or is the Variable assumption always good.. because nowhere else are you checking that youre using Variables

-- observation nodes are green
-- intermediate nodes are grey
-- include_intermediates controls granularity of visualization
-- if there's a return value node, it's visualized as a double circle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should add this to the docs
itd be cool to get a visualized forward graph on the website examples

XXX some things are still funky with the graph (it's not necessarily a DAG) although i don't
think this affects TraceGraph_KL_QP. this has to do with the unique id used, how that
interacts with operations like torch.view(), etc. all this should be solved once we
move away from monkeypatching?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know what this will look like with the new autograd? (presumably the same but i havent checked)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't change anything although i'll have to check

register inputs and output of pytorch function with graph
"""
assert type(output) not in [tuple, list, dict],\
"registor_function: output type not as expected"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print the type here for debugging? also typo "register"

# remove loops
# for edge in self.G.edges():
# if edge[0]==edge[1]:
# self.G.remove_edge(*edge)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed if it's a DAG? if not you should remove it

# construct the forward graph
def new_function__call__(func, *args, **kwargs):
output = self.old_function__call__(func, *args, **kwargs)
if self.monkeypatch_active:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"override"

@martinjankowiak
Copy link
Collaborator Author

@ngoodman

  • it seems that this will work only if distributions return Variables? is this true for all distributions, including discrete ones?

yes, current implementation assumes we're always working with variables.

  • more generally, what happens if we derive a non-Variables value from a Variable, eg by an inequality operator? i think this will drop off the graph, but this is exactly where we might need a score function correction?

as long as the derivation is happening through a pytorch function like torch.ge() (so that inputs and outputs are available to be registered) it will still be able to trace out the dependency structure (even if it's not differentiable)

  • what about data structures? that is, if i wrap a bunch of Variables in a list and then pass them as the arg to a function, will the monkeypatched Function correctly add dependence to the Variables?

no, this isn't currently supported but should be easy to include

  • does this approach Do The Right Thing with stochastic control structure? eg an "if-flip"? i suspect that it does because we only need to capture the objective within an epsilon ball of current params.... but i'm not sure, and we should double check.

yeah i'm also not sure if this Does The Right Thing in all cases. it certainly estimates a reasonable gradient but whether that gradient is actually what is really wanted in all cases is something i still need to think about more

@martinjankowiak
Copy link
Collaborator Author

going to close this PR for the time-being in favor of using a simpler, more conservative tracegraph

will revisit when pytorch settles on and exposes appropriate forward graph structure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants