-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Wrong gradient for slicings if parent nodes include undefined gradients #9688
Comments
Just wanted to second this issue. For my use case (kalman-filters) the idea is that some of our data may be missing and that our kalman-filter can naturally handle this. When computing the loss, it wasn't obvious that these two statements would have such different effects:
|
This is something worth digging into. Principally speaking, we should bring you the behavior you are expecting. @albanD do you think there's any interesting solutions we can come up with at the autograd engine level? |
From my point of view, this is the expected behavior and the correct answer. Here [0, 1] is the correct gradient for the indexing op: since the first input is not used, it gets 0 and the second gets 1. If you had a combined In your case, I guess you should always do the selection first? What is the point on doing ops to just ignore them afterwards? |
i'm closing the issue based on Alban's response, and reading the linked thread. |
Just as a data point, I believe that if we returned sparse tensors for the gradients of indexing ops, we could avoid the nan by simply operating on the non-sparse elements. If it's a good idea to do this, I'm less sure. |
@fmassa but that is not the correct gradient. For a sparse tensor, the elements that are not specified are 0s. And so their gradient for the sqrt op should be nan. |
@soumith, @albanD example
If we exclude coordinates of formal justification (@albanD)(this is less verbose as I would like it to be, but math in markdown without mathjax is a pain in the ... ) First of all what is your argumentation that the correct gradient of the mapping is I suggest (the convention) to use Assigning From this perspective I think this is an important issue and as PyTorch is converging to its 1.0 release one should maybe change the convention or at least document it with all its side-effects. ConclusionAre there any reasons why we should prefer |
Hey, The current specification is the following:For every elementary Function that is implemented, what is returned as the gradient of the output wrt the input is defined as follow where each is checked in this order:
When the user define a function that is a combination of these elementary Functions, the chain rule is used to get the full gradient. If you have Your conclusionSo actually the gradient that is used can only be nan if the forward function was giving nan as well.
import torch
def f(x):
x1 = x.clone()
x1[x1<0] = 0
# x1 contains all the positive values
x2 = x.clone()
x2[x2>0] = 0
# x2 contains all the negative values
return (x1**2).sqrt() - (x2**2).sqrt()
inp = torch.tensor([-2., -1., 0, 1, 2], requires_grad=True)
out = f(inp)
print("out", out) # [-2., -1., 0, 1, 2]
res = torch.autograd.grad(out.sum(), inp)[0]
print(res) # [1., 1., nan, 1., 1.] It's a weird way to write an identity, but that's ok. At the moment, if you ask for the gradient at 0 of this function, you will get Basically the chain rule allows you to shoot yourself in the foot if you try hard enough and will make differentiable functions impossible to differentiate. What cause the nan in your caseIn your example, the problem is that if you write:
Your example is easy to fixif you change: x = x.sqrt()
x = x[x > 0] by x = x[x > 0]
x = x.sqrt() For your code sample, then you get the gradient you expect at no cost ! It will actually be faster because instead of doing the sqrt op on the whole tensor before discarding part of it, you will actually do it only on the part that you care about. |
Is there anywhere in the documentation where it's stated that these two options, like those in your example (or in mine), will have such different behavior? As it is set-up right now, it feels very much like a "gotcha." The intuition it contradicts is something like:
I apologize for not stating this intuition more formally, but can you help me understand where this intuition gets things wrong? |
The "exclusion" operation, when backpropagating through it will give 0 gradient for the elements that were excluded. If you look at the function f: (x,y) -> y it's gradient df/dx = 0. Now the question is how To check that, you can print the gradients returned by you loss function by registering hooks on |
Here's a very simple example to make this more concrete:
I'm a little unclear what I'm looking for here. Again, the intuition is still that I've excluded the invalid gradients when constructing the Apologies for my confusion. |
Your |
@albanD Thanks for the verbose update :) It made things much clearer. So in conclusion, we can say that the mentioned behavior is a design choice when implementing a recursive differentiation framework. It produces (from my point of view) unintuitive behavior regarding "exclusion" but I guess that's the price to pay for intuitive behavior in other cases, e.g., your example. Some thoughts on your solution ....With recursive differentiation in mind its clear that exchanging
by
solves the issue. Would it be reasonable to propose an "exclusion" functionality where locally the multiplicative algebra is This is just a thought and may be not important, or a huge pain for maybe little gain. --cheers chris |
Hi, After some discussion with other people I think I have we found a simple answer to the problem: "The chain rule only works for differentiable functions, and you're working at a point where sqrt is not differentiable. Hell insues !". In your case the problem is not that # I did not run this code, there might be typos in it, sorry
# Your model
x = f_theta(x)
def my_excluding_functions(x_input):
# Compute the function
x = x_input.sqrt()
# Compute the exclusion mask
mask = x>0
# Mask the output
output = x[mask]
# Make sure the backward pass won't return nan
# We use hook to modify the gradient computed for
# grad_input and make sure it won't considered masked elements
def mask_hook(grad):
# Never modify the input given by a hook
grad = grad.clone()
# Set the gradients (don't multiply) of the part we did not keep to 0
grad[1-mask] = 0
return grad
x_input.register_hook(mask_hook)
return output
# Use the function that fix the gradients.
# The foward will be the same as:
# x = x.sqrt()
# x = x[x > 0]
# But the backward will mask out gradients so that
# whatever is returned by the first function, it will
# be zeroed out.
x = my_excluding_functions(x) Note that if the function is not element wise and change the input size, you can compute an input mask and an output mask, one to mask out the output during the forward pass and one to mask out gradients of the input during the backward pass. Hope this helps you make what you want ! |
@albanD: those explanations are very helpful-- thanks for your patience! |
@albanD The proposed solution does not solve the proposed problem (unless I've misunderstood something completely). To be clear, an example
In my understanding the proposed pattern solves the issue if we have detail information about
I totally understand that technically this may yield problems because mathematics usually does not care about its implementation ;). Comment: If one thinks of batch-wise application the problem is maybe more obvious. Let |
Issue description
If we have a calculation graph where a node includes values with undefined gradients and we de-select
those values, the resulting gradient is wrong.
Let
f
be a function not dependent on some input variablex
then the corresponding differentialdf / dx = 0
E.g., let
f(x_1, x_2) = x_2
thend f / d x_1 = 0
.However the following code snippet does not comply to this behavior:
Code example
Which gives ...
but I'd expect
Additional info
If we select first and then do the operation which could lead to an undefined gradient
it works as expected:
which gives
However in the use-case where this occurs this is not an option. It would be crucial that I can
prune the coordinates of the output to those where the gradient is defined and then use just those
in the differentiation.
System Info
The text was updated successfully, but these errors were encountered: