From 776820dd2a79f703012170ab000c8a3f38f253b8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 30 Nov 2025 15:20:06 +0100 Subject: [PATCH 1/2] Numba UnravelIndex: Handle arbitrary indices ndim and F-order --- pytensor/link/numba/dispatch/extra_ops.py | 69 +++++++++++++++-------- pytensor/tensor/extra_ops.py | 6 +- tests/link/numba/test_extra_ops.py | 38 ++++++++----- 3 files changed, 71 insertions(+), 42 deletions(-) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index e4a52b1a32..dee5d9ccdf 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -261,41 +261,60 @@ def unique(x): @register_funcify_and_cache_key(UnravelIndex) def numba_funcify_UnravelIndex(op, node, **kwargs): - order = op.order - - if order != "C": - raise NotImplementedError( - "Numba does not support the `order` argument in `numpy.unravel_index`" - ) + out_ndim = node.outputs[0].type.ndim - if len(node.outputs) == 1: - - @numba_basic.numba_njit(inline="always") - def maybe_expand_dim(arr): - return arr - - else: + if out_ndim == 0: + # Creating a tuple of 0d arrays in numba is basically impossible without codegen, so just go to obj_mode + return generate_fallback_impl(op, node=node), None - @numba_basic.numba_njit(inline="always") - def maybe_expand_dim(arr): - return np.expand_dims(arr, 1) + c_order = op.order == "C" + inp_ndim = node.inputs[0].type.ndim + transpose_axes = (inp_ndim, *range(inp_ndim)) @numba_basic.numba_njit - def unravelindex(arr, shape): + def unravelindex(indices, shape): a = np.ones(len(shape), dtype=np.int64) - a[1:] = shape[:0:-1] - a = np.cumprod(a)[::-1] + if c_order: + # C-Order: Reverse shape (ignore dim0), cumulative product, then reverse back + # Strides: [dim1*dim2, dim2, 1] + a[1:] = shape[:0:-1] + a = np.cumprod(a)[::-1] + else: + # F-Order: Standard shape, cumulative product + # Strides: [1, dim0, dim0*dim1] + a[1:] = shape[:-1] + a = np.cumprod(a) + + # Broadcast with a and shape on the last axis + unraveled_coords = (indices[..., None] // a) % shape - # PyTensor actually returns a `tuple` of these values, instead of an - # `ndarray`; however, this `ndarray` result should be able to be - # unpacked into a `tuple`, so this discrepancy shouldn't really matter - return ((maybe_expand_dim(arr) // a) % shape).T + # Then transpose it to the front + # Numba doesn't have moveaxis (why would it), so we use transpose + # res = np.moveaxis(res, -1, 0) + unraveled_coords = unraveled_coords.transpose(transpose_axes) + # This should be a tuple, but the array can be unpacked + # into multiple variables with the same effect by the outer function + # (special case for single entry is handled with an outer function below) + return unraveled_coords + + cache_version = 1 cache_key = sha256( - str((type(op), op.order, len(node.outputs))).encode() + str((type(op), op.order, len(node.outputs), cache_version)).encode() ).hexdigest() - return unravelindex, cache_key + if len(node.outputs) == 1: + + @numba_basic.numba_njit + def unravel_index_single_item(arr, shape): + # Unpack single entry + (res,) = unravelindex(arr, shape) + return res + + return unravel_index_single_item, cache_key + + else: + return unravelindex, cache_key @register_funcify_default_op_cache_key(SearchsortedOp) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 33a5a6b8dc..be2d61cd62 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1304,13 +1304,11 @@ def make_node(self, indices, dims): if dims.ndim != 1: raise TypeError("dims must be a 1D array") + out_type = indices.type.clone(dtype="int64") return Apply( self, [indices, dims], - [ - TensorType(dtype="int64", shape=(None,) * indices.type.ndim)() - for i in range(ptb.get_vector_length(dims)) - ], + [out_type() for _i in range(ptb.get_vector_length(dims))], ) def infer_shape(self, fgraph, node, input_shapes): diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index abc0bf3f59..a2a7128407 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -1,4 +1,5 @@ import contextlib +from contextlib import nullcontext import numpy as np import pytest @@ -295,37 +296,48 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): @pytest.mark.parametrize( - "arr, shape, order, exc", + "arr, shape, requires_obj_mode", [ + ( + (pt.lscalar(), np.array(9, dtype="int64")), + pt.as_tensor([2, 3, 4]), + True, + ), ( (pt.lvector(), np.array([9, 15, 1], dtype="int64")), pt.as_tensor([2, 3, 4]), - "C", - None, + False, ), ( (pt.lvector(), np.array([1, 0], dtype="int64")), pt.as_tensor([2]), - "C", - None, + False, ), ( - (pt.lvector(), np.array([9, 15, 1], dtype="int64")), + (pt.lmatrix(), np.array([[9, 15, 1], [1, 9, 15]], dtype="int64")), pt.as_tensor([2, 3, 4]), - "F", - NotImplementedError, + False, ), ], ) -def test_UnravelIndex(arr, shape, order, exc): +def test_UnravelIndex(arr, shape, requires_obj_mode): arr, test_arr = arr - g = extra_ops.UnravelIndex(order)(arr, shape) - - cm = contextlib.suppress() if exc is None else pytest.raises(exc) + g_c = extra_ops.UnravelIndex("C")(arr, shape) + g_f = extra_ops.UnravelIndex("F")(arr, shape) + if shape.type.shape == (1,): + outputs = [g_c, g_f] + else: + outputs = [*g_c, *g_f] + + cm = ( + pytest.warns(UserWarning, match="object mode") + if requires_obj_mode + else nullcontext() + ) with cm: compare_numba_and_py( [arr], - g, + outputs, [test_arr], ) From 57ab68a49667649c24e39a97ee58b0b5f8b472d8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 3 Dec 2025 16:37:09 +0100 Subject: [PATCH 2/2] Numba RavelMultiIndex: Handle arbitrary indices ndim and F-order --- pytensor/link/numba/dispatch/extra_ops.py | 98 ++++++++++------------- pytensor/tensor/extra_ops.py | 10 +-- tests/link/numba/test_extra_ops.py | 42 +++++++--- 3 files changed, 75 insertions(+), 75 deletions(-) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index dee5d9ccdf..fc00a56447 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -135,65 +135,49 @@ def filldiagonaloffset(a, val, offset): def numba_funcify_RavelMultiIndex(op, node, **kwargs): mode = op.mode order = op.order + vec_indices = node.inputs[0].type.ndim > 0 - if order != "C": - raise NotImplementedError( - "Numba does not implement `order` in `numpy.ravel_multi_index`" - ) - - if mode == "raise": - - @numba_basic.numba_njit - def mode_fn(*args): - raise ValueError("invalid entry in coordinates array") - - elif mode == "wrap": - - @numba_basic.numba_njit(inline="always") - def mode_fn(new_arr, i, j, v, d): - new_arr[i, j] = v % d - - elif mode == "clip": - - @numba_basic.numba_njit(inline="always") - def mode_fn(new_arr, i, j, v, d): - new_arr[i, j] = min(max(v, 0), d - 1) - - if node.inputs[0].ndim == 0: - - @numba_basic.numba_njit - def ravelmultiindex(*inp): - shape = inp[-1] - arr = np.stack(inp[:-1]) - - new_arr = arr.T.astype(np.float64).copy() - for i, b in enumerate(new_arr): - if b < 0 or b >= shape[i]: - mode_fn(new_arr, i, 0, b, shape[i]) - - a = np.ones(len(shape), dtype=np.float64) - a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1] - return np.array(a.dot(new_arr.T), dtype=np.int64) - - else: + @numba_basic.numba_njit + def ravelmultiindex(*inp): + shape = inp[-1] + # Concatenate indices along last axis + stacked_indices = np.stack(inp[:-1], axis=-1) + + # Manage invalid indices + for i, dim_limit in enumerate(shape): + if mode == "wrap": + stacked_indices[..., i] %= dim_limit + elif mode == "clip": + dim_indices = stacked_indices[..., i] + stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1) + else: # raise + dim_indices = stacked_indices[..., i] + invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i]) + # Cannot call np.any on a boolean + if vec_indices: + invalid_indices = invalid_indices.any() + if invalid_indices: + raise ValueError("invalid entry in coordinates array") + + # Calculate Strides based on Order + a = np.ones(len(shape), dtype=np.int64) + if order == "C": + # C-Order: Last dimension moves fastest (Strides: large -> small -> 1) + # For shape (3, 4, 5): Multipliers are (20, 5, 1) + if len(shape) > 1: + a[:-1] = np.cumprod(shape[:0:-1])[::-1] + else: # order == "F" + # F-Order: First dimension moves fastest (Strides: 1 -> small -> large) + # For shape (3, 4, 5): Multipliers are (1, 3, 12) + if len(shape) > 1: + a[1:] = np.cumprod(shape[:-1]) + + # Dot product indices with strides + # (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support) + return np.asarray((stacked_indices * a).sum(-1)) - @numba_basic.numba_njit - def ravelmultiindex(*inp): - shape = inp[-1] - arr = np.stack(inp[:-1]) - - new_arr = arr.T.astype(np.float64).copy() - for i, b in enumerate(new_arr): - # no strict argument to this zip because numba doesn't support it - for j, (d, v) in enumerate(zip(shape, b)): - if v < 0 or v >= d: - mode_fn(new_arr, i, j, v, d) - - a = np.ones(len(shape), dtype=np.float64) - a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1] - return a.dot(new_arr.T).astype(np.int64) - - return ravelmultiindex + cache_version = 1 + return ravelmultiindex, cache_version @register_funcify_default_op_cache_key(Repeat) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index be2d61cd62..0c6e59d72f 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1371,8 +1371,7 @@ def __init__(self, mode="raise", order="C"): self.order = order def make_node(self, *inp): - multi_index = [ptb.as_tensor_variable(i) for i in inp[:-1]] - dims = ptb.as_tensor_variable(inp[-1]) + *multi_index, dims = map(ptb.as_tensor_variable, inp) for i in multi_index: if i.dtype not in int_dtypes: @@ -1382,19 +1381,20 @@ def make_node(self, *inp): if dims.ndim != 1: raise TypeError("dims must be a 1D array") + out_type = multi_index[0].type.clone(dtype="int64") return Apply( self, [*multi_index, dims], - [TensorType(dtype="int64", shape=(None,) * multi_index[0].type.ndim)()], + [out_type()], ) def infer_shape(self, fgraph, node, input_shapes): return [input_shapes[0]] def perform(self, node, inp, out): - multi_index, dims = inp[:-1], inp[-1] + *multi_index, dims = inp res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order) - out[0][0] = np.asarray(res, node.outputs[0].dtype) + out[0][0] = np.asarray(res, "int64") def ravel_multi_index(multi_index, dims, mode="raise", order="C"): diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index a2a7128407..813ec7765d 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -7,6 +7,7 @@ import pytensor.tensor as pt from pytensor import config from pytensor.tensor import extra_ops +from pytensor.tensor.extra_ops import RavelMultiIndex from tests.link.numba.test_basic import compare_numba_and_py @@ -133,35 +134,34 @@ def test_FillDiagonalOffset(a, val, offset): @pytest.mark.parametrize( - "arr, shape, mode, order, exc", + "arr, shape, mode, exc", [ ( tuple((pt.lscalar(), v) for v in np.array([0])), (pt.lvector(), np.array([2])), "raise", - "C", None, ), ( tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])), (pt.lvector(), np.array([2, 3, 4])), "raise", - "C", None, ), ( tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])), (pt.lvector(), np.array([2, 3, 4])), "raise", - "C", None, ), ( - tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])), + tuple( + (pt.lmatrix(), np.broadcast_to(v, (3, 2)).copy()) + for v in np.array([[0, 1], [2, 0], [1, 3]]) + ), (pt.lvector(), np.array([2, 3, 4])), "raise", - "F", - NotImplementedError, + None, ), ( tuple( @@ -169,7 +169,6 @@ def test_FillDiagonalOffset(a, val, offset): ), (pt.lvector(), np.array([2, 3, 4])), "raise", - "C", ValueError, ), ( @@ -178,7 +177,15 @@ def test_FillDiagonalOffset(a, val, offset): ), (pt.lvector(), np.array([2, 3, 4])), "wrap", - "C", + None, + ), + ( + tuple( + (pt.ltensor3(), np.broadcast_to(v, (2, 2, 3)).copy()) + for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + ), + (pt.lvector(), np.array([2, 3, 4])), + "wrap", None, ), ( @@ -187,21 +194,30 @@ def test_FillDiagonalOffset(a, val, offset): ), (pt.lvector(), np.array([2, 3, 4])), "clip", - "C", + None, + ), + ( + tuple( + (pt.lmatrix(), np.broadcast_to(v, (2, 3)).copy()) + for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + ), + (pt.lvector(), np.array([2, 3, 4])), + "clip", None, ), ], ) -def test_RavelMultiIndex(arr, shape, mode, order, exc): +def test_RavelMultiIndex(arr, shape, mode, exc): arr, test_arr = zip(*arr, strict=True) shape, test_shape = shape - g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape) + g_c = RavelMultiIndex(mode, order="C")(*arr, shape) + g_f = RavelMultiIndex(mode, order="F")(*arr, shape) cm = contextlib.suppress() if exc is None else pytest.raises(exc) with cm: compare_numba_and_py( [*arr, shape], - g, + [g_c, g_f], [*test_arr, test_shape], )