Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of zero norm is nan #2421

Closed
zkolter opened this issue Aug 15, 2017 · 16 comments · Fixed by #2775
Closed

Gradient of zero norm is nan #2421

zkolter opened this issue Aug 15, 2017 · 16 comments · Fixed by #2775

Comments

@zkolter
Copy link
Contributor

zkolter commented Aug 15, 2017

If a norm is zero, its gradient returns nan:

x = Variable(torch.zeros(1), requires_grad=True)
x.norm().backward()
print x.grad

# Variable containing:
# nan
# [torch.FloatTensor of size 1]

Obviously just happening because the gradient divides by the norm, but the (sub)gradient here should probably be zero, or at least not nan, since that will propagate to make all updates nan. Probably low priority, as it's not going to be an issue in 99% of cases, but we're doing a few things with (exact) line searches where this caused a nan to appear, breaking everything downstream.

@cai-lw
Copy link

cai-lw commented Aug 17, 2017

I'm encountering exactly the same issue! Spent hours on debugging, just to find PyTorch has a bug on such basic thing.

@soumith soumith added this to Uncategorized in Issue Status Aug 23, 2017
@soumith soumith added this to numerical-stability/correctness in Issue Categories Aug 30, 2017
@el3ment
Copy link

el3ment commented Sep 6, 2017

+1 just found this bug too

@dannysdeng
Copy link

dannysdeng commented Sep 7, 2017

+1 for this bug. Temporarily changing my code to something like the following for the sake of debugging.

x = Variable(torch.zeros(1), requires_grad=True)
y = x + 1e-16
y.norm().backward()
print x.grad

@albanD
Copy link
Collaborator

albanD commented Sep 7, 2017

The thing is that in the 2 norm, there is a square root, which has a gradient of + infinity at 0.
The gradient gives you nan because you then multiply 0 and an infinity during the backward pass.

@ruotianluo
Copy link
Contributor

For a scalar, norm 2 is basically abs. But x.abs().backward() gives you 0 gradient. In this sense, it's not coherent.

@JianboTang
Copy link

I found this error, too

@soumith
Copy link
Member

soumith commented Sep 18, 2017

Alban fixed this behavior in #2775

@soumith soumith removed this from correctness/stability in Issue Categories Sep 19, 2017
@D-X-Y
Copy link

D-X-Y commented Dec 31, 2017

@soumith Hi, the norm function can give use the 0 gradients now.
However, the following code still has the nan gradient problem

x = torch.autograd.Variable(torch.zeros(1), requires_grad=True)
y = torch.sqrt( x * x )
y.backward()
print (x.grad)

@albanD
Copy link
Collaborator

albanD commented Dec 31, 2017

Ho
The square root has no gradient at 0. This is expected behavior.

@D-X-Y
Copy link

D-X-Y commented Dec 31, 2017

Hi, @albanD but the sub-gradient of the square root should be zero?
Also, y = torch.sqrt( x * x ) should equal to x.norm(), why they have different gradient ( 0 and nan )?

@cai-lw
Copy link

cai-lw commented Dec 31, 2017

@D-X-Y I think @albanD was right. The left-side derivative of sqrt(x) at x=0 is undefined, so it doesn't even have a subgradient at x=0.

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2018

@D-X-Y square root has no subgradient at 0. You could define a gradient by continuity but then it would be +inf...
Given that pytorch is using autograd, x.norm() and x.pow(2).sqrt() (equivalent to your torch.sqrt(x*x)) are completely different:

  • The first one is a single function that is convex and defined on R, it has a subgradient of 0 at 0.
  • The second one is composed of two function, the first function is the square function which is differentiable and outputs values in [0, +inf[. The second function is the square root that is not convex and even though it is defined on [0, +inf[, it is only differentiable on ]0, +inf[ and it's gradient in 0 in undefined.
    Given that, even though x.norm() and x.pow(2).sqrt() will return the same value, their gradients may differ at points where it is not differentiable, this is because automatic differentiation looks at each step of the computation one by one and even though in some cases a subgradient exist (because we look at multiple operations as a single function), it is not always the case and the gradient remains undefined.

@xforceco
Copy link

xforceco commented Mar 17, 2018

I think math is math. Any root's gradient at zero is either inf or undefined.
This issue shall be handled by the users' themselves by adding a small value(as @dannysdeng did), but an error(warning) message may be helpful since it is pretty hard to debug.

Say: Infinite/Undefined gradient is detected at X_function_X at line Y. Exit.

@ngimel
Copy link
Collaborator

ngimel commented Mar 17, 2018

Agree, norm is not differentiable at 0 https://math.stackexchange.com/questions/310325/is-the-euclidean-norm-differentiable-at-0, the bandaid that Alban put there in #2775 is wrong (even in the limit sense the gradient at 0 should be 1 not 0), but it should not have been there at all. Norm is norm, if someone want to add epsilons to their norms (like batchrnorm, e.g.) they are welcome to do so in the user code. What would numpy do?

@asford
Copy link
Contributor

asford commented May 10, 2018

@ngimel @albanD

I've also run into a number of problems related to the change introduced in #2775. Is there a reason the subgradient is set to 0, rather than the 1? (The limit as norm->0?)

As a minimal example:

import torch

v = torch.linspace(0, 1e-6, steps=10).requires_grad_()

def bnorm(val):
    n = val.detach().clone().requires_grad_(True)
    c = n.reshape(-1, 1)
    
    nn = c.norm(dim=-1)
    torch.autograd.backward(nn, torch.ones_like(nn))
    return n.grad

def snorm(val):
    n = val.detach().clone().requires_grad_(True)
    c = n.reshape(-1, 1)
    
    nn = (c ** 2).sum(dim=-1).sqrt()
    
    torch.autograd.backward(nn, torch.ones_like(nn))
    return n.grad

print("torch.__version__:")
display(torch.__version__)

print("vals:")
display(v)

print("torch.norm(v, dim=-1) grad:")
display(bnorm(v))

print("(v**2).sum(dim=-1).sqrt() grad:")
display(snorm(v))

Produces

torch.__version__:
'0.4.0'
vals:
tensor(1.00000e-07 *
       [ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,
         7.7778,  8.8889, 10.0000])
torch.norm(v, dim=-1) grad:
tensor([ 0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.])
(v**2).sum(dim=-1).sqrt() grad:
tensor([nan.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.])

@albanD
Copy link
Collaborator

albanD commented May 11, 2018

Well any value between [-1, 1] is a valid subgradient for the 2-norm.
More genereally, any vector in the 1 ball for the dual norm is a valid subgradient.
This means that 0 is always going to be a subgradient, while 1 will not be for all p.

Anyway, the theory says that any of them could be taken and subgradient descent will work. I'm sure that depending on the application, one will be better than the other.
For example, the relu function will also give a 0 subgradient at 0, you could have given 1.
The main point here was to remove nans that make your network give nan for everything which is not convenient.

imgemp added a commit to imgemp/Lyapunov-GANs that referenced this issue Sep 22, 2018
gradient of sqrt(x^2) near zero is infinite due to chain rule decomposition: 1/2*z^(-1/2)*2x where z = x^2.
pytorch/pytorch#2421
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Issue Status
Uncategorized
Development

Successfully merging a pull request may close this issue.