Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jacobians computed by autograd.functional.jacobian with compute_graph sometimes set requires_grad True #46918

Open
manuelbaum opened this issue Oct 27, 2020 · 9 comments
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@manuelbaum
Copy link

manuelbaum commented Oct 27, 2020

馃悰 Bug

The Jacobians computed by torch.autograd.functional.jacobian sometimes have require_grad=True when compute_graph=True is given, but not always. It depends on the concrete function that is input to the Jacobian computation.

To Reproduce

Steps to reproduce the behavior:

`import torch

def f1(x):
return x

def f2(x):
return x*x

x0 = torch.tensor([4., 2., 2.])
print("x0.requries_grad", x0.requires_grad)
print("--- Returning just x with create graph ---")

J1 = torch.autograd.functional.jacobian(f1, (x0), create_graph=True)
print("J1.requires_grad", J1.requires_grad)

print("--- Returning x*x with create graph ---")
x0 = torch.tensor([4., 2., 2.])
J2 = torch.autograd.functional.jacobian(f2, (x0), create_graph=True)
print("J2.requires_grad", J2.requires_grad)

print("--- Returning x*x without create graph ---")
x0 = torch.tensor([4., 2., 2.])
J3 = torch.autograd.functional.jacobian(f2, (x0))
print("J3.requires_grad", J3.requires_grad)`

The output here is

x0.requries_grad False
--- Returning just x with create graph ---
J1.requires_grad False
--- Returning xx with create graph ---
J2.requires_grad True
--- Returning x
x without create graph ---
J3.requires_grad False

Expected behavior

In the example above, J2.requires_grad should be False.

Environment

PyTorch version: 1.8.0.dev20201027
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.17.4
[pip3] torch==1.8.0.dev20201027
[pip3] torchvision==0.8.0.dev20201026
[conda] Could not collect

Additional context

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved

@mrshenli mrshenli added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 27, 2020
@albanD
Copy link
Collaborator

albanD commented Oct 29, 2020

Hi,

This is expected as mentioned in the doc: Note that when strict is False, the result can not require gradients or be disconnected from the inputs.
This is true for the autograd in general where independent gradients can be represented by disconnected graph (and so result that does not require gradients) or a graph that returns a Tensor full of 0s or None gradients.

@manuelbaum
Copy link
Author

I agree that it is expected that the results should not require gradients, but the problem is that in this case they do! The unexpected behavior that the result has require_grad=True is the bug that I am referring to.

@albanD
Copy link
Collaborator

albanD commented Oct 29, 2020

I am not sure what you expect?
The second example to have requires_grad=False because x0.requires_grad=False?
I am afraid this one falls in the second category of what I mentionned above: "a graph that returns a Tensor full of 0s or None gradient".

@manuelbaum
Copy link
Author

As I wrote above, I expect that J2.requires_grad should be False, but it is True.

J3_requires_grad is indeed False, as expected. I just included that to show that the behavior depends on the concrete function used for the jacobian computation.

@albanD
Copy link
Collaborator

albanD commented Oct 30, 2020

As I wrote above, I expect that J2.requires_grad should be False, but it is True.

I am afraid this is expected as an independent gradient can show up as: "a graph that returns a Tensor full of 0s or None gradient".
And if we create a graph, then the Tensors will have require_grad=True.

I do agree that this is not the state we would have in an ideal world but we cannot detect what is the reason for these Tensors to require gradients (side effect of the computation we do or a Tensor that is used internally by your function that requires gradients). So we have to just return them as is.

Note that in any case, the gradient will be "correct" (under the definition that no graph / None / Tensor full of 0s are the same thing).

@manuelbaum
Copy link
Author

I'm sorry, but I don't really understand this. None of the input variables requires_grad, but the output of the jacobian computation does require_grad.

It's not a leaf variable, so i can't manually set requires_grad=False. But also .detach() is not an option, because this jacobian computation is part of a function that should be differentiable itself.

This is pretty annoying to me because the (from my perspective) randomly occurring requires_grad flag leads to huge computation graphs and me having to very carefully think about where to break gradient flow in a BPTT setting.

@albanD
Copy link
Collaborator

albanD commented Nov 11, 2020

I'm sorry, but I don't really understand this. None of the input variables requires_grad, but the output of the jacobian computation does require_grad.

