Skip to content

Conversation

@ricardoV94
Copy link
Member

In #811 I just disabled the custom impls and went into object mode

But in the cherry pick I got nerd-snipped and implemented it to work with arbitrary ndims and F-order while we are at it.

Only one case fallback to objected mode because I didn't want to bother with code-gen, we can revisit later. I bet these Ops are virtually unused anyway.

Also stopped losing static shape in the Ops.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 4, 2025

A follow up will be to make these ops purely symbolic, the impl here looks fine enough to use as a template.

The less Ops we have, the easier to maintain multiple backends.

a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
# Dot product indices with strides
# (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth having a code path that does the matmul for speed? Or is the difference marginal?

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want matmul, as numba can only do it with floats. Also I doubt this is 1) ever being used and 2) ever being used with very large arrays. And the copy to float and back to int would cause a considerable overhead anyway

unraveled_coords = (indices[..., None] // a) % shape

# Then transpose it to the front
# Numba doesn't have moveaxis (why would it), so we use transpose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xD

*multi_index, dims = inp
res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order)
out[0][0] = np.asarray(res, node.outputs[0].dtype)
out[0][0] = np.asarray(res, "int64")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it ok to hard cast to 64? What if floatX is set to half precision

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floatX doesn't affect integers, and make_node already promised int64, we can change if we ever change make_node

@ricardoV94 ricardoV94 merged commit a782753 into pymc-devs:main Dec 5, 2025
56 checks passed
@ricardoV94 ricardoV94 deleted the ravel_unravel_index branch December 5, 2025 08:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants