Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: scipy.special.log_softmax could be 2**126 to 2**1022 times more accurate #19521

Open
JasonGross opened this issue Nov 14, 2023 · 2 comments · May be fixed by #19549
Open

ENH: scipy.special.log_softmax could be 2**126 to 2**1022 times more accurate #19521

JasonGross opened this issue Nov 14, 2023 · 2 comments · May be fixed by #19549
Labels
enhancement A new feature or improvement scipy.special

Comments

@JasonGross
Copy link

JasonGross commented Nov 14, 2023

Is your feature request related to a problem? Please describe.

Consider

import numpy as np
import scipy.special

eps = np.finfo(np.float32).eps
print(-scipy.special.log_softmax(np.array([1-np.log(2*eps), 0], dtype=np.float32))) # [1.19209275e-07 1.62492371e+01]
print(-scipy.special.log_softmax(np.array([1-np.log(eps), 0], dtype=np.float32)))   # [-0.       16.942385]

As I understand it, scipy implements log_softmax(x) as x - np.max(x) - np.log(np.sum(np.exp(x - np.max(x)))). However, when the largest value is much larger than the rest of the values (about 16 larger for float32, about 36 larger for float64), log_softmax returns 0 at the maximum value, when it could give a much more precise answer.

This came up when a transformer I was training with cross-entropy loss on a classification task had loss dominated by np.finfo(np.float32).eps.

Describe the solution you'd like.

Consider the following code, demonstrating a more accurate log_softmax:

import numpy as np
import scipy.special

def log_softmax_alt(x):
    maxi = np.argmax(x)
    xoffset = x - x[maxi]
    xoffsetexp = np.exp(xoffset)
    # xoffsetexp[maxi] is currently about 1
    xoffsetexp[maxi] = 0
    xoffsetexp_sum_m1 = np.sum(xoffsetexp)
    return xoffset - np.log1p(xoffsetexp_sum_m1)


for ty in (np.float32, np.float64):
    smallest_log_softmax, smallest_log_softmax_alt, smallest_log_softmax_val, smallest_log_softmax_alt_val = 0, 0, 0, 0
    for i in range(int(1-np.log2(np.finfo(ty).smallest_subnormal))):
        values = np.array([1 + i, 0], dtype=ty)
        log_softmax_values = scipy.special.log_softmax(values)
        log_softmax_values_alt = log_softmax_alt(values)
        if log_softmax_values[0] != 0: smallest_log_softmax, smallest_log_softmax_val = i, log_softmax_values[0]
        if log_softmax_values_alt[0] != 0: smallest_log_softmax_alt, smallest_log_softmax_alt_val = i, log_softmax_values_alt[0]
        if log_softmax_values[0] == 0 and log_softmax_values_alt[0] == 0: break
    print(f"For {ty}, diff in supported input accuracy is 2**-({smallest_log_softmax_alt} - {smallest_log_softmax}) = 2**-{smallest_log_softmax_alt - smallest_log_softmax}; diff in output accuracy is np.log2({-smallest_log_softmax_val}) - np.log2({-smallest_log_softmax_alt_val}) = {np.log2(-smallest_log_softmax_val) - np.log2(-smallest_log_softmax_alt_val)}")

which outputs the numbers in the title of this issue:
For <class 'numpy.float32'>, diff in supported input accuracy is 2**-(102 - 15) = 2**-87; diff in output accuracy is np.log2(1.1920927533992653e-07) - np.log2(1.401298464324817e-45) = 126.0
For <class 'numpy.float64'>, diff in supported input accuracy is 2**-(744 - 35) = 2**-709; diff in output accuracy is np.log2(2.2204460492503128e-16) - np.log2(5e-324) = 1022.0

Describe alternatives you've considered.

No response

Additional context (e.g. screenshots, GIFs)

I originally posted this as a StackOverflow question.
Companion PyTorch issue: pytorch/pytorch#113708
Companion TensorFlow issue: tensorflow/tensorflow#62400

@steppi
Copy link
Contributor

steppi commented Nov 16, 2023

Thanks @JasonGross, that's pretty clever to take advantage of the fact that $$x_i - \max_i x_i = 0$$
when $$i = \mathrm{argmax}\space x_k$$ and then use log1p. I'm just curious, do you have a use case for this? I'd imagine in most situations the difference wouldn't actually matter.

At the moment your implementation only handles 1d arrays and lacks the axis keyword argument. It also doesn't have handling of non-finite values. Check the current signature and implementation of log_softmax. If you get your implementation to parity, feel free to submit a PR.

JasonGross added a commit to JasonGross/scipy that referenced this issue Nov 17, 2023
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes scipy#19521
@JasonGross
Copy link
Author

In most situations the difference doesn't matter. It came up for me (in pytorch) when I was overtraining a very small transformer and the loss capped out at lower confidence than it needed to. I'll submit a PR momentarily

JasonGross added a commit to JasonGross/scipy that referenced this issue Jan 18, 2024
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes scipy#19521
JasonGross added a commit to JasonGross/scipy that referenced this issue Jan 18, 2024
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes scipy#19521
JasonGross added a commit to JasonGross/scipy that referenced this issue Jan 18, 2024
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes scipy#19521
JasonGross added a commit to JasonGross/scipy that referenced this issue Jan 19, 2024
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes scipy#19521
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement A new feature or improvement scipy.special
Projects
None yet
3 participants