From 49d9f87ec984d596436d3d064de6a908183a427c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 21 Nov 2025 16:59:54 +0100 Subject: [PATCH 1/3] Numba fallback cython missing dtype --- pytensor/link/numba/dispatch/scalar.py | 18 +++++++++------ tests/link/numba/test_scalar.py | 32 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 50af695a2e..c626f7faf3 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -69,15 +69,19 @@ def numba_funcify_ScalarOp(op, node, **kwargs): cython_func = getattr(scipy.special.cython_special, scalar_func_name, None) if cython_func is not None: - scalar_func_numba = wrap_cython_function( - cython_func, output_dtype, input_dtypes - ) - has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch - input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() - output_inner_dtype = scalar_func_numba.numpy_output_dtype() + try: + scalar_func_numba = wrap_cython_function( + cython_func, output_dtype, input_dtypes + ) + except NotImplementedError: + pass + else: + has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch + input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() + output_inner_dtype = scalar_func_numba.numpy_output_dtype() if scalar_func_numba is None: - scalar_func_numba = generate_fallback_impl(op, node, **kwargs) + return generate_fallback_impl(op, node, **kwargs), None scalar_op_fn_name = get_name_for_object(scalar_func_numba) prefix = "x" if scalar_func_name != "x" else "y" diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 2125d7cc0e..e2665f2eef 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -1,11 +1,14 @@ import numpy as np import pytest +import scipy import pytensor.scalar as ps import pytensor.scalar.basic as psb import pytensor.scalar.math as psm import pytensor.tensor as pt from pytensor import config, function +from pytensor.graph import Apply +from pytensor.scalar import UnaryScalarOp from pytensor.scalar.basic import Composite from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise @@ -184,3 +187,32 @@ def test_Softplus(dtype): strict=True, err_msg=f"Failed for value {value}", ) + + +def test_cython_obj_mode_fallback(): + """Test that unsupported cython signatures fallback to obj-mode""" + + # Create a ScalarOp with a non-standard dtype + class IntegerGamma(UnaryScalarOp): + # We'll try to check for scipy cython impl + nfunc_spec = ("scipy.special.gamma", 1, 1) + + def make_node(self, x): + x = psb.as_scalar(x) + assert x.dtype == "int64" + out = x.type() + return Apply(self, [x], [out]) + + def impl(self, x): + return scipy.special.gamma(x).astype("int64") + + x = pt.scalar("x", dtype="int64") + g = Elemwise(IntegerGamma())(x) + assert g.type.dtype == "int64" + + with pytest.warns(UserWarning, match="Numba will use object mode"): + compare_numba_and_py( + [x], + [g], + [np.array(5, dtype="int64")], + ) From c76807ce2341593f58e27c5e17559d0494972985 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 21 Nov 2025 18:46:11 +0100 Subject: [PATCH 2/3] Numba fallback complex erf --- pytensor/link/numba/dispatch/scalar.py | 6 +++++- tests/link/numba/test_scalar.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index c626f7faf3..65c40f3b77 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -278,7 +278,11 @@ def logp1mexp(x): @register_funcify_and_cache_key(Erf) -def numba_funcify_Erf(op, **kwargs): +def numba_funcify_Erf(op, node, **kwargs): + if node.inputs[0].type.dtype.startswith("complex"): + # Complex not supported by numba + return numba_funcify_ScalarOp(op, node=node, **kwargs) + @numba_basic.numba_njit def erf(x): return math.erf(x) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index e2665f2eef..3e89db5859 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -216,3 +216,14 @@ def impl(self, x): [g], [np.array(5, dtype="int64")], ) + + +def test_erf_complex(): + x = pt.scalar("x", dtype="complex128") + g = pt.erf(x) + + compare_numba_and_py( + [x], + [g], + [np.array(0.5 + 1j, dtype="complex128")], + ) From 682a77a94b2d42d529fd3d348e90daea97a61477 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 25 Nov 2025 14:19:15 +0100 Subject: [PATCH 3/3] Numba fallback non-implemented RVs Closes https://github.com/pymc-devs/pytensor/issues/1245 --- pytensor/link/numba/dispatch/basic.py | 5 ++-- pytensor/link/numba/dispatch/random.py | 14 ++++++++++- tests/link/numba/test_random.py | 34 ++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 299479af07..07fa376699 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -20,6 +20,7 @@ ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType +from pytensor.tensor.random.type import RandomGeneratorType from pytensor.tensor.type import TensorType from pytensor.tensor.utils import hash_from_ndarray @@ -129,8 +130,8 @@ def get_numba_type( return CSRMatrixType(numba_dtype) if pytensor_type.format == "csc": return CSCMatrixType(numba_dtype) - - raise NotImplementedError() + elif isinstance(pytensor_type, RandomGeneratorType): + return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType") else: raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 91298eff45..a20881db7a 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -16,6 +16,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( direct_cast, + generate_fallback_impl, numba_funcify, register_funcify_and_cache_key, ) @@ -406,13 +407,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs [rv_node] = op.fgraph.apply_nodes rv_op: RandomVariable = rv_node.op + + try: + core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) + except NotImplementedError: + py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs) + + @numba_basic.numba_njit + def fallback_rv(_core_shape, *args): + return py_impl(*args) + + return fallback_rv, None + size = rv_op.size_param(rv_node) dist_params = rv_op.dist_params(rv_node) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) core_shape_len = get_vector_length(core_shape) inplace = rv_op.inplace - core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) nin = 1 + len(dist_params) # rng + params core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1) diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index c7da82b2db..20b1026e07 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -257,7 +257,7 @@ def test_multivariate_normal(): ], pt.as_tensor([3, 2]), ), - pytest.param( + ( ptr.hypergeometric, [ ( @@ -274,7 +274,6 @@ def test_multivariate_normal(): ), ], pt.as_tensor([3, 2]), - marks=pytest.mark.xfail, # Not implemented ), ( ptr.wald, @@ -722,3 +721,34 @@ def test_repeated_args(): final_node = fn.maker.fgraph.outputs[0].owner assert isinstance(final_node.op, RandomVariableWithCoreShape) assert final_node.inputs[-2] is final_node.inputs[-1] + + +def test_rv_fallback(): + """Test that random variables can fallback to object mode.""" + + class CustomRV(ptr.RandomVariable): + name = "custom" + signature = "()->()" + dtype = "float64" + + def rng_fn(self, rng, value, size=None): + # Just return the value plus a random number + return value + rng.standard_normal(size=size) + + custom_rv = CustomRV() + + rng = shared(np.random.default_rng(123)) + size = pt.scalar("size", dtype=int) + next_rng, x = custom_rv(np.pi, size=(size,), rng=rng).owner.outputs + + fn = function([size], x, updates={rng: next_rng}, mode="NUMBA") + + result1 = fn(1) + result2 = fn(1) + assert result1.shape == (1,) + assert result1 != result2 + + large_sample = fn(1000) + assert large_sample.shape == (1000,) + np.testing.assert_allclose(large_sample.mean(), np.pi, rtol=1e-2) + np.testing.assert_allclose(large_sample.std(), 1, rtol=1e-2)