Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marking variable not required in backward calculation if it is not needed #56500

Open
mfkasim1 opened this issue Apr 20, 2021 · 13 comments
Open
Labels
actionable feature A request for a proper, new feature. module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mfkasim1
Copy link
Contributor

mfkasim1 commented Apr 20, 2021

🚀 Feature

I wonder if there is a feature in pytorch's autograd.Function where it can detect which variables' grads are required and not required.

For example:

import torch

class Mul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp1, inp2):
        ctx.save_for_backward(inp1, inp2)
        return inp1 * inp2
    @staticmethod
    def backward(ctx, grad_out):
        inp1, inp2 = ctx.saved_tensors
        grad_inp1 = None
        grad_inp2 = None
        if inp1.requires_grad:  # and inp1.grad_is_needed:
            print("Calculating grad_inp1")
            grad_inp1 = grad_out * inp2
        if inp2.requires_grad:  # and inp2.grad_is_needed:
            print("Calculating grad_inp2")
            grad_inp2 = grad_out * inp1
        return grad_inp1, grad_inp2

dtype = torch.float64
inp1 = torch.tensor(1.0, dtype=dtype).requires_grad_()
inp2 = torch.tensor(2.0, dtype=dtype).requires_grad_()

out = Mul.apply(inp1, inp2)
# grad_inp2 is not required
gradinp1 = torch.autograd.grad(out, inp1, create_graph=True)
# prints:
# Calculating grad_inp1
# Calculating grad_inp2

Gradient of inp2 is not required in the grad calculation, but there is no way to know if inp2 grad is required or not.
It would be good if there is a function that can be called to see if inp2 grad is required, e.g. inp2.grad_is_needed.

Motivation

This problem comes up when I am doing multiple gradients (higher order grad) where the gradient calculation spends a lot of time in unnecessary part of the graph. Similar to the reason in #39784.

Pitch

This feature would enable optimizing the gradient calculation for large graph, especially for computation that requires higher order gradients.

Alternatives

For this simple case, I can just mark inp2 as not requiring grad, but for this example below, there is no way other than just accepting it (i.e. the fact that it calculates unnecessary gradients).

import torch

class Pow(torch.autograd.Function):
    @staticmethod
    def forward(ctx, base, pow):
        out = base ** pow
        ctx.save_for_backward(base, pow, out)
        return out
    @staticmethod
    def backward(ctx, grad_out):
        base, pow, out = ctx.saved_tensors
        grad_base = None
        grad_pow = None
        if base.requires_grad:  # and base.grad_is_needed:
            print("Calculating grad_base")
            grad_base = grad_out * pow * out / base
        if pow.requires_grad:  # and pow.grad_is_needed:
            print("Calculating grad_pow")
            grad_pow = grad_out * out * torch.log(base)
        return grad_base, grad_pow

dtype = torch.float64
base = torch.tensor(1.0, dtype=dtype).requires_grad_()
pow = torch.tensor(2.0, dtype=dtype).requires_grad_()

out = Pow.apply(base, pow)
gp = torch.autograd.grad(out, pow, create_graph=True)
# calculating grad_pow is not required below
gbgp = torch.autograd.grad(gp, base, create_graph=True)

Additional context

This is similar to issue #39784 where it has been solved since February, and I am using nightly version 1.9.0.dev20210409 to run the example above.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

@heitorschueroff heitorschueroff added feature A request for a proper, new feature. module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 20, 2021
@albanD
Copy link
Collaborator

albanD commented Apr 20, 2021

Yes, you have ctx.needs_input_grad that is a list of booleans that will tell you if the input at that index requires grad to be computed or not.
You can see more details in the custom Function doc: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function.backward
Or in the extending Note in the doc.

That is what you want right?

@mfkasim1
Copy link
Contributor Author

@albanD yes! thanks for this!

@mfkasim1
Copy link
Contributor Author

