Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd #109690

Closed
wants to merge 16 commits into from

Conversation

…arly for getting user defined hooks to compiled autograd

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109690

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 2 Pending

As of commit 807641b with merge base 34ded74 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

voznesenskym added a commit that referenced this pull request Sep 20, 2023
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 9194b20f488ef37015646cb041195b995c322f6e
Pull Request resolved: #109690
…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Sep 20, 2023
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 1daeccc15a48edf8a8a9c855cc3d514924c70d29
Pull Request resolved: #109690
@voznesenskym voznesenskym marked this pull request as draft September 20, 2023 08:08
fn = functools.partial(self_invoke, fn=fn)
fn.__name__ = fn.func.__name__

is_func = torch._is_functional_tensor(grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, seems fishy that we need to muck with functional tensors inside of ProxyTorchDispatchMode, lmk if I can help

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You commented on this here #107502 (comment) sorry to move stuff around so much

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right :p

…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Sep 20, 2023
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 0518747bd22e427555bb72c60356c04827313b3b
Pull Request resolved: #109690

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd
…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Sep 20, 2023
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 29d32f183082aabb0402241a4fed1852eede2d2d
Pull Request resolved: #109690

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd
…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
@voznesenskym voznesenskym marked this pull request as ready for review September 20, 2023 20:35
voznesenskym added a commit that referenced this pull request Sep 20, 2023
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 84c1209b94bcca2f0cd1d84242dd2291fde9beb7
Pull Request resolved: #109690

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd
# the functions as needed. This, in turn, means we can support functions in backward with complex python
# state mutation. If we were to not do this, the functions would get inlined into their composing aten ops,
# and we would lose the python state mutation.
def _trace_wrapped(*args, fn):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still bikeshedding on name, placeholder for now to not churn review comments overmuch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something about leaving the inner fn opaque so we can trace it in the bw? trace_opaque, trace_opaque_for_bw (idk)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang suggested PythonOp or some variation around the word leaf...

out_proxy = mode.tracer.create_proxy(
"call_function", fn, proxy_args, {}, name="invocation"
)
grad = torch.empty_like(grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: zeros_like to prevent unfortunate accidents when grad is actually not a fake tensor. Or perhaps assert that grad must be fake?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) Do we ever intend to use this higher order op for stuff other than hooks?

(2) Is it asserted anywhere else (e.g. the autograd engine) that hook functions always take in a single tensor argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) We want this for autograd function backwards too, and other stuff in the future

(2) Yea, thats the contract, but I like repeating invariants. I'll do a zeros_like and assert too.

name="assert",
)
grad = torch.empty_like(grad)
grad = track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, why do you need to do this twice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand why you did this (you need to prevent the assert from getting DCEd) but I don't think this is the right way to do it. Let me think...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to prevent this from DCE'd? Like, the assert can just have no data deps and you don't have to track at all. What happens when you do that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, lemme try, I thought it was cause of DCE but now I do not remember.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you get DCE if you don't track w/ create_proxy. However, if we change it to create_node, it breaks in other ways because none of the rest of this is nodes. It's all proxies. Is there a way to pass proxies to node creation? It seems like crossing streams...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every proxy has a node so you can extract the node from

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ofc, but is that kosher here? is that better than just repeating proxy binding code? Does it actually make a difference? I defer to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is some DCE thing, it will happen whether or not you create_proxy or create_node. I guess this is fine. Actually, why don't you just shove this into self_invoke, that will also prevent DCE

tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch._dynamo._trace_wrapped_higher_order_op._assert_meta,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No why do you have to do this. You've already inserted the assert meta into the graph, you're going to trace into it later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussing offline atm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will keep it simple

("dtype", _tensor_mutating_dtype),
]:

def _graph_break_invoke(grad):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This technically won't cause a graph break (as its name implies) right - it'll cause a hard error in the backward?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it would be nice to have another test for aliasing, where a hook returns a view of the input. (which is "wrong" because our fake tensor rule for the wrap higher order op assumes that the output of the hook never aliases the input).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a graph break, which turns into a hard error.

return _trace_wrapped_op(*args, fn=fn)


def _assert_meta(grad, size, stride, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob dynamo q: how does dynamo know to execute these asserts at compile time (while dynamo is tracing), instead of automatically trying to add these asserts and metadata calls as proxies into the backward graph?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So long as this function is not allowed in graph, dynamo must inline into it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right- thanks!

@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2023

fwiw, for me, this is very close, just my two last comments

@voznesenskym
Copy link
Collaborator Author

fwiw, for me, this is very close, just my two last comments

Understood, tyvm :)

…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
fn = kwargs["fn"]
assert len(args) == 1
grad = args[0]
assert isinstance(grad, TensorVariable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shouldn't be asserts right, because malformed user code can trigger them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is explicitly a user invoked operation, provided almost akin to something like a compiler directive - it feels like it should assert. I dont mind doing unimplemented but I feel a tad stronger here than like, for when we are aping something from std.

# we can support functions in backward with complex python. It can be thought of as an allow_in_graph
# for our aten graph. If we were to not do this, the functions would get inlined into their composing aten ops,
# and we would lose the python state mutation.
def trace_wrapped(*args, fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to take this variadically, you only support one argument 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back and forth. This feels better in case we want to use it in autograd.Function, where we take multiple args.

# the functions as needed. While there is nothing backward specific about this op, the way it is written means
# we can support functions in backward with complex python. It can be thought of as an allow_in_graph
# for our aten graph. If we were to not do this, the functions would get inlined into their composing aten ops,
# and we would lose the python state mutation.
Copy link
Contributor

@ezyang ezyang Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a proposed rewrite of the top level comment:

trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: if you make_fx trace through this call, we will not actually trace into fn; instead, we will directly insert it as a call_function to fn in the graph. (Unlike make_fx, Dynamo WILL inline into fn.) You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.

Because proxy tensor tracing does not actually run the function, there are requirements on the behavior of fn. We are still figuring it out, but here is the current state:

  • fn can only take a single argument, which must be a tensor
  • fn must return a new tensor with the same metadata as the original tensor (e.g., empty_like(input) is a permissible implementation of fn). This is verified via an extra assert that is inserted into the traced graph.
  • fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)

These requirements stem from the requirement that we need to continue performing proxy tensor tracing, which assumes accurate fake tensor metadata, without actually running fn. In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.

Note that tensors / Python state are allowed to be mutated. This is relaxed constraint is not always sound, but it is sound for backward tracing with fake tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).

The intended use case for this function is to allow AOTAutograd to defer complex backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves the function call as is in the graph, and only when we Dynamo through the backward graph in compiled autograd do we inline into the function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this is better. Thanks for rewriting it. eg: zeros_like(input) I suppose

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are using may and must as https://www.rfc-editor.org/rfc/rfc2119 - let's use SHOULD and MUST ;)

Thank you again.

# calls into "leaf modules" as per traditional FX terminology.
# Note: Instead of naming it "allow_in_graph", we opted for a different name since "allow_in_graph"
# might imply that it's traceable, whereas this function is intrinsically non-traceable.
# Note2: I hate this name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be subsumed by the comment above I think


proxy_args = (mode.tracer.unwrap_proxy(grad),)
out_proxy = mode.tracer.create_proxy(
"call_function", self_invoke, proxy_args, {}, name="invocation"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this trace_wrapped instead?

# a runtime assert
proxy_args = pytree.tree_map(
mode.tracer.unwrap_proxy, (grad, grad.size(), grad.stride(), grad.dtype)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the tree_map here is also unnecessary, just s/grad/out_proxy/

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

go go go

…e, particularly for getting user defined hooks to compiled autograd"

cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
voznesenskym added a commit that referenced this pull request Sep 27, 2023
…arly for getting user defined hooks to compiled autograd

ghstack-source-id: 37029e876bf9889441c2f469b06add4fe754e4d5
Pull Request resolved: #109690

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

Higher order op for preserving leaf functions through trace, particularly for getting user defined hooks to compiled autograd

lint

lint

lint

lint

more test

more test

more test

more test
@pytorch pytorch deleted a comment from pytorch-bot bot Sep 27, 2023
@voznesenskym
Copy link
Collaborator Author

@pytorchbot merge -f "Flaky ci, graph break break in vision rcnn introduced in #110101"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/voznesenskym/226/head branch October 1, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants