-
Notifications
You must be signed in to change notification settings - Fork 25.6k
torch.compile should auto-functionalize certain mutable ops #114955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114955
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit e44a9f0 with merge base 8dbae73 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
) | ||
|
||
@torch.library.impl("mylib::foo", "cpu", lib=lib) | ||
@torch._dynamo.disable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curiosity: why do we need the dynamo.disable here? (Won't dynamo never end up seeing foo_impl anyway?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamo interposes on all frames. The Pytorch dispatcher invoking this function is a frame, so Dynamo ends up trying to compile it. Dynamo errors out loudly here because the Python key has been disabled at this point, but it tries to construct FakeTensors. The workaround is to disable Dynamo here, but a longer-term solution is to get Dynamo to reset dispatcher state temporarily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Pytorch dispatcher invoking this function is a frame, so Dynamo ends up trying to compile it. Dynamo errors out loudly here because the Python key has been disabled at this point, but it tries to construct FakeTensors
hmm. I'm sure this is a noob question, but - this is a python function getting invoked by C++ code (the pytorch dispatcher). I had always just assumed that any python getting called by C++ code was "invisible" to dynamo. If dynamo will indeed go ahead and try to compile this python frame (and fakeify the inputs) even though it was called from C++, then why doesn't e.g. every tensor subclass's __torch_dispatch__
not also need to be wrappe in a @dynamo.disable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a python function getting invoked by C++ code (the pytorch dispatcher). I had always just assumed that any python getting called by C++ code was "invisible" to dynamo
My understanding is that python functions called by C++ code still get seen by Dynamo.
If dynamo will indeed go ahead and try to compile this python frame (and fakeify the inputs) even though it was called from C++, then why doesn't e.g. every tensor subclass's torch_dispatch not also need to be wrappe in a @dynamo.disable?
We execute the__torch_dispatch__
as a part of AOTAutograd, right? And Dynamo is turned off during AOTAutograd.
test/dynamo/test_misc.py
Outdated
orig_args = (x, y, z, n) | ||
|
||
compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) | ||
torch.compile(f, backend="aot_eager")(*compiled_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought: aot_eager_decomp_partition
might be a better test backend here: the aot_eager
backend actually does not try to keep input mutations in the graph (so the other backend will be a little closer to the "default" torch.compile behavior).
We could maybe update aot_eager
to keep input mutations in the graph, but I thought it would be a nice way to debug compiler issues specific to input mutations if we keep aot_eager as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know about this backend, let me try it out
auto_functionalized = AutoFunctionalized() | ||
|
||
|
||
def can_auto_functionalize(op: torch._ops.OperatorBase) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's worth returning False for ops in the ATen namespace? Just so we don't accidentally auto-functionalize them (even though I think in way you use this in fake_tensor.py
today, that will never happen since you only try to run this during the fallback).
My thought was that this will be silently wrong for metadata-mutation ops - and even once you add op_tests to error on metadata-mutation, it's still valid to have aten ops that mutate metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we could
) -> Tuple[Tensor, ...]: | ||
new_kwargs = dict(**kwargs) | ||
result = [] | ||
for name in mutated_args_names: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's worth a warn_once()
here as a perf warning? (counterargument == too noisy, and it's not very actionable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not for now, as you mentioned that it is not very actionable. Maybe we can do this in the future when the re-inplacing pass exists (and we can log instead of warn so someone staring at the log might be able to surface this to us)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments but lgtm
How does this work if there are aliased inputs to an op? In that case, won't cloning before a mutation potentially be incorrect ? |
@eellison that's a good test case, I'll add one. We expect this to just work automatically (https://gist.github.com/zou3519/373634e2be660026252969882fa7fe84) The TL;DR is that when we replace the original input with the new output during Functionalization, if the input to be replaced is a view, we do some heavy-lifting of ensuring that each view gets updated |
Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This should deflake some of the tests Pull Request resolved: #114956 Approved by: https://github.com/williamwen42 ghstack dependencies: #114955
In preparation for the next PR up in the stack, which is going to update "can_auto_functionalize" to support more operators than just ones that return nothing. We are unable to auto-generate FakeTensor kernels for operators that do not return nothing, but we are able to generate functionalization kernels for operators that return something. Test Plan: Existing tests Pull Request resolved: #115134 Approved by: https://github.com/bdhirsh ghstack dependencies: #114955, #114956
We can auto-functionalize operators that mutate their inputs as long as the outputs of the operator do not alias their inputs. The user needs to provide an abstract impl for the operator if it has non-trivial returns. - We update can_auto_functionalize(op) to include ops that return (but do not alias) Tensors - We update auto_functionalized(op, mutated_args_names, kwargs) to return (out, mutated_args), where `out = op(**kwargs)` and `mutated_args` are the new values of the inputs that would have been mutated. Test Plan: - new test Pull Request resolved: #115135 Approved by: https://github.com/bdhirsh ghstack dependencies: #114955, #114956, #115134
…114955) Users may wish to torch.compile custom ops that mutate their inputs and return nothing (this is a common class of operators). torch.compile will automatically support this op without anyone needing to provide a functionalization kernel for it. Here's how. Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () op. First, when FakeTensor sees this op, it can just return None. This is the case because custom ops are not allowed to mutate input metadata, so the FakeTensor rule for one that returns nothing is trivial. Next, when Python FunctionalTensor sees the op, it will functionalize it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) HOP and replacing the mutated inputs with the outputs of this HOP. This HOP effectively runs the functional version of the op when called: it clones inputs that will be mutated, runs the op, and then returns Tensors with the new values. In the future we can teach Inductor how to do re-inplacing when it sees this HOP (like how triton kernels do it) but this isn't urgent (and is more of a performance problem). Test Plan: - new tests Pull Request resolved: pytorch#114955 Approved by: https://github.com/bdhirsh
…ch#114956) This should deflake some of the tests Pull Request resolved: pytorch#114956 Approved by: https://github.com/williamwen42 ghstack dependencies: pytorch#114955
In preparation for the next PR up in the stack, which is going to update "can_auto_functionalize" to support more operators than just ones that return nothing. We are unable to auto-generate FakeTensor kernels for operators that do not return nothing, but we are able to generate functionalization kernels for operators that return something. Test Plan: Existing tests Pull Request resolved: pytorch#115134 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#114955, pytorch#114956
…h#115135) We can auto-functionalize operators that mutate their inputs as long as the outputs of the operator do not alias their inputs. The user needs to provide an abstract impl for the operator if it has non-trivial returns. - We update can_auto_functionalize(op) to include ops that return (but do not alias) Tensors - We update auto_functionalized(op, mutated_args_names, kwargs) to return (out, mutated_args), where `out = op(**kwargs)` and `mutated_args` are the new values of the inputs that would have been mutated. Test Plan: - new test Pull Request resolved: pytorch#115135 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#114955, pytorch#114956, pytorch#115134
Stack from ghstack (oldest at bottom):
Users may wish to torch.compile custom ops that mutate their inputs
and return nothing (this is a common class of operators).
torch.compile will automatically support this op without anyone needing
to provide a functionalization kernel for it. Here's how.
Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
op. First, when FakeTensor sees this op, it can just return None.
This is the case because custom ops are not allowed to mutate input
metadata, so the FakeTensor rule for one that returns nothing is trivial.
Next, when Python FunctionalTensor sees the op, it will functionalize
it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
HOP and replacing the mutated inputs with the outputs of this HOP.
This HOP effectively runs the functional version of the op when
called: it clones inputs that will be mutated, runs the op, and
then returns Tensors with the new values.
In the future we can teach Inductor how to do re-inplacing when it sees
this HOP (like how triton kernels do it) but this isn't urgent (and is
more of a performance problem).
Test Plan:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng