-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraceGraphPoutine and TraceGraph_KL_QP #64
Conversation
c581f89
to
f52e37c
Compare
68e9abc
to
68b80a8
Compare
…tin-dev-tracegraph
… moved away from using python id(); graph structure still a bit funky in some edge cases but seems to be ok for current use case
Great work, this is really nice. I like the mechanism a lot. Some high-level comments in addition to a code review:
Thoughts on the necessity of these things before merging? |
This is a nice approach -- it should generalize well. I'll take a look at the code soon, to make sure I understand the details. I'm curious how much overhead is incurred by this. Having to track the forward graph ourselves seems like it could slow things down substantially; constructing the surrogate loss less so, but possibly. Could you minimally do a speed test comparing your tracegraph kl_qp to the original version? |
@ngoodman if i do a speed comparison for one of the inference tests (which are very overhead heavy given the tiny tensors involved etc.) i find that the tracegraph version of kl_qp takes about 50% more time per gradient step |
added the simplest kind of baseline, namely the exponential decaying average kind. does a lot to stabilize some of the inferential tests. for now am punting on adding support for fancier baselines, as we should first think through interface stuff. |
having skimmed through the code, i have a few questions about the monkeypatching method for building the SCG.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really cool stuff.
i didnt comb through the klqp algo since i havent had time to read the paper yet, but the tests seem comprehensive
some things to keep in mind moving forward:
- pt just had a new release with a new and improved autograd. does that break anything?
- overhead of graph generation
- dont forget to document this (when it's merged)
- probably should stress test the graph structure later on, i feel like there might be weird edge cases we're missing
@@ -42,6 +42,10 @@ def sample(self, mu=None, sigma=None, *args, **kwargs): | |||
_mu, _sigma = self._sanitize_input(mu, sigma) | |||
eps = Variable(torch.randn(_mu.size())) | |||
z = _mu + eps * _sigma | |||
if 'reparameterized' in kwargs: | |||
self.reparameterized = kwargs['reparameterized'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this work correctly? if reparameterized
is in kwargs
then there is no object initialized so setting self.reparameterized
will not be persisted, no?
i assume you checked in klqp that the reparameterized flag was doing the right thing because in dev it is not, currently commented out in test_inference.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't understand. self.reparameterized = False by default from Distribution, no?
# XXX not actually using the result of this computation, but these two calls | ||
# trigger the actual log_pdf calculations and fill in the trace dictionary. | ||
# do elsewhere? | ||
_ = guide_trace.log_pdf() - model_trace.log_pdf() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can call this without assignment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah well the whole got to go anyway
# XXX parents currently include parameters (unnecessarily so) | ||
cost_nodes = [] | ||
for name in model_trace.keys(): | ||
mtn = model_trace[name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you name this something better? like mtrace_name
or even model_name
would be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that'll make my lines too long
@@ -42,6 +42,10 @@ def sample(self, mu=None, sigma=None, *args, **kwargs): | |||
_mu, _sigma = self._sanitize_input(mu, sigma) | |||
eps = Variable(torch.randn(_mu.size())) | |||
z = _mu + eps * _sigma | |||
if 'reparameterized' in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we do this for all reparameterizable dists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can although i don't necessarily have any inference tests waiting to take advantage
from .trace_poutine import TracePoutine | ||
|
||
|
||
def varid(v): # XXX what's the best way to do this?? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you generalize as a get_id
and handle both Variables and tensors? or is the Variable assumption always good.. because nowhere else are you checking that youre using Variables
-- observation nodes are green | ||
-- intermediate nodes are grey | ||
-- include_intermediates controls granularity of visualization | ||
-- if there's a return value node, it's visualized as a double circle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should add this to the docs
itd be cool to get a visualized forward graph on the website examples
XXX some things are still funky with the graph (it's not necessarily a DAG) although i don't | ||
think this affects TraceGraph_KL_QP. this has to do with the unique id used, how that | ||
interacts with operations like torch.view(), etc. all this should be solved once we | ||
move away from monkeypatching? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know what this will look like with the new autograd? (presumably the same but i havent checked)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't change anything although i'll have to check
register inputs and output of pytorch function with graph | ||
""" | ||
assert type(output) not in [tuple, list, dict],\ | ||
"registor_function: output type not as expected" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print the type here for debugging? also typo "register"
# remove loops | ||
# for edge in self.G.edges(): | ||
# if edge[0]==edge[1]: | ||
# self.G.remove_edge(*edge) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed if it's a DAG? if not you should remove it
# construct the forward graph | ||
def new_function__call__(func, *args, **kwargs): | ||
output = self.old_function__call__(func, *args, **kwargs) | ||
if self.monkeypatch_active: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"override
"
yes, current implementation assumes we're always working with variables.
as long as the derivation is happening through a pytorch function like torch.ge() (so that inputs and outputs are available to be registered) it will still be able to trace out the dependency structure (even if it's not differentiable)
no, this isn't currently supported but should be easy to include
yeah i'm also not sure if this Does The Right Thing in all cases. it certainly estimates a reasonable gradient but whether that gradient is actually what is really wanted in all cases is something i still need to think about more |
going to close this PR for the time-being in favor of using a simpler, more conservative tracegraph will revisit when pytorch settles on and exposes appropriate forward graph structure |
this PR should go some way in getting us more robust SVI with reduced variance.
basically it's an implementation of Gradient Estimation Using Stochastic Computation Graphs from John Schulman, Nicolas Heess, Theophane Weber, & Pieter Abbeel for the ELBO. the only thing that's currently missing is fancier baselines. that in principle should be straightforward to add. it's largely a matter of thinking through how we want the user to interact with baselines, and which tradeoffs between fine-grained user control and syntactic succinctness we want to accept.
note that this is written for the ELBO, but not all that much more code is required to make it a gradient estimator for generic objective functions. we should have a conversation about how that interface should look and think about which, if any, parts of it should be pushed upwards into pytorch. if we decide yes, we should consult with the pytorchers.
it would be good for someone to look through the logic that builds the surrogate loss that is differentiated and make sure i'm not missing any edge cases. as of right now the following inference tests for conjugate models all pass:
the numbers in parentheses give the number of combinations of nonreparam/reparam latent variables that are tested. that is, diag_normal now takes reparameterizable=True/False as an argument, which is useful for getting better test coverage. note that anecdotally i can see that optimization proceeds more smoothly for most of these models with the fancier gradient estimator, although some of them are still a bit finicky (something baselines should hopefully address). these tests are all in test_tracegraph_klqp. note that i've retained the old kl_qp. we should consider when might be an appropriate time to remove it.
implementation-wise, TraceGraphPoutine inherits from TracePoutine and records the forward graph as a TraceGraph object. it also produces an optional visualization (although the visualization could still be improved in various ways). we may also want to combine/refactor the Trace and TraceGraph data structures in various ways, but i leave it like this for the time-being, so that inference algorithms that don't need the forward graph aren't impacted.
probably the next main step is to implement a fancier version of baselines.
additional notes: