Skip to content

Commit

Permalink
Fix SoftmaxGrad failure with constant dy in numba backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Sep 12, 2023
1 parent fc5e10f commit da66c2e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
7 changes: 6 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,12 @@ def softmax_grad_py_fn(dy, sm):
dx = dy_times_sm - sum_dy_times_sm * sm
return dx

softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
# softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
softmax_grad = numba_njit(
boundscheck=False,
fastmath=config.numba__fastmath,
)(softmax_grad_py_fn)

return softmax_grad

Expand Down
10 changes: 10 additions & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,16 @@ def test_SoftmaxGrad(dy, sm, axis, exc):
)


def test_SoftMaxGrad_constant_dy():
dy = at.constant(np.zeros((3,), dtype=config.floatX))
sm = at.vector(shape=(3,))

g = SoftmaxGrad(axis=None)(dy, sm)
g_fg = FunctionGraph(outputs=[g])

compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)])


@pytest.mark.parametrize(
"x, axis, exc",
[
Expand Down

0 comments on commit da66c2e

Please sign in to comment.