New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AOTAutograd perf: avoid as_strided() calls when we have intermediate bases #111081
Comments
In fact, a different (but simpler) example is this - no intermediate bases, just an output that aliases the input
|
We need to properly replay these views that happened in the graph, without relying on autograd's view replay. The most obvious way to do this is probably: functionalization already tracks these views somewhere, so we can ask functionalization to replay them (this should also be pretty fast, since it's all stored in lambdas in C++). One small downside to doing it this way, though, is that I don't think we can use this approach when input are subclasses (and the output of the compiled blob is a view of that subclass). Since functionalization runs below the subclass, and the subclass might insert other logic around the view. We can always deal with that problem later though. |
I'm testing out the approach above. This approach doesn't work with dynamic shapes: functionalization does remember all the views that were used to generate the output, but it also stashes the symbolic values that were passed to each view op. In order to actually re-use functionalization's stashed view information, we'd need to resolve each symbol at runtime. This seems doable, but not trivial. |
Do you have a benchmark or similar you are tracking to ensure doing functionalization is worth it here? |
Talked offline - agreed that if we switch from as_strided to view chains as the default in all cases, we should do some amount of benchmarking to make sure we aren't regressing any cases. |
#111411 should fix the most egregious instances of as_strided. It turns out that by-far the worst offenders are multi-output views. Take these two examples: (alias-of-intermediate case)
(alias-of-input case)
It turns out that multi-output views are exactly the case that autograd's view-replay logic cannot handle, so we always fall back to as_strided in both of these cases. Worst of all, the original code had a single call to The "partition the forward graph across view outputs" idea from this doc https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit is a good long-term fix for this problem. But it's very far out - it requires some major surgery to AOTAutograd, and isn't viable until pre-dispatch functionalization is both completed, and hardened enough to run 100% of the time in AOTAutograd (cc @tugsbayasgalan 😃). Instead, the approach that I'm attempting to take in the linked PR is to effectively hide all multi-output-view aliasing from autograd. We should think carefully about whether or not this is safe in all cases. But it relies on the high-level idea that the autograd engine does not allow you to mutate the outputs of multi-output views:
|
…d outputs came from a multi_output_view" Partially addresses #111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. [ghstack-poisoned]
… when all aliased outputs came from a multi_output_view" Partially addresses #111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. [ghstack-poisoned]
…d outputs came from a multi_output_view" Partially addresses #111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. [ghstack-poisoned]
… when all aliased outputs came from a multi_output_view" Partially addresses #111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. [ghstack-poisoned]
…d outputs came from a multi_output_view" Partially addresses #111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. [ghstack-poisoned]
…ame from a multi_output_view (#111411) Partially addresses #111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` @torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. Pull Request resolved: #111411 Approved by: https://github.com/ezyang
…ame from a multi_output_view (pytorch#111411) Partially addresses pytorch#111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` @torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. Pull Request resolved: pytorch#111411 Approved by: https://github.com/ezyang
…ame from a multi_output_view (pytorch#111411) Partially addresses pytorch#111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` @torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. Pull Request resolved: pytorch#111411 Approved by: https://github.com/ezyang
…ame from a multi_output_view (pytorch#111411) Partially addresses pytorch#111081 This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this: ``` @torch.compile def f(x): intermediate = x.mul(2) outs = intermediate.unbind(0) return *outs x = torch.randn(50, 50, requires_grad=True) outs = f(x) sum(outs).sum().backward() ``` There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing. In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`. I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode. To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out. Pull Request resolved: pytorch#111411 Approved by: https://github.com/ezyang
This is a more targeted version of an existing issue around as_strided calls in AOTAutograd, #109237. Came from an internal issue
Simple repro:
prints:
We end up calling as_strided in the compiled forward, so an
AsStridedBackward
node shows up in the backward, which in general is not implemented to be particularly fast.Why does this happen?
(1) AOTAutograd has logic for "intermediate bases". If we have two outputs of our graph that are aliases of each other (and of the same graph intermediate), today, AOTAutograd will just have the shared intermediate be an output to the graph. AOTAutograd will then replay the views off of the intermediate, so that autograd properly realizes that the outputs alias.
(2) AOTAutograd has a function to try to do the view replay, but it hits a slow path in that function that causes it to go to as_strided. We should figure out why and fix this: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L807
cc @ezyang @msaroufim @wconstab @anijain2305 @zou3519
The text was updated successfully, but these errors were encountered: