Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOTAutograd: avoid intermediate_base logic when all aliased outputs came from a multi_output_view #111411

Closed
wants to merge 5 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Oct 17, 2023

Partially addresses #111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:

@torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base logic, and try to regenerate the aliases outside of the compiled autograd.Function at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in 50 AsStridedBackward nodes in the backward, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single UnbindBackward, or a call to aten.cat().

I left a long comment in the code describing the situation, but the core idea is that autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views. So if we have k outputs that alias each other, but k-1 of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, and keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened and adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.

Stack from ghstack (oldest at bottom):

…ame from a multi_output_view

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111411

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 41dc313 with merge base 8253e05 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Oct 17, 2023
…ame from a multi_output_view

ghstack-source-id: 452ea0a1a044769656d877f7dc407b4ecc81b857
Pull Request resolved: #111411
…d outputs came from a multi_output_view"

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 17, 2023
…ame from a multi_output_view

ghstack-source-id: df4bd8bb33b4a6707960c1fb5b7aae2b9f946509
Pull Request resolved: #111411
num_multi_output_view_outs = num_aliased_tensors_that_are_multi_output_views[curr_storage]
num_aliased_outs_that_are_not_multi_output_views = num_aliased_outs - num_multi_output_view_outs
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
if out_tensor_alias_counts[curr_storage] == 1 or num_aliased_outs_that_are_not_multi_output_views <= 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a check on line 1184 elif curr_storage in inp_storage_refs: sometimes we have multi_output_views that are alias_of_input, it would fall into this if and speed up quite a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that if we're going to apply this "hide multi-output view aliasing from autograd" to aliases of intermediates, we should also apply the same logic to aliases of inputs.

This is especially true because autograd's view-replay will generate equally inefficient backward code for the alias-of-input case, if our aliases are multi-output views (view-replay will always generate an as_strided).

Let me spend some time trying to add it and add better testing for it.

…d outputs came from a multi_output_view"

Partially addresses #111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:
```
torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()
```

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`.

I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 19, 2023
…ame from a multi_output_view

ghstack-source-id: e04499d20012fe9e3db64c21e2c757b4c9c94cf6
Pull Request resolved: #111411
Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not possible to just write some logic to try to post-hoc recover the ops that produced the aliasing relationships?

Like, in principle, I would have expected the code we write would look something like

outs = [...]
outs_that_are_aliases = []
for out in outs_that_are_aliases:
     if is_view(out): apply view
     if is_unbind(out): apply unbind
...

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 19, 2023

@Chillee yeah this is another option (in fact, Flavio has a tentative version here: https://www.internalfb.com/diff/D50252204). My thought was:

(1) I think the Real / long-term solution is probably to take the entire forward graph that we traced out, and try to partition in two, so the first part is the region we want to compile, and the second part corresponds to all of the output views that we need to regenerate afterward. This is pretty fool-proof and should eliminate as_strided in 100% of cases. But it requires (a) turning on pre-dispatch functionalization in all paths of AOTAutograd (Tugsuu is working on pre-dispatch functionalization), and (b) Some pretty large changes to AOTAutograd. So if we do this I think it's at least a few months out.

(2) That approach might be doable for simple cases (a single call to view() or transpose()), but it doesn't feel like it really scales. multi-output view ops are more painful to handle (you have to stare at all of the graph outputs, check that their storages alias, and then check that their strides/storage_offsets map to a single unbind call). Handling e.g. noncontiguous outputs, or cases where only a subset of the views are returned will require more checking. I also don't think we can handle chains of views, e.g. x.unsqueeze_(0).transpose(2, 1).

All those extra checks will be more runtime overhead, so depending on how much code we put into the checks, we might want to trace the runtime wrapper into an FX graph to trace away the overhead (to be fair we should do this anyway).

(3) The current change is actually... pretty simple :) (well it's like ~10 lines of AOTAutograd code).

That said - if we want to add a one-off escape hatch for a particular case of unbind() in AOTAutograd, and we agree that we can just unwind it later when we eventually implement (1), I think I'm ok with that too. Let me know what you think!

@Chillee
Copy link
Contributor

Chillee commented Oct 19, 2023

I think it's reasonable to do this, although is this not also basically a one-off case for directly returning outputs of unbind?

All those extra checks will be more runtime overhead, so depending on how much code we put into the checks

I guess this seems easier to resolve :P Don't we have all the information we need at compile time?

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 20, 2023

I think it's reasonable to do this, although is this not also basically a one-off case for directly returning outputs of unbind?

Well.. mostly yeah 😛. To be fair though:

  • it handles all the multi-output view ops (chunk, split, etc)
  • it handles all the edge cases for them too without any extra work (non-contiguous, what if you unbind() to 10 tensors but only return a subset of them, what if you unbind() and then take more views off of those unbound tensors, etc)

@@ -377,6 +377,7 @@ def emit_view_functionalization_body(
"""

else:
is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nyeh, I think you have to be more selective about this. If hypothetically you have a multi-output view where the outputs alias each other, e.g., some sort of sliding window + unbind, then it wouldn't be right to do the logic you've introduced here, because if you did a mutation it would be necessary to error to prevent incorrect gradients from the other rules

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because if you did a mutation it would be necessary to error to prevent incorrect gradients from the other rules

I agree that this is more complicated in the overlapping case (Alban pointed out that even the unbind case can be non-overlapping, e.g. torch.ones(4).expand(10, 4).unbind(0)).

I would agree with this, but my thought was that autograd should never let this happen, because it doesn't allow you to change the grad_fn of multi-output views (it raises an error, or replaces the multi-output-view grad_fn with an error grad_fn).

I agree this does feel hairy though, so lmk if you think this isn't water-tight / we're risking some correctness issues slipping through.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're basically saying that a user will never pass torch.compile a program that doesn't run in eager mode. If this is true, I agree that the user cannot have written a naughty program. But we definitely have users passing malformed programs to torch.compile without having checked the eager version first...

# For a set of outputs of the graph that alias each other, o_1...o_k, consider:
# (1) They came from the same multi-outout view op, e.g. o_1, ..., o_k = intermediate.unbind(0)
# (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate),
# **at most** 1 can escape from the graph (e.g. there is not some other graph input/output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, is this at most one can escape, ON TOP of the original multi-output view op's outputs o_1, ... ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - if I have K multi-output view aliases of some tensor (say they all came from an unbind call), then this comment is saying that at most K+1 aliases are allowed to escape the graph

# (1) They came from the same multi-outout view op, e.g. o_1, ..., o_k = intermediate.unbind(0)
# (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate),
# **at most** 1 can escape from the graph (e.g. there is not some other graph input/output
# o_other, that aliases these outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this. I think condition (2) is too lax and it is making it hard for me to verify that this optimization is correct.

Here is a graph for which I think it is obviously correct to attach grad_fns to the outputs:

def f(x):
    y = x + 2
    return y.unbind()

Here is a graph where I am not sure:

def f(x):
    y = x + 2
    return y, y.unbind()

In the first graph, no outputs actually aliased each other (the unbind says they alias, but this is because autograd isn't smart enough to know that actually these views are guaranteed to reference disjoint spots of memory). In the second graph, the outputs DO alias, and so we have to make sure that if you mutate y, the grad_fns for y.unbind() get updated.

Maybe there is a way we can get the second case to work, but if the first case is good enough to solve the problem we're facing in prod, I would prefer to do it first, as it is much more obviously correct.

Copy link
Contributor Author

@bdhirsh bdhirsh Oct 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, after I made this PR I realized that the second case also shows up in prod sadly (returning both an intermediate, and its unbinded tensors).

So given that, I think our options are:

(1) convince ourselves that this is safe, even in the case where we return one extra alias (like the intermediate here).

(2) decide that this either violates safety, or this feels too dangerous to allow, and instead just add some one-off logic for the prod case to specifically handle regenerating unbind() at runtime.

Let me know what you think

@ezyang
Copy link
Contributor

ezyang commented Oct 23, 2023

I think I'm OK with the unsoundness as a stop-gap, on the road to FX wrapper'ifying the AOTAutograd wrapper code and I guess maybe running conventional multi-view ops instead of as_strided.

There are a few things that I would like to see:

  1. First, I'd like to see an xfail'ed test where we have a program that errors in eager, but doesn't error when you torch.compile. I just want this case durably recorded. When you remove the aliasing relationship more programs will run under torch.compile
  2. Second, I'd like to see the comments updated to account for the unsoundness, and also link to the doc where we're tracking the long term fix
  3. Third, I'd like to see some logging (info is probably good enough) when we hit these cases

I think if we actually run into problems due to the unsoundness in the real world, it will be OK because when you ablate aot autograd you will notice that the program starts failing.

…d outputs came from a multi_output_view"

Partially addresses #111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:
```
torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()
```

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`.

I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 25, 2023
…ame from a multi_output_view

ghstack-source-id: 311e2d889da8aa1702648384f34fa5ea42e99df8
Pull Request resolved: #111411
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 25, 2023

Updated the PR.

First, I'd like to see an xfail'ed test where we have a program that errors in eager, but doesn't error when you torch.compile. I just want this case durably recorded. When you remove the aliasing relationship more programs will run under torch.compile

When I added the test just now, I realized that we actually (correctly) get the same set of errors with eager vs compile in all of the examples that I could think of. It looks like autograd provides the same set of restrictions for both tensors coming from multi-output views, and tensors coming from custom autograd.Function. So for example, this code:

@torch.compile(backend="aot_eager")
def f(a):
    return list(a.unbind(0))

x = torch.ones(2, 2, requires_grad=True).clone()
y, z = f(x)
# autograd error
y.mul_(2)

Now raises this error:

RuntimeError: Output 0 of CompiledFunctionBackward is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

Second, I'd like to see the comments updated to account for the unsoundness, and also link to the doc where we're tracking the long term fix

Completely agreed. Given the above there are no longer any soundness issues that I'm aware of, but I still agree that in the medium-term we should still rip this out and replace it with the long term fix of graph partitioning (updated the comments)

Third, I'd like to see some logging (info is probably good enough) when we hit these cases

Agreed, logging added

@bdhirsh bdhirsh added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor-perf-test-nightly Trigger nightly inductor perf tests labels Oct 25, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2023

Warning: Unknown label ciflow/inductor-perf-test-nightly.
Currently recognized labels are

  • ciflow/binaries
  • ciflow/binaries_conda
  • ciflow/binaries_libtorch
  • ciflow/binaries_wheel
  • ciflow/inductor
  • ciflow/inductor-perf-compare
  • ciflow/mps
  • ciflow/nightly
  • ciflow/periodic
  • ciflow/rocm
  • ciflow/slow
  • ciflow/trunk
  • ciflow/unstable

Please add the new label to .github/pytorch-probot.yml

…d outputs came from a multi_output_view"

Partially addresses #111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:
```
torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()
```

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`.

I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 25, 2023
…ame from a multi_output_view

ghstack-source-id: 6d20e9f9bc4ab349733c87ecf84ee62e4861e5c7
Pull Request resolved: #111411
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 26, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
…ame from a multi_output_view (pytorch#111411)

Partially addresses pytorch#111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:
```
@torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()
```

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`.

I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.

Pull Request resolved: pytorch#111411
Approved by: https://github.com/ezyang
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/473/head branch October 29, 2023 14:23
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…ame from a multi_output_view (pytorch#111411)

Partially addresses pytorch#111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:
```
@torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()
```

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`.

I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.

Pull Request resolved: pytorch#111411
Approved by: https://github.com/ezyang
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ame from a multi_output_view (pytorch#111411)

Partially addresses pytorch#111081

This fixes the majority of the slowness from https://fb.workplace.com/groups/1405155842844877/permalink/7491314274228973/. In particular, the type of example that suffers the most perf-wise in AOTAutograd looks like this:
```
@torch.compile
def f(x):
    intermediate = x.mul(2)
    outs = intermediate.unbind(0)
    return *outs

x = torch.randn(50, 50, requires_grad=True)
outs = f(x)
sum(outs).sum().backward()
```

There are 50 output tensors in the above function, that all alias each other. AOTAutograd will dutifully exercise its intermediate base [logic](https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L294), and try to regenerate the aliases outside of the compiled `autograd.Function` at runtime, to ensure that the autograd engine is aware of the aliasing.

In this case, this will result in **50 AsStridedBackward nodes in the backward**, because we will fall back to using as_strided to generate each of those 50 outputs. The current PR as is (somewhat unsafely) ensures that the backward graph consists of a single `UnbindBackward`, or a call to `aten.cat()`.

I left a long comment in the code describing the situation, but the core idea is that **autograd does not let you mutate grad_fn of tensor aliases that come from multi-output views**. So if we have `k` outputs that alias each other, but `k-1` of them are aliases that came from multi-output views, then in eager mode, it would not be possible to mutate one of the aliases in a way that would change the grad_fn of any of the other aliases, without causing an error in the backward. So the claim I'm making is that if we hide this aliasing from the autograd engine, then it is impossible for the user to perform any mutations that would cause autograd metadata to diverge between torch.compile and eager in a way that isn't an error in eager mode.

To be fair, I think that taking the approach outlined in https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit would also help us avoid the as_strided calls in this particularly egregious case, **and** keep the autograd error messages. This relies on both pre-dispatch functionalization being fully hardened **and** adding some pretty invasive changes to AOTAutograd though, and is probably at least several months out.

Pull Request resolved: pytorch#111411
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants