Skip to content

Commit

Permalink
ENH: improve scipy.special.log_softmax accuracy
Browse files Browse the repository at this point in the history
By taking advantage of the fact that `x - x_max` is going to be 0 at the
maximum and that `exp(0)` is 1, we can use `log1p` instead of `log` to
increase the accuracy of `log_softmax` at the maximum index by a factor
of about `2**126` (for float32) or about `2**1022` (for float64).

Fixes #19521
  • Loading branch information
JasonGross committed Jan 18, 2024
1 parent 0a7631f commit 67cbf91
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
47 changes: 42 additions & 5 deletions scipy/special/_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,24 +284,61 @@ def log_softmax(x, axis=None):
>>> y
array([ 0., -inf])
>>> subnormal32 = np.finfo(np.float32).smallest_subnormal
>>> x = np.array([0, np.log(subnormal32)], dtype=np.float32)
>>> y = log_softmax(x)
>>> y
array([-1.40130e-45, -1.03279e+02], dtype=float32)
>>> with np.errstate(divide='ignore'):
... y = np.log(softmax(x))
...
>>> y
array([ 0. , -103.27893], dtype=float32)
>>> subnormal64 = np.finfo(np.float64).smallest_subnormal
>>> x = np.array([0, np.log(subnormal64)], dtype=np.float64)
>>> y = log_softmax(x)
>>> y
array([-4.9407e-324, -7.4444e+002])
>>> with np.errstate(divide='ignore'):
... y = np.log(softmax(x))
...
>>> y
array([ 0. , -744.44007])
"""

x = _asarray_validated(x, check_finite=False)

x_max = np.amax(x, axis=axis, keepdims=True)
x_argmax = np.argmax(x, axis=axis, keepdims=True)
# work around https://github.com/numpy/numpy/issues/25622
if axis is None: x_argmax = x_argmax.flatten()
x_max = np.take_along_axis(x, x_argmax, axis=axis)

finite_max_mask = np.isfinite(x_max)

if x_max.ndim > 0:
x_max[~np.isfinite(x_max)] = 0
elif not np.isfinite(x_max):
x_max = 0
x_max[~finite_max_mask] = 0
elif not finite_max_mask:
x_max = np.zeros_like(x_max)

tmp = x - x_max
exp_tmp = np.exp(tmp)

# we know that exp_tmp at the location of the max is either 1 or infinite,
# depending on finite_max_mask, so we can set it to zero and use log1p
if exp_tmp.ndim > 0:
exp_tmp_max = np.take_along_axis(exp_tmp, x_argmax, axis=axis)
exp_tmp_max[finite_max_mask] = 0
np.put_along_axis(exp_tmp, x_argmax, exp_tmp_max, axis=axis)
elif finite_max_mask:
exp_tmp = np.zeros_like(exp_tmp)

# suppress warnings about log of zero
with np.errstate(divide='ignore'):
s = np.sum(exp_tmp, axis=axis, keepdims=True)
out = np.log(s)
out = np.log1p(s)

out = tmp - out
return out
8 changes: 8 additions & 0 deletions scipy/special/tests/test_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

@pytest.mark.parametrize('x, expected', [
(np.array([1000, 1]), np.array([0, -999])),
# we shouldn't return zero on the smallest subnormal input
(np.array([-np.log(np.finfo(np.float32).smallest_subnormal), 0], dtype=np.float32),
np.array([float.fromhex('-0x1.00000p-149'), float.fromhex('-0x1.9d1dap+6')],
dtype=np.float32)),
(np.array([-np.log(np.finfo(np.float64).smallest_subnormal), 0], dtype=np.float64),
np.array([float.fromhex('-0x0.0000000000001p-1022'),
float.fromhex('-0x1.74385446d71c3p+9')],
dtype=np.float64)),
# Expected value computed using mpmath (with mpmath.mp.dps = 200) and then
# converted to float.
Expand Down

0 comments on commit 67cbf91

Please sign in to comment.