-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marking variable not required in backward calculation if it is not needed #56500
Comments
Yes, you have That is what you want right? |
@albanD yes! thanks for this! |
@albanD sorry to reopen the issue. It seems like
The calculation of pytorch/torch/csrc/autograd/python_function.cpp Lines 491 to 493 in 4575028
|
Yes, that is expected I am afraid. I am not sure if we have this information at all... We would need to investigate. |
@albanD There is a possible fix for this with Python which I wrote in colab: https://colab.research.google.com/drive/1wEgkI4IsUQBXUEtgFa-ai3yOwJnuvOGD?usp=sharing (you can skip the "problem" part). However, this requires a global variable to store the derived variables and an inefficient way to check the dependency of the variable w.r.t. the derived variables. |
The problem is that by design, we never have this information :/ As an example when you do The other issue is that the existing flag |
Not sure if this would solve your issue but we did have a PR that solves something similar #52180. So there might be a way you could try to do this, but depending your specific graph it might not always work. Lets say you have some You can think of topological number as the longest possible path between your node and any leaf node. So if you are a leaf node, your topo_nr is always 0. This is a valid topological ordering because we have the invariant that any parent's topo_nr is one greater than the max topo_nr of its children. So to apply this to our previous example, if you did |
My understanding is that there is the extra assumption here that |
Oh right, I guess that is unrelated to the feature proposal. Since the motivation is for making higher-order stuff more efficient I figured this might be useful to know though. :P One thing we could do though is to enhance input_needs_grad with information we gained from topo_nr, i.e., if any of the inputs have topo_nr lower than min_topo_nr have the boolean at that index to be False |
@soulitzer thanks for the info. it seems that I can use the topo_nr to make the checking more efficient (i.e. helping solve the problem number 2 from the google colab). Is |
The most horrible way I can think doing it is:
@soulitzer what do you think? |
@mfkasim1 No, not currently. Should it be though? I'm thinking it should at least be exposed as @albanD (edit: realized I completely misread some code lol) updating needs_input_grad is more straightforward than I initially thought. But still don't see any way around that thread safety issue you point out, i.e., if multiple threads try to backward over the same node at the same time wrt different inputs. |
Has there been any progress on this issue? I'm currently running into this issue as well, and it seems like as of version 1.11, it is still happening the same way as described here |
### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: #56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark) @zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc @zasdfgbnm @albanD Pull Request resolved: #82544 Approved by: https://github.com/soulitzer
Summary: ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: pytorch/pytorch#56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following pytorch/pytorch#56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, #989, so we are using v0.1.1 for benchmark) zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc zasdfgbnm albanD X-link: pytorch/pytorch#82544 Approved by: https://github.com/soulitzer Reviewed By: seemethere Differential Revision: D38643340 fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
Summary: ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: #56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark) zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc zasdfgbnm albanD Pull Request resolved: #82544 Approved by: https://github.com/soulitzer Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/382ef1fda75dfce07c37a920e908ce96d99bf970 Reviewed By: seemethere Differential Revision: D38643340 fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
🚀 Feature
I wonder if there is a feature in pytorch's
autograd.Function
where it can detect which variables' grads are required and not required.For example:
Gradient of
inp2
is not required in the grad calculation, but there is no way to know ifinp2
grad is required or not.It would be good if there is a function that can be called to see if
inp2
grad is required, e.g.inp2.grad_is_needed
.Motivation
This problem comes up when I am doing multiple gradients (higher order grad) where the gradient calculation spends a lot of time in unnecessary part of the graph. Similar to the reason in #39784.
Pitch
This feature would enable optimizing the gradient calculation for large graph, especially for computation that requires higher order gradients.
Alternatives
For this simple case, I can just mark
inp2
as not requiring grad, but for this example below, there is no way other than just accepting it (i.e. the fact that it calculates unnecessary gradients).Additional context
This is similar to issue #39784 where it has been solved since February, and I am using nightly version
1.9.0.dev20210409
to run the example above.cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7
The text was updated successfully, but these errors were encountered: