Skip to content

Commit

Permalink
ENH: improve scipy.special.log_softmax accuracy
Browse files Browse the repository at this point in the history
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes #19521
  • Loading branch information
JasonGross committed Nov 17, 2023
1 parent f3cda9d commit e24a728
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
41 changes: 37 additions & 4 deletions scipy/special/_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,24 +275,57 @@ def log_softmax(x, axis=None):
>>> y
array([ 0., -inf])
>>> x = np.array([0, np.log(np.finfo(np.float32).smallest_subnormal)], dtype=np.float32)
>>> y = log_softmax(x)
>>> y
array([-1.40130e-45, -1.03279e+02], dtype=float32)
>>> with np.errstate(divide='ignore'):
... y = np.log(softmax(x))
...
>>> y
array([ 0. , -103.27893], dtype=float32)
>>> x = np.array([0, np.log(np.finfo(np.float64).smallest_subnormal)], dtype=np.float64)
>>> y = log_softmax(x)
>>> y
array([-4.9407e-324, -7.4444e+002])
>>> with np.errstate(divide='ignore'):
... y = np.log(softmax(x))
...
>>> y
array([ 0. , -744.44007])
"""

x = _asarray_validated(x, check_finite=False)

x_max = np.amax(x, axis=axis, keepdims=True)
x_argmax = np.argmax(x, axis=axis, keepdims=True)
x_max = x[x_argmax] if x.ndim > 0 else x

finite_max_mask = np.isfinite(x_max)

if x_max.ndim > 0:
x_max[~np.isfinite(x_max)] = 0
elif not np.isfinite(x_max):
x_max[~finite_max_mask] = 0
elif not finite_max_mask:
x_max = 0

tmp = x - x_max
exp_tmp = np.exp(tmp)

# we know that exp_tmp at the location of the max is either 1 or infinite,
# depending on finite_max_mask, so we can set it to zero and use log1p
if exp_tmp.ndim > 0:
exp_tmp[x_argmax[finite_max_mask]] = 0
elif finite_max_mask:
exp_tmp = 0

# suppress warnings about log of zero
with np.errstate(divide='ignore'):
s = np.sum(exp_tmp, axis=axis, keepdims=True)
out = np.log(s)
out = np.log1p(s)

out = tmp - out
return out
5 changes: 5 additions & 0 deletions scipy/special/tests/test_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

@pytest.mark.parametrize('x, expected', [
(np.array([1000, 1]), np.array([0, -999])),
# we shouldn't return zero on the smallest subnormal input
(np.array([-np.log(np.finfo(np.float32).smallest_subnormal), 0], dtype=np.float32),
np.array([float.fromhex('-0x1.00000p-149'), float.fromhex('-0x1.9d1dap+6')], dtype=np.float32)),
(np.array([-np.log(np.finfo(np.float64).smallest_subnormal), 0], dtype=np.float64),
np.array([float.fromhex('-0x0.0000000000001p-1022'), float.fromhex('-0x1.74385446d71c3p+9')], dtype=np.float64)),
# Expected value computed using mpmath (with mpmath.mp.dps = 200) and then
# converted to float.
Expand Down

0 comments on commit e24a728

Please sign in to comment.