Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Wrong gradient for slicings if parent nodes include undefined gradients #9688

Closed
c-hofer opened this issue Jul 22, 2018 · 16 comments
Closed

Comments

@c-hofer
Copy link

c-hofer commented Jul 22, 2018

Issue description

If we have a calculation graph where a node includes values with undefined gradients and we de-select
those values, the resulting gradient is wrong.
Let f be a function not dependent on some input variable x then the corresponding differential
df / dx = 0
E.g., let f(x_1, x_2) = x_2 then d f / d x_1 = 0.

However the following code snippet does not comply to this behavior:

Code example

import torch

point = torch.tensor([0, 1], requires_grad=True, dtype=torch.float)

x = point
x = x.sqrt()
l = x[1]

l.backward()

point.grad

Which gives ...

tensor([   nan, 0.5000])

but I'd expect

tensor([ 0.0000, 0.5000])

Additional info

If we select first and then do the operation which could lead to an undefined gradient
it works as expected:

import torch

point = torch.tensor([0, 1], requires_grad=True, dtype=torch.float)

x = point
x = x[1]
x = x.sqrt()

l = x
l.backward()

point.grad

which gives

tensor([0.0000, 0.5000])

However in the use-case where this occurs this is not an option. It would be crucial that I can
prune the coordinates of the output to those where the gradient is defined and then use just those
in the differentiation.

System Info

PyTorch version: 0.5.0a0+3799b10
Is debug build: No
CUDA used to build PyTorch: 9.2.148

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.11.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.2.148
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti

Nvidia driver version: 396.37
cuDNN version: Probably one of the following:
/usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn.so.7.1.4
/usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] numpy (1.14.5)
[pip] torch (0.5.0a0+3799b10)
[conda] magma-cuda90              2.3.0                         1    pytorch
[conda] torch                     0.5.0a0+3799b10           <pip>
@jwdink
Copy link

jwdink commented Sep 22, 2018

Just wanted to second this issue. For my use case (kalman-filters) the idea is that some of our data may be missing and that our kalman-filter can naturally handle this. When computing the loss, it wasn't obvious that these two statements would have such different effects:

# leads to unexpected nan gradients:
loss = pointwise_lossfun(pred, actual)
loss[~torch.isnan(loss)].mean().backward()

# works:
loss = pointwise_lossfun(pred[~torch.isnan(actual)], actual[~torch.isnan(actual)])
loss.mean().backward()

@soumith
Copy link
Member

soumith commented Nov 2, 2018

This is something worth digging into. Principally speaking, we should bring you the behavior you are expecting.
The very likely reason for this behavior is the fact that we dont handle masked gradients separately from zeroed gradients. Hence, when we send a grad_output with [0, 1], grad_sqrt(0) turns into nan.

@albanD do you think there's any interesting solutions we can come up with at the autograd engine level?
One thing that seems easy to do for simple pointwise cases is that we mask grad_input as well. But that doesn't work for cases where grad_input starts changing shape

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2018

From my point of view, this is the expected behavior and the correct answer.
This is the same discussion as this one. The norm accept a subgradient of 0 at 0, but the decomposed version that take the square root of the sum of squares will return nan.

Here [0, 1] is the correct gradient for the indexing op: since the first input is not used, it gets 0 and the second gets 1.
Then in the composition with sqrt will give you nan for the first entry.

If you had a combined Function that was doing both ops (similar in spirit to the norm function above), then yes you could return [0,1]. But for the not-merged function, this is the correct result.

In your case, I guess you should always do the selection first? What is the point on doing ops to just ignore them afterwards?

@soumith
Copy link
Member

soumith commented Nov 2, 2018

i'm closing the issue based on Alban's response, and reading the linked thread.

@soumith soumith closed this as completed Nov 2, 2018
@fmassa
Copy link
Member

fmassa commented Nov 2, 2018

Just as a data point, I believe that if we returned sparse tensors for the gradients of indexing ops, we could avoid the nan by simply operating on the non-sparse elements. If it's a good idea to do this, I'm less sure.

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2018

@fmassa but that is not the correct gradient. For a sparse tensor, the elements that are not specified are 0s. And so their gradient for the sqrt op should be nan.
I agree that this is not convenient :D But if you want correct gradients, you don't really have a choice.

@c-hofer
Copy link
Author

c-hofer commented Nov 3, 2018

@soumith, @albanD
I'm sorry that I may have stated the problem confusingly. My concern is NOT the gradient for sqrt(0).
Its just the symptom of a deeper issue with PyTorch's interpretation of the differential operator. The following is a little lengthy and may seem not of practical value hence I will motivate it with an example ;)

example

import torch
import torch.nn as nn

class ParametrizedFunction(nn.Module):
    def __init__(self):
        super().__init__()
        self.theta = nn.Parameter(torch.tensor(1.0))
        
    def forward(self, x):
        return x - self.theta.expand_as(x)

# that's what I want to do ... (case 1)
f_theta = ParametrizedFunction()
x = torch.tensor([float(i) for i in range(10)])
x = f_theta(x) # x == tensor([-1.,  0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.])
x = x.sqrt()
x = x[x > 0]
x = x.sum()
x.backward()

f_theta.theta.grad # tensor(nan.) -> I did not expect that!

# that's what I have to do to get what I want to do ... (case 2)
f_theta   = ParametrizedFunction()
x_initial = torch.tensor([float(i) for i in range(10)])

x    = f_theta(x_initial)
x    = x.sqrt()
mask = (x > 0)

x = x_initial[mask]
x = f_theta(x)  # this is redundant and may be expensive!
x = x.sqrt()
x = x.sum()
x.backward()

f_theta.theta.grad # tensor(-2.1857) -> now I'm happy but not efficient ;)

If we exclude coordinates of f_theta(x_initial) which yield an undefined gradient, internally the gradient is nevertheless updated with the value nan.
My concern is that the value should be 0 as the above example then would also work in the first case.

formal justification (@albanD)

(this is less verbose as I would like it to be, but math in markdown without mathjax is a pain in the ... )

First of all what is your argumentation that the correct gradient of the mapping is nan? The differential operation of the mapping f(x) = sqrt(x) for the value 0 in the first (and only) coordinate is simply not defined. This does not mean its value is nan.
From my perspective PyTorch uses the the convention that the symbol nan is assigned if the differential operation is not defined.
Assigning symbols to define otherwise undefined values is common practice in mathematics. For example in measure theory its natural to have the symbol inf which algebraically integrates in the real numbers by setting inf + x = inf + x = inf ( inf - inf is not defined). The motivation for this convention is to formally simplify things.
However, usually this assigning symbols to undefined values tries to ensure that those values "behave naturally" (see example from measure theory above).

I suggest (the convention) to use 0 as the value for undefined differential operations instead of nan.
At least after slicing operations, as nan yields somehow "unnatural" behaviour, which is nicely reflected in
the code example I stated at the very top.
To be precise, what is happening above is that we contradict the common convention that
d f / dx = 0 if f does not depend on x.

Assigning 0 to undefined differential operations would comply to this behaviour if, e.g., f'(0) is not defined but f(0) is defined.
As we automatically sum the gradients in the grad property, nan is an somehow unfortunate choice as x + nan = nan. I totally understand that from an implementation perspective this may be convinient as, e.g., we get an integrated check if one of the involved differentation operations is not defined. However, it has the mentioned side-effects which are from my point of view unwanted.

From this perspective I think this is an important issue and as PyTorch is converging to its 1.0 release one should maybe change the convention or at least document it with all its side-effects.
I'm a mathematician and at least for me this behavior was unexpected and did not feel natural.
However I am not deep into the PyTorch sources and do not now all the implications.

Conclusion

Are there any reasons why we should prefer nan over 0 for undefined differential operations, which
are stronger than the drawbacks (for example the redundant call in the code example)?

@albanD
Copy link
Collaborator

albanD commented Nov 3, 2018

Hey,
This answer grew more than I expected. It contains the current behavior that we give to the autograd (slightly different from what you were thinking), my point of view on your conclusion, the detail of why you get nan in your case and a simple fix that give you the right gradients at no cost (actually faster than the version that gives nan).

The current specification is the following:

For every elementary Function that is implemented, what is returned as the gradient of the output wrt the input is defined as follow where each is checked in this order:

  • If the forward function is not defined (sqrt(-1) or log(-1) for example) then this is undefined. Most function will return nan, but for performance reasons, some functions will return not-nan values (log for example).
  • If a gradient exist at the current point, return it
  • If the function is convex, return one of it's subgradients at the current point (absolute value at 0 for example)
  • Define the gradient at the current point by continuity (note that inf is possible here), if two values are possible, pick one arbitrarily (for example clamp or sqrt)
  • No other rule is needed as all of our functions are differentiable almost everywhere.

When the user define a function that is a combination of these elementary Functions, the chain rule is used to get the full gradient. If you have f(x) = g(h(x)), the engine cannot consider the properties of df/dx as it does not know about it, it will only work with dg/dx and dh/dx where g and h are elementary functions.

Your conclusion

So actually the gradient that is used can only be nan if the forward function was giving nan as well.
In your case, the chain rule make a nan appear because the gradient of the input is 0*inf (see next point).
I don't think we want to give 0 for all undefined operations as:

  • I think it makes sense to define the gradient by continuity if there is a single point where the function is not differentiable
  • Forcing the gradient to be 0 does not make sense in all case even though it does in yours. Just looking at sqrt at 0, consider the sample below:
import torch
def f(x):
    x1 = x.clone()
    x1[x1<0] = 0
    # x1 contains all the positive values
    x2 = x.clone()
    x2[x2>0] = 0
    # x2 contains all the negative values
    return (x1**2).sqrt() - (x2**2).sqrt()

inp = torch.tensor([-2., -1., 0, 1, 2], requires_grad=True)
out = f(inp)
print("out", out) # [-2., -1., 0, 1, 2]

res = torch.autograd.grad(out.sum(), inp)[0]
print(res)  # [1., 1., nan, 1., 1.]

It's a weird way to write an identity, but that's ok. At the moment, if you ask for the gradient at 0 of this function, you will get nan. Meaning that given the implementation and using the chain rule, we cannot compute the gradient at this point.
If now we set the gradient of the square root at 0 to 0, then this function will have a gradient of 0 at 0. For an identity function you expect 1.
At the moment in the worst case we return nan for something that could have had a proper gradient. But we always return correct gradients when it's not nan. With your proposition, the autograd engine could return a well formed gradient value that is wrong (in this example, 0 instead of 1).

Basically the chain rule allows you to shoot yourself in the foot if you try hard enough and will make differentiable functions impossible to differentiate.
Unfortunately, I don't see a good way to reduce this problem while always returning valid gradients if we return one (not-nan).

What cause the nan in your case

In your example, the problem is that if you write: y = x.sqrt() and out = y[y>0]. Then for the values of x <= 0, y will be 0 or nan. And so the gradient of y (with respect to a given quantity that i will call loss) will be 0 for these entries. the chain rule for the gradient of x will then give you dloss/dx = dloss/dy * dy/dx = :

  • 0 * nan if x<0
  • 0 * inf if x == 0
  • 1 / (2 * sqrt(x)) if x > 0

Your example is easy to fix

if you change:

x = x.sqrt()
x = x[x > 0]

by

x = x[x > 0]
x = x.sqrt()

For your code sample, then you get the gradient you expect at no cost ! It will actually be faster because instead of doing the sqrt op on the whole tensor before discarding part of it, you will actually do it only on the part that you care about.

@jwdink
Copy link

jwdink commented Nov 3, 2018

Is there anywhere in the documentation where it's stated that these two options, like those in your example (or in mine), will have such different behavior? As it is set-up right now, it feels very much like a "gotcha." The intuition it contradicts is something like:

  • When I perform an element-wise operation, I expect to have element-wise gradients.
  • If I later exclude the invalid elements before aggregating, then I have excluded those invalid gradients. The result of my aggregation should have a valid gradient.