The problem here is that you set create_graph=True. The reason for this flag is to be able to differentiate through the computation of the jacobian. So the jacobian will be computed in a differentiable manner and will require gradients.

I am not sure to understand your sentence as you say "None of the input variables requires_grad" but "this jacobian computation is part of a function that should be differentiable itself". If no input requires gradient, then no backward needs to run for that jacobian? So you could just set create_graph=False?
If the Tensor you want to compute gradients for are captured by side effect during the jacobian computation itself (from the state of your model maybe?) then you actually want the output jacobian to require gradients when create_graph=True. Again, if you don't want any gradient to be backpropagated for this jacobian (even though it uses some Tensors that require gradients by side effect), then you can just set create_graph=False.

Does this make it clearer?
If not, do you have a code sample of what you're trying to do and why it is problematic in your case?

@manuelbaum
Copy link
Author

manuelbaum commented Nov 12, 2020

First of all, thanks a lot albanD for taking the time for this. I highly, highly appreciate that! :)

I understand what you write about the create_graph flag. I think your second paragraph helps get to the core of this. I am in fact not using backward() at the moment, because I'm not doing parameter learning yet. I am building the function / model that that's supposed to be differentiable. This model includes various Jacobian computations "in it's forward function" (e.g. it involves an extended Kalman filter).

I tried to create a minimum example to illustrate the problem:

import torch
import time

def f(x):
    return torch.sin(x*x)

def g(x, fun, create_graph):
    J = torch.autograd.functional.jacobian(fun, x, create_graph=create_graph)
    return torch.sigmoid(J.mv(x+1)) # some computation involving J

def h(x, fun, create_graph):
    J = torch.autograd.functional.jacobian(fun, x, create_graph=create_graph)
    return torch.tanh(J.mv(x*2)) # some computation involving J

T = 100

g_partial_with_create_graph = lambda x: g(x, f, True)

x = torch.tensor([1., 2., 3.])
print("x:",x)
print("x.requires_grad:", x.requires_grad)

for t in range(T):
    t0 = time.time()
    x = h(x, g_partial_with_create_graph, True) # create graph
    # x = x.detach()
    ### x = x.detach() would do the trick here, but you have to be extremely careful, because you can really only do that
    ### on the outmost code-level. if this code occurs inside of a function that's supposed to be differentiable, then
    ### gradient flow breaks
    t1 = time.time()

print("Duration of last iteration:", t1-t0)
print("x:",x)
print("x.requires_grad:", x.requires_grad)

The problem is that this code becomes extremely slow as the loop rolls out (and with more complex functions in reality). This is because the output of h requires_grad (although it's not a child of any other variable that requires_grad [from the user perspective]). This leads to a large computation graph that spans all the time steps.

  • If I set create_graph=False in any place then the result is wrong because Jacobians are not computed correctly
  • If I set create_graph=True, x.requires_grad starts to spontaneously become True without a reason from user perspective
  • I can stop the requires_grad flag from "randomly spreading" by using detach() in the right place, but then I really need to know that I am in the outmost level of code that doesn't need a gradient to flow through. This is a problem for reusability.

@albanD
Copy link
Collaborator

albanD commented Nov 13, 2020

Thanks for the code sample that makes things much clearer!

I think what you want to change is this line: x = h(x, g_partial_with_create_graph, True) to actually set the create_graph=False. At this level, you don't actually want to create the graph. Only the jacobian call done by g_partial_with_create_graph needs create_graph=True because this one is differentiated.

Another approach, if you know that you're in the case where you want to create the graph iif the input you give requires grad (an assumption that we cannot make in general unfortunately :/ ) is to change all your call to jacobian to do create_graph=x.requires_grad() and remove the create_graph completely from your APIs:

def g(x, fun):
    J = torch.autograd.functional.jacobian(fun, x, create_graph=x.requires_grad)
    return torch.sigmoid(J.mv(x+1)) # some computation involving J

def h(x, fun):
    J = torch.autograd.functional.jacobian(fun, x, create_graph=x.requires_grad)
    return torch.tanh(J.mv(x*2)) # some computation involving J

That will simplify your life quite a bit I think.
But be careful never to get something that requires grad that is not in x. Otherwise, you won't get gradients for it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants