Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve gradient stability of logsumexp, softmax, log_softmax, logsigmoid at -inf (replace nan by zero) #31829

Closed
vadimkantorov opened this issue Jan 3, 2020 · 22 comments
Labels
high priority module: autograd Related to torch.autograd, and the autograd engine in general module: NaNs and Infs Problems related to NaN and Inf handling in floating point needs research We need to decide whether or not this merits inclusion, based on research world triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jan 3, 2020

During a simple educational reimpl of CTC I found that torch.logsumexp produces nan gradient if all inputs happen to be -inf (it can also produce inf output, but it's not a problem). Zero gradient is much better in this case (since zero accumulates fine with other non-nan gradients).

Previous related issue: #6864. (One argument by @apaszke there is that inf outputs are often bad anyway, but in the case of HMM-like algorithms in log-space they are natural)

I agree this is not a very big issue since CTC problems are gone if float('-inf') is replaced by torch.finfo(torch.float).min as a neutral addition element in the Log-space semiring. However, if it's an easy fix it would be nice to have, since -inf is naturally a neutral addition element for Log-space.

In case of logsumexp problems possibly come because of subtraction of two -inf:

return grad * (self - result).exp();
which leads to nan in gradient.

A full repro showcasing PyTorch ops that return nan gradient in this case:

import torch
import torch.nn.functional as F

class LogsumexpFunction(torch.autograd.function.Function):
    @staticmethod
    def forward(self, x0, x1, x2):
        m = torch.max(torch.max(x0, x1), x2)
        m = m.masked_fill_(torch.isinf(m), 0)
        e0 = (x0 - m).exp_()
        e1 = (x1 - m).exp_()
        e2 = (x2 - m).exp_()
        e = (e0 + e1 + e2).clamp_(min = 1e-16)
        self.save_for_backward(e0, e1, e2, e)
        return e.log().add_(m)

    @staticmethod
    def backward(self, grad_output):
        e0, e1, e2, e = self.saved_tensors
        g = grad_output / e
        return g * e0, g * e1, g * e2

x0 = torch.tensor([float('-inf')], requires_grad = True)
x1 = torch.tensor([float('-inf')], requires_grad = True)
x2 = torch.tensor([float('-inf')], requires_grad = True)

logsumexp = torch.logsumexp(torch.stack([x0, x1, x2]), dim = 0)
softmax = torch.softmax(torch.stack([x0, x1, x2]), dim = 0)
log_softmax = torch.log_softmax(torch.stack([x0, x1, x2]), dim = 0)
sigmoid = torch.sigmoid(torch.stack([x0, x1, x2]))
logsigmoid = F.logsigmoid(torch.stack([x0, x1, x2]))

logsumexp_ = LogsumexpFunction.apply(x0, x1, x2)

print('torch.logsumexp', torch.autograd.grad(logsumexp, x0))
print('torch.softmax', torch.autograd.grad(softmax.sum(), x0))
print('torch.log_softmax', torch.autograd.grad(log_softmax.sum(), x0))
print('torch.log_sigmoid', torch.autograd.grad(logsigmoid.sum(), x0))
print('torch.sigmoid', torch.autograd.grad(sigmoid.sum(), x0))

print('custom logsumexp', torch.autograd.grad(logsumexp_, x0))

# torch.logsumexp (tensor([nan]),)
# torch.softmax (tensor([nan]),)
# torch.log_softmax (tensor([nan]),)
# torch.log_sigmoid (tensor([nan]),)
# torch.sigmoid (tensor([0.]),)
# custom logsumexp (tensor([0.]),)

In the custom reimpl nan gradient also occurs (if clamp is removed) because of dividing by zero (sum of exps of inputs happens to be zero).

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @gqchen @pearu @nikitaved @soulitzer @ssnl

@vadimkantorov vadimkantorov changed the title Improve gradient stability of logsumexp, softmax, log_softmax, logsigmoid at -inf Improve gradient stability of logsumexp, softmax, log_softmax, logsigmoid at -inf (replace nan by zero) Jan 3, 2020
@jerryzh168 jerryzh168 added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 8, 2020
@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 15, 2020

One intermediate (but practical) way could be implementing a hack like F.zero_nan_grad_(x) that will inplace zero out nan gradients that are coming from x (while proper fix is being deliberated). E.g. torch.where/indexing used to have this problem when implementing cross-entropy or entropy (it would have nan gradient).

This is somewhat akin to stop_gradient or gradient_reversal pseudo-functions that appear in GAN works. So maybe a whole namespace torch.nn.functional.grad is worth adding.

@speedcell4
Copy link

speedcell4 commented Feb 3, 2020

At this time we have a workaround like below, it works well on all these three cases,

  1. containing only -float("inf")
  2. containing no -float("inf")
  3. containing some -float("inf")
import torch
from torch import Tensor
from torch import jit


@jit.script
def logsumexp(x: Tensor, dim: int) -> Tensor:
    m, _ = x.max(dim=dim)
    mask = m == -float('inf')

    s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim)
    return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float('inf'))