mfkasim1 commented Apr 21, 2021

@albanD sorry to reopen the issue. It seems like ctx.needs_input_grad does not behave like what I expected.
For the simple example above:

import torch

class Mul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp1, inp2):
        ctx.save_for_backward(inp1, inp2)
        return inp1 * inp2
    @staticmethod
    def backward(ctx, grad_out):
        inp1, inp2 = ctx.saved_tensors
        grad_inp1 = None
        grad_inp2 = None
        if ctx.needs_input_grad[0]:
            print("Calculating grad_inp1")
            grad_inp1 = grad_out * inp2
        if ctx.needs_input_grad[1]:
            print("Calculating grad_inp2")
            grad_inp2 = grad_out * inp1
        return grad_inp1, grad_inp2

dtype = torch.float64
inp1 = torch.tensor(1.0, dtype=dtype).requires_grad_()
inp2 = torch.tensor(2.0, dtype=dtype).requires_grad_()

out = Mul.apply(inp1, inp2)
# grad_inp2 is not required
gradinp1 = torch.autograd.grad(out, inp1, create_graph=True)

The calculation of grad_inp2 is still performed in the example above. It seems like ctx.needs_input_grad[1] is False only if inp2.requires_grad==False

PyObject* needs_grad = tensor.requires_grad() ? Py_True : Py_False;
Py_INCREF(needs_grad);
PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);

@mfkasim1 mfkasim1 reopened this Apr 21, 2021
@albanD
Copy link
Collaborator

albanD commented Apr 21, 2021

Yes, that is expected I am afraid.
This is only aware of the links that are in the graph. Not the subset required when calling autograd.grad.

I am not sure if we have this information at all... We would need to investigate.

@mfkasim1
Copy link
Contributor Author

@albanD There is a possible fix for this with Python which I wrote in colab: https://colab.research.google.com/drive/1wEgkI4IsUQBXUEtgFa-ai3yOwJnuvOGD?usp=sharing (you can skip the "problem" part).
In short, we need to check if the variable has a dependency w.r.t. the derived variables (i.e. the inputs in torch.autograd.grad).

However, this requires a global variable to store the derived variables and an inefficient way to check the dependency of the variable w.r.t. the derived variables.
Do you have any idea to solve this?
If the first problem can be fixed, I think I can live with it.

@albanD
Copy link
Collaborator

albanD commented Apr 22, 2021

The problem is that by design, we never have this information :/
In particular, the knowledge is a Node needs to be executed is local to that Node in the engine. And even if a Node is not going to get executed, you might want to compute the gradients wrt to its outputs (because you want to capture them with autograd.grad for example). We have to do it this way because Tensors don't exist in the graph, only their grad_fn.

As an example when you do torch.autograd.grad(out, inp1), we actually put a capture on inp1.grad_fn inputs even though inp1.grad_fn won't be executed.

The other issue is that the existing flag needs_input_grad is actually set during creation (based on the inputs requires grad). And there is no effect of the current backward run on the content of this flag or any API that exist for the engine to tell the Node that for this particular run, some inputs don't need gradient.

@soulitzer
Copy link
Contributor

soulitzer commented Apr 22, 2021

Not sure if this would solve your issue but we did have a PR that solves something similar #52180. So there might be a way you could try to do this, but depending your specific graph it might not always work. Lets say you have some a = expensive_calculation(x, y, z). If you don't want to execute a.grad_fn but a.grad_fn is part of your graph i.e. you have b = a + c; out = loss(b), you need to make sure a.grad_fn has a topological_nr lower than any node you wish to obtain the grad of. What does this mean exactly?

You can think of topological number as the longest possible path between your node and any leaf node. So if you are a leaf node, your topo_nr is always 0. This is a valid topological ordering because we have the invariant that any parent's topo_nr is one greater than the max topo_nr of its children.

