Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd #109690
Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd #109690
Changes from 15 commits
2208e96
fcc5d22
d5ed407
f15a8b8
90fb411
5909aa5
7564ee8
d3e0685
e793656
722d3de
18c134b
9b5f464
78a4b1f
7fd2e88
735a68d
807641b
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a proposed rewrite of the top level comment:
trace_wrapped(*args, fn)
is equivalent tofn(*args)
, but with a twist: if youmake_fx
trace through this call, we will not actually trace into fn; instead, we will directly insert it as acall_function
tofn
in the graph. (Unlikemake_fx
, Dynamo WILL inline into fn.) You can think of this as a one offallow_in_graph
equivalent for proxy tensor tracing.Because proxy tensor tracing does not actually run the function, there are requirements on the behavior of fn. We are still figuring it out, but here is the current state:
empty_like(input)
is a permissible implementation of fn). This is verified via an extra assert that is inserted into the traced graph.These requirements stem from the requirement that we need to continue performing proxy tensor tracing, which assumes accurate fake tensor metadata, without actually running fn. In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.
Note that tensors / Python state are allowed to be mutated. This is relaxed constraint is not always sound, but it is sound for backward tracing with fake tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).
The intended use case for this function is to allow AOTAutograd to defer complex backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves the function call as is in the graph, and only when we Dynamo through the backward graph in compiled autograd do we inline into the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this is better. Thanks for rewriting it. eg: zeros_like(input) I suppose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are using may and must as https://www.rfc-editor.org/rfc/rfc2119 - let's use SHOULD and MUST ;)
Thank you again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason to take this variadically, you only support one argument 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went back and forth. This feels better in case we want to use it in autograd.Function, where we take multiple args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob dynamo q: how does dynamo know to execute these asserts at compile time (while dynamo is tracing), instead of automatically trying to add these asserts and metadata calls as proxies into the backward graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So long as this function is not allowed in graph, dynamo must inline into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right- thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be subsumed by the comment above I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: call this trace_wrapped instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the tree_map here is also unnecessary, just s/grad/out_proxy/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, why do you need to do this twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand why you did this (you need to prevent the assert from getting DCEd) but I don't think this is the right way to do it. Let me think...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to prevent this from DCE'd? Like, the assert can just have no data deps and you don't have to track at all. What happens when you do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, lemme try, I thought it was cause of DCE but now I do not remember.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you get DCE if you don't track w/ create_proxy. However, if we change it to create_node, it breaks in other ways because none of the rest of this is nodes. It's all proxies. Is there a way to pass proxies to node creation? It seems like crossing streams...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every proxy has a node so you can extract the node from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ofc, but is that kosher here? is that better than just repeating proxy binding code? Does it actually make a difference? I defer to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is some DCE thing, it will happen whether or not you create_proxy or create_node. I guess this is fine. Actually, why don't you just shove this into
self_invoke
, that will also prevent DCEThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yanboliang