I apologize for not stating this intuition more formally, but can you help me understand where this intuition gets things wrong?

@albanD
Copy link
Collaborator

albanD commented Nov 3, 2018

@jwdink

If I later exclude the invalid elements before aggregating, then I have excluded those invalid gradients.

The "exclusion" operation, when backpropagating through it will give 0 gradient for the elements that were excluded. If you look at the function f: (x,y) -> y it's gradient df/dx = 0.

Now the question is how pointwise_lossfun will handle the backprop phase for an element that was nan initially and has a 0 grad_output.
I expect that you get nan's in your case because the backward of pointwise_lossfun will involve the original output (that is nan) and most certainly will give a grad_input that is nan for them.
This nan will then propagate in your network.

To check that, you can print the gradients returned by you loss function by registering hooks on pred and actual

@jwdink
Copy link

jwdink commented Nov 3, 2018

Here's a very simple example to make this more concrete:

import torch
from torch import nn
from numpy import nan

# simple model for example:
class InterceptModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.intercept = nn.Parameter(torch.tensor(1.0))
        
    def forward(self, x):
        return self.intercept.expand_as(x)

# loss fun:
def pointwise_mse(pred, actual):
    return torch.pow(pred - actual, 2)

# the target, has some missing values:
actual = torch.randn(5) + 10.
actual[1] = nan
# actual.register_hook(lambda x: print(('actual',x))) 
# RuntimeError: cannot register a hook on a tensor that doesn't require gradient

# set up my model, generate predictions:
my_model = InterceptModel()
pred = my_model(actual)
pred.register_hook(lambda x: print(('pred',x)))

# this is an elementwise function
mses = pointwise_mse(pred, actual)

# now i'm excluding the invalid gradients:
loss = mses[actual==actual].mean()
loss.register_hook(lambda x: print(('loss',x)))

# why didn't excluding the invalid gradients work?
loss.backward()
my_model.intercept.grad # returns `tensor(nan)`

I'm a little unclear what I'm looking for here.

Again, the intuition is still that I've excluded the invalid gradients when constructing the loss scalar, so calling backward on this scalar shouldn't involve these gradients. It's not clear to my why these nans should propagate, in that case.

Apologies for my confusion.

@albanD
Copy link
Collaborator

albanD commented Nov 3, 2018

Your loss is 1 so mses.grad is tensor([0.2500, 0.0000, 0.2500, 0.2500, 0.2500])
Now the fomula for the Mean Square Error gradient is if you write o = mse(x, y)
do/dx = 2*(x-y). As you can see, if y is nan for an entry, then pred.grad will be nan for that entry. (you can see this in your hook for pred).
Then your model is doing: o = model(p) where you repeat p with the expand_as. So do/dp = 1 (note that 1 is a vector here). and so dloss/dp = dloss/do * do/dp = pred.grad * 1 = sum(pred.grad). So if any element of pred.grad is nan, the gradient will be nan, because nan + number = nan.

@c-hofer
Copy link
Author

c-hofer commented Nov 3, 2018

@albanD Thanks for the verbose update :) It made things much clearer. So in conclusion, we can say that the mentioned behavior is a design choice when implementing a recursive differentiation framework. It produces (from my point of view) unintuitive behavior regarding "exclusion" but I guess that's the price to pay for intuitive behavior in other cases, e.g., your example.

Some thoughts on your solution ....

With recursive differentiation in mind its clear that exchanging

x = x.sqrt()
x = x[x > 0]

by

x = x[x > 0]
x = x.sqrt()

solves the issue.
However, one could be in the situation that the latter is not an option as we have no access to
x before the sqrt operation. For example this could be a distance matrix which is given some functionality which then "excludes" unwanted "coordinates".
Or more general, you get an tensor D and some algorithm, alg, uses D but can, based on D's point-values, decide which parts have defined gradients. In the current framework you need to have detail information on D to be sure that autograd works, in terms of encapsulation this may be a disadvantage.
(This case is not hypothetical, I am currently in a similar situation ;) )