So to apply this to our previous example, if you did autograd.grad(out, (b,)), a.grad_fn wouldn't execute because we know for sure according to the properties of topo_nr that a.grad_fn.topo_nr < b.grad_fn.topo_nr. But if you also want grad w.r.t. c, you'd need to make sure c.grad_fn.topo_nr > a.grad_fn.top_nr to ensure that a.grad_fn doesn't execute.

@albanD
Copy link
Collaborator

albanD commented Apr 22, 2021

My understanding is that there is the extra assumption here that a.grad_fn == b.grad_fn ;)
See that the original example in a custom Function.

@soulitzer
Copy link
Contributor

Oh right, I guess that is unrelated to the feature proposal. Since the motivation is for making higher-order stuff more efficient I figured this might be useful to know though. :P

One thing we could do though is to enhance input_needs_grad with information we gained from topo_nr, i.e., if any of the inputs have topo_nr lower than min_topo_nr have the boolean at that index to be False

@mfkasim1
Copy link
Contributor Author

@soulitzer thanks for the info. it seems that I can use the topo_nr to make the checking more efficient (i.e. helping solve the problem number 2 from the google colab). Is topo_nr available from the Python's interface?

@albanD
Copy link
Collaborator

albanD commented Apr 23, 2021

The most horrible way I can think doing it is:

  • During the Node preparation, get the current graph_task that we're running from:
    static thread_local std::shared_ptr<GraphTask> current_graph_task = nullptr;
  • From the graph task, if graph_task.exec_info_ does not exist, that means everything needs to be executed and so do as usual
  • Otherwise, for every "next_functions", check in graph_task.exec_info_ for that Node. And for each Node, see if should_execute() is True.
  • This will tell you if, for that particular call to autograd.grad(), you need to compute the gradients for it or not. (that gradient will be discarded by the engine here anyways in this case)
  • for custom Function, update input_needs_grad for that one run only (thread safety issue here!)
  • for regular Nodes, we might want to update the should_compute_output as well. That will allow us to trim graphs for codegen defined functions as well :)

@soulitzer what do you think?

@soulitzer
Copy link
Contributor

soulitzer commented Apr 23, 2021

@mfkasim1 No, not currently. Should it be though? I'm thinking it should at least be exposed as grad_fn._topo_nr for debugging purposes since how the topo_nr's of the nodes interact does make an impact on performance. @albanD what do you think?

@albanD (edit: realized I completely misread some code lol) updating needs_input_grad is more straightforward than I initially thought. But still don't see any way around that thread safety issue you point out, i.e., if multiple threads try to backward over the same node at the same time wrt different inputs.

@MatthiasKohl
Copy link

Has there been any progress on this issue? I'm currently running into this issue as well, and it seems like as of version 1.11, it is still happening the same way as described here

pytorchmergebot pushed a commit that referenced this issue Aug 11, 2022
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: #56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark)

@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc @zasdfgbnm @albanD
Pull Request resolved: #82544
Approved by: https://github.com/soulitzer
facebook-github-bot pushed a commit to pytorch/functorch that referenced this issue Aug 12, 2022
Summary:
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: pytorch/pytorch#56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following pytorch/pytorch#56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, #989, so we are using v0.1.1 for benchmark)

zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc zasdfgbnm albanD

X-link: pytorch/pytorch#82544
Approved by: https://github.com/soulitzer

Reviewed By: seemethere

Differential Revision: D38643340

fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
facebook-github-bot pushed a commit that referenced this issue Aug 12, 2022
Summary:
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: #56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following #56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark)

zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc zasdfgbnm albanD

Pull Request resolved: #82544
Approved by: https://github.com/soulitzer

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/382ef1fda75dfce07c37a920e908ce96d99bf970

Reviewed By: seemethere

Differential Revision: D38643340

fbshipit-source-id: 346de0e0971363441c6d06dc83601e0297d5ccc8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable feature A request for a proper, new feature. module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants