Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 84 additions & 81 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,65 +135,49 @@ def filldiagonaloffset(a, val, offset):
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
vec_indices = node.inputs[0].type.ndim > 0

if order != "C":
raise NotImplementedError(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)

if mode == "raise":

@numba_basic.numba_njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")

elif mode == "wrap":

@numba_basic.numba_njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d

elif mode == "clip":

@numba_basic.numba_njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)

if node.inputs[0].ndim == 0:

@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])

new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
if b < 0 or b >= shape[i]:
mode_fn(new_arr, i, 0, b, shape[i])

a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return np.array(a.dot(new_arr.T), dtype=np.int64)

else:

@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])

new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
# no strict argument to this zip because numba doesn't support it
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
@numba_basic.numba_njit
def ravelmultiindex(*inp):
shape = inp[-1]
# Concatenate indices along last axis
stacked_indices = np.stack(inp[:-1], axis=-1)

# Manage invalid indices
for i, dim_limit in enumerate(shape):
if mode == "wrap":
stacked_indices[..., i] %= dim_limit
elif mode == "clip":
dim_indices = stacked_indices[..., i]
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
else: # raise
dim_indices = stacked_indices[..., i]
invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])
# Cannot call np.any on a boolean
if vec_indices:
invalid_indices = invalid_indices.any()
if invalid_indices:
raise ValueError("invalid entry in coordinates array")

# Calculate Strides based on Order
a = np.ones(len(shape), dtype=np.int64)
if order == "C":
# C-Order: Last dimension moves fastest (Strides: large -> small -> 1)
# For shape (3, 4, 5): Multipliers are (20, 5, 1)
if len(shape) > 1:
a[:-1] = np.cumprod(shape[:0:-1])[::-1]
else: # order == "F"
# F-Order: First dimension moves fastest (Strides: 1 -> small -> large)
# For shape (3, 4, 5): Multipliers are (1, 3, 12)
if len(shape) > 1:
a[1:] = np.cumprod(shape[:-1])

a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
# Dot product indices with strides
# (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth having a code path that does the matmul for speed? Or is the difference marginal?

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want matmul, as numba can only do it with floats. Also I doubt this is 1) ever being used and 2) ever being used with very large arrays. And the copy to float and back to int would cause a considerable overhead anyway

return np.asarray((stacked_indices * a).sum(-1))

return ravelmultiindex
cache_version = 1
return ravelmultiindex, cache_version


@register_funcify_default_op_cache_key(Repeat)
Expand Down Expand Up @@ -261,41 +245,60 @@ def unique(x):

@register_funcify_and_cache_key(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order

if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
out_ndim = node.outputs[0].type.ndim

if len(node.outputs) == 1:

@numba_basic.numba_njit(inline="always")
def maybe_expand_dim(arr):
return arr

else:
if out_ndim == 0:
# Creating a tuple of 0d arrays in numba is basically impossible without codegen, so just go to obj_mode
return generate_fallback_impl(op, node=node), None

@numba_basic.numba_njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
c_order = op.order == "C"
inp_ndim = node.inputs[0].type.ndim
transpose_axes = (inp_ndim, *range(inp_ndim))

@numba_basic.numba_njit
def unravelindex(arr, shape):
def unravelindex(indices, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
if c_order:
# C-Order: Reverse shape (ignore dim0), cumulative product, then reverse back
# Strides: [dim1*dim2, dim2, 1]
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
else:
# F-Order: Standard shape, cumulative product
# Strides: [1, dim0, dim0*dim1]
a[1:] = shape[:-1]
a = np.cumprod(a)

# Broadcast with a and shape on the last axis
unraveled_coords = (indices[..., None] // a) % shape

# Then transpose it to the front
# Numba doesn't have moveaxis (why would it), so we use transpose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xD

# res = np.moveaxis(res, -1, 0)
unraveled_coords = unraveled_coords.transpose(transpose_axes)

# PyTensor actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
# This should be a tuple, but the array can be unpacked
# into multiple variables with the same effect by the outer function
# (special case for single entry is handled with an outer function below)
return unraveled_coords

cache_version = 1
cache_key = sha256(
str((type(op), op.order, len(node.outputs))).encode()
str((type(op), op.order, len(node.outputs), cache_version)).encode()
).hexdigest()

return unravelindex, cache_key
if len(node.outputs) == 1:

@numba_basic.numba_njit
def unravel_index_single_item(arr, shape):
# Unpack single entry
(res,) = unravelindex(arr, shape)
return res

return unravel_index_single_item, cache_key

else:
return unravelindex, cache_key


@register_funcify_default_op_cache_key(SearchsortedOp)
Expand Down
16 changes: 7 additions & 9 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,13 +1304,11 @@ def make_node(self, indices, dims):
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")

out_type = indices.type.clone(dtype="int64")
return Apply(
self,
[indices, dims],
[
TensorType(dtype="int64", shape=(None,) * indices.type.ndim)()
for i in range(ptb.get_vector_length(dims))
],
[out_type() for _i in range(ptb.get_vector_length(dims))],
)

def infer_shape(self, fgraph, node, input_shapes):
Expand Down Expand Up @@ -1373,8 +1371,7 @@ def __init__(self, mode="raise", order="C"):
self.order = order

def make_node(self, *inp):
multi_index = [ptb.as_tensor_variable(i) for i in inp[:-1]]
dims = ptb.as_tensor_variable(inp[-1])
*multi_index, dims = map(ptb.as_tensor_variable, inp)

for i in multi_index:
if i.dtype not in int_dtypes:
Expand All @@ -1384,19 +1381,20 @@ def make_node(self, *inp):
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")

out_type = multi_index[0].type.clone(dtype="int64")
return Apply(
self,
[*multi_index, dims],
[TensorType(dtype="int64", shape=(None,) * multi_index[0].type.ndim)()],
[out_type()],
)

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]

def perform(self, node, inp, out):
multi_index, dims = inp[:-1], inp[-1]
*multi_index, dims = inp
res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order)
out[0][0] = np.asarray(res, node.outputs[0].dtype)
out[0][0] = np.asarray(res, "int64")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it ok to hard cast to 64? What if floatX is set to half precision

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floatX doesn't affect integers, and make_node already promised int64, we can change if we ever change make_node



def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
Expand Down
Loading
Loading