Skip to content

Commit

Permalink
SQUASH ME: attempt to fix slicing in log_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Jan 18, 2024
1 parent eadbb1a commit 6048ec3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions scipy/special/_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def log_softmax(x, axis=None):
x = _asarray_validated(x, check_finite=False)

x_argmax = np.argmax(x, axis=axis, keepdims=True)
x_max = x[x_argmax] if x.ndim > 0 else x
x_max = np.take_along_axis(x, x_argmax, axis=axis)

finite_max_mask = np.isfinite(x_max)

Expand All @@ -327,7 +327,9 @@ def log_softmax(x, axis=None):
# we know that exp_tmp at the location of the max is either 1 or infinite,
# depending on finite_max_mask, so we can set it to zero and use log1p
if exp_tmp.ndim > 0:
exp_tmp[x_argmax[finite_max_mask]] = 0
exp_tmp_max = np.take_along_axis(exp_tmp, x_argmax, axis=axis)
exp_tmp_max[finite_max_mask] = 0
np.put_along_axis(exp_tmp, x_argmax, exp_tmp_max, axis=axis)
elif finite_max_mask:
exp_tmp = 0

Expand Down

0 comments on commit 6048ec3

Please sign in to comment.