From 256c2007f4c9c58b3d9a5c232c26de73c5175e46 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 18 Nov 2025 10:39:30 +0100 Subject: [PATCH] Numba Pow: Fix failure with discrete integer exponents Workaround for: https://github.com/numba/numba/issues/9554 --- pytensor/link/numba/dispatch/scalar.py | 18 ++++++++++++++++++ tests/link/numba/test_scalar.py | 10 ++++++++++ 2 files changed, 28 insertions(+) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 65c40f3b77..2e40dd4b34 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -22,6 +22,7 @@ Composite, Identity, Mul, + Pow, Reciprocal, ScalarOp, Second, @@ -165,6 +166,23 @@ def {binary_op_name}({input_signature}): return nary_fn +@register_funcify_and_cache_key(Pow) +def numba_funcify_Pow(op, node, **kwargs): + pow_dtype = node.inputs[1].type.dtype + if pow_dtype.startswith("int"): + # Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True + # https://github.com/numba/numba/issues/9554 + + def pow(x, y): + return x ** np.asarray(y, dtype=np.int64).item() + else: + + def pow(x, y): + return x**y + + return numba_basic.numba_njit(pow), scalar_op_cache_key(op) + + @register_funcify_and_cache_key(Add) def numba_funcify_Add(op, node, **kwargs): nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 3e89db5859..6145d2da1b 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -189,6 +189,16 @@ def test_Softplus(dtype): ) +def test_discrete_power(): + # Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554 + x = pt.scalar("x", dtype="float64") + exponent = pt.scalar("exponent", dtype="int8") + out = pt.power(x, exponent) + compare_numba_and_py( + [x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")] + ) + + def test_cython_obj_mode_fallback(): """Test that unsupported cython signatures fallback to obj-mode"""