def check(x: Tensor, fn):
    x.grad = None
    y = fn(x, dim=-1)
    print(f'y => {y.view(-1)}')
    y.backward(torch.ones_like(y))
    print(f'x.grad => {a.grad.view(-1)}')


if __name__ == '__main__':
    a = torch.full((2, 2), -float('inf'), requires_grad=True)

    check(a, logsumexp)
    check(a, torch.logsumexp)

    # y => tensor([-inf, -inf], grad_fn=<ViewBackward>)
    # x.grad => tensor([0., 0., 0., 0.])
    # y => tensor([-inf, -inf], grad_fn=<ViewBackward>)
    # x.grad => tensor([nan, nan, nan, nan])

    a = torch.randn((2, 2), requires_grad=True)

    check(a, logsumexp)
    check(a, torch.logsumexp)

    # y => tensor([ 1.0678, -0.0330], grad_fn=<ViewBackward>)
    # x.grad => tensor([0.7920, 0.2080, 0.8099, 0.1901])
    # y => tensor([ 1.0678, -0.0330], grad_fn=<ViewBackward>)
    # x.grad => tensor([0.7920, 0.2080, 0.8099, 0.1901])

    a = torch.randn((2, 2))
    a[0, 0] = -float('inf')
    a.requires_grad_(True)

    check(a, logsumexp)
    check(a, torch.logsumexp)

    # y => tensor([-0.0910,  1.5311], grad_fn=<ViewBackward>)
    # x.grad => tensor([0.0000, 1.0000, 0.2983, 0.7017])
    # y => tensor([-0.0910,  1.5311], grad_fn=<ViewBackward>)
    # x.grad => tensor([0.0000, 1.0000, 0.2983, 0.7017])

@albanD albanD added the needs research We need to decide whether or not this merits inclusion, based on research world label Feb 7, 2020
@vadimkantorov
Copy link
Contributor Author

@albanD Could you please tag this as "topic: NaNs and Infs"? https://github.com/pytorch/pytorch/issues?q=is%3Aissue+is%3Aopen+label%3A%22topic%3A+NaNs+and+Infs%22

@ezyang
Copy link
Contributor

ezyang commented Jan 4, 2021

Increased priority based on user activity

@vadimkantorov
Copy link
Contributor Author

@ezyang fixing the original problem would be nice, but please also consider a F.zero_nan_grad_(x) pseudo-function feature request.

@ezyang
Copy link
Contributor

ezyang commented Jan 5, 2021

yeah, that also seems pretty reasonable too

@vadimkantorov
Copy link
Contributor Author

Or some more general replace gradient values version: replacing infs can be useful too, or sometime replacing zero's by a tiny eps. This function can be useful for forward too. There may be some nuances (like negative/positive zeros), but it should be possible to come up with a good design...

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2021

please also consider a F.zero_nan_grad_(x) pseudo-function feature request.

This can be implemented with a custom autograd Function already right? Or even a hook similar to x.register_hook(lambda g: g.masked_fill_(g != g, 0)?

Or your argument here is more that we should have that in core pytorch?

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 5, 2021

My argument is that these problems are so frequent (torch.where producing bad gradients, absence of xlogy, need for replacing inf gradients to sidestep 0 * inf) and require workarounds that are not completely trivial to come up with (sometimes shifting, sometimes clamping, sometimes clamping the gradient) that PyTorch needs idioms for these in core (though hopefully backward of torch.where could still be fixed + xlogy implemented).

And in general - a series of idiomatic gradient pseudo-functions would be nice (e.g. for clipping large gradient values not within the optimizer, but at the model code granularity). For me hooks are always a bit fishy, not clear when they are called and if they would be applied at scripting / export, while an explicit standard-ish pseudo-function could be treated properly by PyTorch (if desired).

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2021

hopefully backward of torch.where could still be fixed

Do you have an issue about that? The only one I have in mind is about the fact that it generates "just" 0 gradients for elements that were not used and I don't think there is any plan to change that.

that PyTorch needs idioms for these in core

The problem is that, as you said, the choice of which one to use is not trivial and requires a deep understanding of what is happening. But I feel like, by the time you understand that you need this (and which one of these you need), you have all the knowledge to fix it yourself without a special function from the core.

But I guess all of this is beyond the point of the original post and maybe we should move this discussion to a new issue if we really want to add such a function (or set of functions).

@vadimkantorov
Copy link
Contributor Author

A problem about torch.where is that gradients would still be nan even if the bad value was not taken in the forward pass. At least it was like that for a long time, so the naive entropy computation had bad gradients: torch.where(p > 0, p * p.log(), 0)

For generic case, yes, the user should have control, but there should be supporting idioms (pseudo-functions) to make these choices easier and less error-prone. Yep, maybe we should factor this in a separate issue. Will create one.

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2021

A problem about torch.where is that gradients would still be nan even if the bad value was not taken in the forward pass. At least it was like that for a long time, so the naive entropy computation had bad gradients: torch.where(p > 0, p * p.log(), 0)

Yes but as was discussed in the issue, this is expected I am afraid. the function before the torch.where is the one that generates these bad gradients. So there is very little torch.where can do unless we can have a "unused" (or something else) element for the gradient. But this would need to be stored within the grad Tensor and all the float values are already used there.

@vadimkantorov
Copy link
Contributor Author

Isn't it torch.where that produces its final gradient? The whole point of using torch.where is to instruct it to not use p * p.log() where it is not correctly defined. So it should be able to not use the p * p.log() gradient where p == 0 because it's exactly the passed mask

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2021

torch.where will generate the right gradients for all these inputs, and the gradient will be 0 because these inputs were not used.
The problem is that later in the grad computation (of log for example), this grad of 0 will end up being divided by another 0, leading to nan.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 5, 2021

Well, for a long time it was nan and was nan last time I checked. I'll dig up the discussions about entropy function... Having good gradients for p * log(p) would be nice, and there is a feature request for torch.xlogy function, but even the fix-up with torch.where was not surprisingly working

I guess I don't understand why you say that computation of gradient of log is later in the grad computation. First we get gradient of p and p.log() (will be bad as well), then of p * p.log() (which will have nan), and then torch.where should ignore these bad gradients because they are luckily masked out by the passed mask.

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2021

The gradients are computed in the backward order compared to the forward pass. So the where is the first one for which gradients are computed. then product, then the log and then p.

@vadimkantorov
Copy link
Contributor Author

The gradients are computed in the backward order compared to the forward pass. So the where is the first one for which gradients are computed. then product, then the log and then p.

You're right

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 5, 2021

Isn't there a moment in mul gradient where it multiplies the results by the grad_output (that it received from torch.where)? It could probably copy over zeros instead of multiplying zero by nan/inf and having a nan propagated? Or is the problem the perf hit of modifying the mul's backward?

Maybe some pseudo-function could handle this :)

@albanD
Copy link
Collaborator

albanD commented Jan 5, 2021

The problem is that we have no way to differentiate a 0 that comes from "unused" from a 0 that comes from "it is numerically 0".
And while, for the first one, we could argue that we have happy that multiplying it with inf should remain a "unused" 0. For the second one, we definitely want it to become a nan!

Hence my comment above that the big blocker here is that we have no way to differentiate the two today :/

@gchanan
Copy link
Contributor

gchanan commented Jan 11, 2021

closing this, further discussion in: #49724.

@gchanan gchanan closed this as completed Jan 11, 2021
@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 11, 2021

Would you like to copy-paste there the examples/test cases from my OP here? So that they are not forgotten (if they're different)

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jun 10, 2022

The problem is that we have no way to differentiate a 0 that comes from "unused" from a 0 that comes from "it is numerically 0". And while, for the first one, we could argue that we have happy that multiplying it with inf should remain a "unused" 0. For the second one, we definitely want it to become a nan!

Hence my comment above that the big blocker here is that we have no way to differentiate the two today :/

One funny way could be to distinguish 0 and -0 to mean "zero" or "unused" :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: autograd Related to torch.autograd, and the autograd engine in general module: NaNs and Infs Problems related to NaN and Inf handling in floating point needs research We need to decide whether or not this merits inclusion, based on research world triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants