Skip to content

Commit

Permalink
ENH: log_softmax supports tuple axis again
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Jan 19, 2024
1 parent e573855 commit 1c0dca0
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions scipy/special/_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,15 @@ def log_softmax(x, axis=None):

x = _asarray_validated(x, check_finite=False)

x_argmax = np.argmax(x, axis=axis, keepdims=True)
# work around https://github.com/numpy/numpy/issues/25622
if axis is None:
x_argmax = x_argmax.flatten()
x_max = np.take_along_axis(x, x_argmax, axis=axis)
# work around https://github.com/numpy/numpy/issues/25623
if isinstance(axis, tuple):
x_max = np.amax(x, axis=axis, keepdims=True)
else:
x_argmax = np.argmax(x, axis=axis, keepdims=True)
# work around https://github.com/numpy/numpy/issues/25622
if axis is None:
x_argmax = x_argmax.flatten()
x_max = np.take_along_axis(x, x_argmax, axis=axis)

finite_max_mask = np.isfinite(x_max)

Expand All @@ -327,19 +331,27 @@ def log_softmax(x, axis=None):
tmp = x - x_max
exp_tmp = np.exp(tmp)

# we know that exp_tmp at the location of the max is either 1 or infinite,
# depending on finite_max_mask, so we can set it to zero and use log1p
if exp_tmp.ndim > 0:
exp_tmp_max = np.take_along_axis(exp_tmp, x_argmax, axis=axis)
exp_tmp_max[finite_max_mask] = 0
np.put_along_axis(exp_tmp, x_argmax, exp_tmp_max, axis=axis)
elif finite_max_mask:
exp_tmp = np.zeros_like(exp_tmp)
# work around https://github.com/numpy/numpy/issues/25623
if isinstance(axis, tuple):
pass
else:
# we know that exp_tmp at the location of the max is either 1 or infinite,
# depending on finite_max_mask, so we can set it to zero and use log1p
if exp_tmp.ndim > 0:
exp_tmp_max = np.take_along_axis(exp_tmp, x_argmax, axis=axis)
exp_tmp_max[finite_max_mask] = 0
np.put_along_axis(exp_tmp, x_argmax, exp_tmp_max, axis=axis)
elif finite_max_mask:
exp_tmp = np.zeros_like(exp_tmp)

# suppress warnings about log of zero
with np.errstate(divide='ignore'):
s = np.sum(exp_tmp, axis=axis, keepdims=True)
out = np.log1p(s)
# work around https://github.com/numpy/numpy/issues/25623
if isinstance(axis, tuple):
out = np.log(s)
else:
out = np.log1p(s)

out = tmp - out
return out

0 comments on commit 1c0dca0

Please sign in to comment.