-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve gradient stability of logsumexp, softmax, log_softmax, logsigmoid at -inf (replace nan by zero) #31829
Comments
One intermediate (but practical) way could be implementing a hack like This is somewhat akin to |
At this time we have a workaround like below, it works well on all these three cases,
import torch
from torch import Tensor
from torch import jit
@jit.script
def logsumexp(x: Tensor, dim: int) -> Tensor:
m, _ = x.max(dim=dim)
mask = m == -float('inf')
s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim)
return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float('inf'))
def check(x: Tensor, fn):
x.grad = None
y = fn(x, dim=-1)
print(f'y => {y.view(-1)}')
y.backward(torch.ones_like(y))
print(f'x.grad => {a.grad.view(-1)}')
if __name__ == '__main__':
a = torch.full((2, 2), -float('inf'), requires_grad=True)
check(a, logsumexp)
check(a, torch.logsumexp)
# y => tensor([-inf, -inf], grad_fn=<ViewBackward>)
# x.grad => tensor([0., 0., 0., 0.])
# y => tensor([-inf, -inf], grad_fn=<ViewBackward>)
# x.grad => tensor([nan, nan, nan, nan])
a = torch.randn((2, 2), requires_grad=True)
check(a, logsumexp)
check(a, torch.logsumexp)
# y => tensor([ 1.0678, -0.0330], grad_fn=<ViewBackward>)
# x.grad => tensor([0.7920, 0.2080, 0.8099, 0.1901])
# y => tensor([ 1.0678, -0.0330], grad_fn=<ViewBackward>)
# x.grad => tensor([0.7920, 0.2080, 0.8099, 0.1901])
a = torch.randn((2, 2))
a[0, 0] = -float('inf')
a.requires_grad_(True)
check(a, logsumexp)
check(a, torch.logsumexp)
# y => tensor([-0.0910, 1.5311], grad_fn=<ViewBackward>)
# x.grad => tensor([0.0000, 1.0000, 0.2983, 0.7017])
# y => tensor([-0.0910, 1.5311], grad_fn=<ViewBackward>)
# x.grad => tensor([0.0000, 1.0000, 0.2983, 0.7017]) |
@albanD Could you please tag this as "topic: NaNs and Infs"? https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+label%3A%22topic%3A+NaNs+and+Infs%22 |
Increased priority based on user activity |
@ezyang fixing the original problem would be nice, but please also consider a |
yeah, that also seems pretty reasonable too |
Or some more general replace gradient values version: replacing infs can be useful too, or sometime replacing zero's by a tiny eps. This function can be useful for forward too. There may be some nuances (like negative/positive zeros), but it should be possible to come up with a good design... |
This can be implemented with a custom autograd Function already right? Or even a hook similar to Or your argument here is more that we should have that in core pytorch? |
My argument is that these problems are so frequent (torch.where producing bad gradients, absence of xlogy, need for replacing inf gradients to sidestep 0 * inf) and require workarounds that are not completely trivial to come up with (sometimes shifting, sometimes clamping, sometimes clamping the gradient) that PyTorch needs idioms for these in core (though hopefully backward of torch.where could still be fixed + xlogy implemented). And in general - a series of idiomatic gradient pseudo-functions would be nice (e.g. for clipping large gradient values not within the optimizer, but at the model code granularity). For me hooks are always a bit fishy, not clear when they are called and if they would be applied at scripting / export, while an explicit standard-ish pseudo-function could be treated properly by PyTorch (if desired). |
Do you have an issue about that? The only one I have in mind is about the fact that it generates "just" 0 gradients for elements that were not used and I don't think there is any plan to change that.
The problem is that, as you said, the choice of which one to use is not trivial and requires a deep understanding of what is happening. But I feel like, by the time you understand that you need this (and which one of these you need), you have all the knowledge to fix it yourself without a special function from the core. But I guess all of this is beyond the point of the original post and maybe we should move this discussion to a new issue if we really want to add such a function (or set of functions). |
A problem about torch.where is that gradients would still be nan even if the bad value was not taken in the forward pass. At least it was like that for a long time, so the naive entropy computation had bad gradients: For generic case, yes, the user should have control, but there should be supporting idioms (pseudo-functions) to make these choices easier and less error-prone. Yep, maybe we should factor this in a separate issue. Will create one. |
Yes but as was discussed in the issue, this is expected I am afraid. the function before the |
Isn't it torch.where that produces its final gradient? The whole point of using torch.where is to instruct it to not use |
|
Well, for a long time it was I guess I don't understand why you say that computation of gradient of |
The gradients are computed in the backward order compared to the forward pass. So the where is the first one for which gradients are computed. then product, then the log and then p. |
You're right |
Isn't there a moment in mul gradient where it multiplies the results by the Maybe some pseudo-function could handle this :) |
The problem is that we have no way to differentiate a Hence my comment above that the big blocker here is that we have no way to differentiate the two today :/ |
closing this, further discussion in: #49724. |
Would you like to copy-paste there the examples/test cases from my OP here? So that they are not forgotten (if they're different) |
One funny way could be to distinguish 0 and -0 to mean "zero" or "unused" :) |
During a simple educational reimpl of CTC I found that torch.logsumexp produces nan gradient if all inputs happen to be -inf (it can also produce inf output, but it's not a problem). Zero gradient is much better in this case (since zero accumulates fine with other non-nan gradients).
Previous related issue: #6864. (One argument by @apaszke there is that inf outputs are often bad anyway, but in the case of HMM-like algorithms in log-space they are natural)
I agree this is not a very big issue since CTC problems are gone if
float('-inf')
is replaced bytorch.finfo(torch.float).min
as a neutral addition element in the Log-space semiring. However, if it's an easy fix it would be nice to have, since -inf is naturally a neutral addition element for Log-space.In case of
logsumexp
problems possibly come because of subtraction of two-inf
:pytorch/tools/autograd/templates/Functions.cpp
Line 433 in 877c96c
A full repro showcasing PyTorch ops that return nan gradient in this case:
In the custom reimpl nan gradient also occurs (if clamp is removed) because of dividing by zero (sum of exps of inputs happens to be zero).
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @gqchen @pearu @nikitaved @soulitzer @ssnl
The text was updated successfully, but these errors were encountered: