-
Notifications
You must be signed in to change notification settings - Fork 149
Handle F-order and arbitrary index ndim in numba UnravelIndex and RavelMultiIndex #1770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
A follow up will be to make these ops purely symbolic, the impl here looks fine enough to use as a template. The less Ops we have, the easier to maintain multiple backends. |
| a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1] | ||
| return a.dot(new_arr.T).astype(np.int64) | ||
| # Dot product indices with strides | ||
| # (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it worth having a code path that does the matmul for speed? Or is the difference marginal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want matmul, as numba can only do it with floats. Also I doubt this is 1) ever being used and 2) ever being used with very large arrays. And the copy to float and back to int would cause a considerable overhead anyway
| unraveled_coords = (indices[..., None] // a) % shape | ||
|
|
||
| # Then transpose it to the front | ||
| # Numba doesn't have moveaxis (why would it), so we use transpose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xD
| *multi_index, dims = inp | ||
| res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order) | ||
| out[0][0] = np.asarray(res, node.outputs[0].dtype) | ||
| out[0][0] = np.asarray(res, "int64") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it ok to hard cast to 64? What if floatX is set to half precision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
floatX doesn't affect integers, and make_node already promised int64, we can change if we ever change make_node
In #811 I just disabled the custom impls and went into object mode
But in the cherry pick I got nerd-snipped and implemented it to work with arbitrary ndims and F-order while we are at it.
Only one case fallback to objected mode because I didn't want to bother with code-gen, we can revisit later. I bet these Ops are virtually unused anyway.
Also stopped losing static shape in the Ops.