Would it be reasonable to propose an "exclusion" functionality where locally the multiplicative algebra is
changed such that nan * 0 = 0?
In this case alg could decide solely on the values of D which index selection should be applied.

This is just a thought and may be not important, or a huge pain for maybe little gain.

--cheers chris

@albanD
Copy link
Collaborator

albanD commented Nov 4, 2018

Hi,

After some discussion with other people I think I have we found a simple answer to the problem: "The chain rule only works for differentiable functions, and you're working at a point where sqrt is not differentiable. Hell insues !".
We will write all of that down explicitly in the doc to reduce such problems.

In your case the problem is not that nan * 0 = nan, only inf * 0 = nan.
Anyway, the proper way to deal with "I don't want the gradient that the chain rule gives" is hooks.
In your case something like this should fix any use case you encounter:

# I did not run this code, there might be typos in it, sorry
# Your model
x = f_theta(x)

def my_excluding_functions(x_input):
    # Compute the function
    x = x_input.sqrt()
    # Compute the exclusion mask
    mask = x>0
    # Mask the output
    output = x[mask]
    # Make sure the backward pass won't return nan
    # We use hook to modify the gradient computed for
    # grad_input and make sure it won't considered masked elements
    def mask_hook(grad):
        # Never modify the input given by a hook
        grad = grad.clone()
        # Set the gradients (don't multiply) of the part we did not keep to 0
        grad[1-mask] = 0
        return grad
    x_input.register_hook(mask_hook)
    
    return output

# Use the function that fix the gradients.
# The foward will be the same as:
# x = x.sqrt()
# x = x[x > 0]
# But the backward will mask out gradients so that
# whatever is returned by the first function, it will
# be zeroed out.
x = my_excluding_functions(x)

Note that if the function is not element wise and change the input size, you can compute an input mask and an output mask, one to mask out the output during the forward pass and one to mask out gradients of the input during the backward pass.

Hope this helps you make what you want !

@jwdink
Copy link

jwdink commented Nov 4, 2018

@albanD: those explanations are very helpful-- thanks for your patience!

@c-hofer
Copy link
Author

c-hofer commented Nov 5, 2018

@albanD
First, thanks again for your time.
Indeed I used nan instead of inf as I was thinking of values < 0 too. sry for the sloppiness ;).

The proposed solution does not solve the proposed problem (unless I've misunderstood something completely). To be clear, an example

x_0 = torch.tensor([float(i-1) for i in range(5)], requires_grad=True)
x_1 = x_0.sqrt() # this is the output of the client

x_input = x_1 # assumption: we have only access to x_1 and x_0 is hidden from 
                    #  the perspective of my_excluding_function
x_output = my_excluding_function(x_input)
x_output.sum().backward()

# expected: x_0.grad contains no nan 

In my understanding the proposed pattern solves the issue if we have detail information about
x_input, i.e. x_0 and twists grad at the right position in the recursion such that the chain rule does not hurt you.
However what is needed is to modify the very definition of multiplication at the current top level
of the recursion in a way such that nan * 0 = inf * 0 = 0. In other words from the point of the "exclusion" operation upwards we do not care about the ancestors in the differentiation graph.
Mathematically this reflects the idea that we always can add an unknown to the point wise definition of a function and the related partial differentation operation will yield 0 if it does not impact the pointwise value.

f(x_1, x_2) = x_1 ==> d f / d x_2 = 0 

I totally understand that technically this may yield problems because mathematics usually does not care about its implementation ;).
Again I want to point out that maybe this is not a big issue and not worth the time. I think i can work around the issue but it simply makes things less elegant as I'd like them to be ;) .

Comment: If one thinks of batch-wise application the problem is maybe more obvious. Let xx_0 = torch.tensor([float(i+1) for i in range(5)], requires_grad=True) be the "second" batch element. All gradients are defined and my_excluding_function would not exclude anything. However, the gradients in the backward pass will be added and hence the nan in the first case, x_0 will the destroy the gradient. It wouldn't if the gradient wrt the excluded coordinates in x_0 would be zero.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants