-
Notifications
You must be signed in to change notification settings - Fork 149
Handle F-order and arbitrary index ndim in numba UnravelIndex and RavelMultiIndex #1770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,65 +135,49 @@ def filldiagonaloffset(a, val, offset): | |
| def numba_funcify_RavelMultiIndex(op, node, **kwargs): | ||
| mode = op.mode | ||
| order = op.order | ||
| vec_indices = node.inputs[0].type.ndim > 0 | ||
|
|
||
| if order != "C": | ||
| raise NotImplementedError( | ||
| "Numba does not implement `order` in `numpy.ravel_multi_index`" | ||
| ) | ||
|
|
||
| if mode == "raise": | ||
|
|
||
| @numba_basic.numba_njit | ||
| def mode_fn(*args): | ||
| raise ValueError("invalid entry in coordinates array") | ||
|
|
||
| elif mode == "wrap": | ||
|
|
||
| @numba_basic.numba_njit(inline="always") | ||
| def mode_fn(new_arr, i, j, v, d): | ||
| new_arr[i, j] = v % d | ||
|
|
||
| elif mode == "clip": | ||
|
|
||
| @numba_basic.numba_njit(inline="always") | ||
| def mode_fn(new_arr, i, j, v, d): | ||
| new_arr[i, j] = min(max(v, 0), d - 1) | ||
|
|
||
| if node.inputs[0].ndim == 0: | ||
|
|
||
| @numba_basic.numba_njit | ||
| def ravelmultiindex(*inp): | ||
| shape = inp[-1] | ||
| arr = np.stack(inp[:-1]) | ||
|
|
||
| new_arr = arr.T.astype(np.float64).copy() | ||
| for i, b in enumerate(new_arr): | ||
| if b < 0 or b >= shape[i]: | ||
| mode_fn(new_arr, i, 0, b, shape[i]) | ||
|
|
||
| a = np.ones(len(shape), dtype=np.float64) | ||
| a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1] | ||
| return np.array(a.dot(new_arr.T), dtype=np.int64) | ||
|
|
||
| else: | ||
|
|
||
| @numba_basic.numba_njit | ||
| def ravelmultiindex(*inp): | ||
| shape = inp[-1] | ||
| arr = np.stack(inp[:-1]) | ||
|
|
||
| new_arr = arr.T.astype(np.float64).copy() | ||
| for i, b in enumerate(new_arr): | ||
| # no strict argument to this zip because numba doesn't support it | ||
| for j, (d, v) in enumerate(zip(shape, b)): | ||
| if v < 0 or v >= d: | ||
| mode_fn(new_arr, i, j, v, d) | ||
| @numba_basic.numba_njit | ||
| def ravelmultiindex(*inp): | ||
| shape = inp[-1] | ||
| # Concatenate indices along last axis | ||
| stacked_indices = np.stack(inp[:-1], axis=-1) | ||
|
|
||
| # Manage invalid indices | ||
| for i, dim_limit in enumerate(shape): | ||
| if mode == "wrap": | ||
| stacked_indices[..., i] %= dim_limit | ||
| elif mode == "clip": | ||
| dim_indices = stacked_indices[..., i] | ||
| stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1) | ||
| else: # raise | ||
| dim_indices = stacked_indices[..., i] | ||
| invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i]) | ||
| # Cannot call np.any on a boolean | ||
| if vec_indices: | ||
| invalid_indices = invalid_indices.any() | ||
| if invalid_indices: | ||
| raise ValueError("invalid entry in coordinates array") | ||
|
|
||
| # Calculate Strides based on Order | ||
| a = np.ones(len(shape), dtype=np.int64) | ||
| if order == "C": | ||
| # C-Order: Last dimension moves fastest (Strides: large -> small -> 1) | ||
| # For shape (3, 4, 5): Multipliers are (20, 5, 1) | ||
| if len(shape) > 1: | ||
| a[:-1] = np.cumprod(shape[:0:-1])[::-1] | ||
| else: # order == "F" | ||
| # F-Order: First dimension moves fastest (Strides: 1 -> small -> large) | ||
| # For shape (3, 4, 5): Multipliers are (1, 3, 12) | ||
| if len(shape) > 1: | ||
| a[1:] = np.cumprod(shape[:-1]) | ||
|
|
||
| a = np.ones(len(shape), dtype=np.float64) | ||
| a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1] | ||
| return a.dot(new_arr.T).astype(np.int64) | ||
| # Dot product indices with strides | ||
| # (allow arbitrary left operand ndim and int dtype, which numba matmul doesn't support) | ||
| return np.asarray((stacked_indices * a).sum(-1)) | ||
|
|
||
| return ravelmultiindex | ||
| cache_version = 1 | ||
| return ravelmultiindex, cache_version | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(Repeat) | ||
|
|
@@ -261,41 +245,60 @@ def unique(x): | |
|
|
||
| @register_funcify_and_cache_key(UnravelIndex) | ||
| def numba_funcify_UnravelIndex(op, node, **kwargs): | ||
| order = op.order | ||
|
|
||
| if order != "C": | ||
| raise NotImplementedError( | ||
| "Numba does not support the `order` argument in `numpy.unravel_index`" | ||
| ) | ||
| out_ndim = node.outputs[0].type.ndim | ||
|
|
||
| if len(node.outputs) == 1: | ||
|
|
||
| @numba_basic.numba_njit(inline="always") | ||
| def maybe_expand_dim(arr): | ||
| return arr | ||
|
|
||
| else: | ||
| if out_ndim == 0: | ||
| # Creating a tuple of 0d arrays in numba is basically impossible without codegen, so just go to obj_mode | ||
| return generate_fallback_impl(op, node=node), None | ||
|
|
||
| @numba_basic.numba_njit(inline="always") | ||
| def maybe_expand_dim(arr): | ||
| return np.expand_dims(arr, 1) | ||
| c_order = op.order == "C" | ||
| inp_ndim = node.inputs[0].type.ndim | ||
| transpose_axes = (inp_ndim, *range(inp_ndim)) | ||
|
|
||
| @numba_basic.numba_njit | ||
| def unravelindex(arr, shape): | ||
| def unravelindex(indices, shape): | ||
| a = np.ones(len(shape), dtype=np.int64) | ||
| a[1:] = shape[:0:-1] | ||
| a = np.cumprod(a)[::-1] | ||
| if c_order: | ||
| # C-Order: Reverse shape (ignore dim0), cumulative product, then reverse back | ||
| # Strides: [dim1*dim2, dim2, 1] | ||
| a[1:] = shape[:0:-1] | ||
| a = np.cumprod(a)[::-1] | ||
| else: | ||
| # F-Order: Standard shape, cumulative product | ||
| # Strides: [1, dim0, dim0*dim1] | ||
| a[1:] = shape[:-1] | ||
| a = np.cumprod(a) | ||
|
|
||
| # Broadcast with a and shape on the last axis | ||
| unraveled_coords = (indices[..., None] // a) % shape | ||
|
|
||
| # Then transpose it to the front | ||
| # Numba doesn't have moveaxis (why would it), so we use transpose | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xD |
||
| # res = np.moveaxis(res, -1, 0) | ||
| unraveled_coords = unraveled_coords.transpose(transpose_axes) | ||
|
|
||
| # PyTensor actually returns a `tuple` of these values, instead of an | ||
| # `ndarray`; however, this `ndarray` result should be able to be | ||
| # unpacked into a `tuple`, so this discrepancy shouldn't really matter | ||
| return ((maybe_expand_dim(arr) // a) % shape).T | ||
| # This should be a tuple, but the array can be unpacked | ||
| # into multiple variables with the same effect by the outer function | ||
| # (special case for single entry is handled with an outer function below) | ||
| return unraveled_coords | ||
|
|
||
| cache_version = 1 | ||
| cache_key = sha256( | ||
| str((type(op), op.order, len(node.outputs))).encode() | ||
| str((type(op), op.order, len(node.outputs), cache_version)).encode() | ||
| ).hexdigest() | ||
|
|
||
| return unravelindex, cache_key | ||
| if len(node.outputs) == 1: | ||
|
|
||
| @numba_basic.numba_njit | ||
| def unravel_index_single_item(arr, shape): | ||
| # Unpack single entry | ||
| (res,) = unravelindex(arr, shape) | ||
| return res | ||
|
|
||
| return unravel_index_single_item, cache_key | ||
|
|
||
| else: | ||
| return unravelindex, cache_key | ||
|
|
||
|
|
||
| @register_funcify_default_op_cache_key(SearchsortedOp) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1304,13 +1304,11 @@ def make_node(self, indices, dims): | |
| if dims.ndim != 1: | ||
| raise TypeError("dims must be a 1D array") | ||
|
|
||
| out_type = indices.type.clone(dtype="int64") | ||
| return Apply( | ||
| self, | ||
| [indices, dims], | ||
| [ | ||
| TensorType(dtype="int64", shape=(None,) * indices.type.ndim)() | ||
| for i in range(ptb.get_vector_length(dims)) | ||
| ], | ||
| [out_type() for _i in range(ptb.get_vector_length(dims))], | ||
| ) | ||
|
|
||
| def infer_shape(self, fgraph, node, input_shapes): | ||
|
|
@@ -1373,8 +1371,7 @@ def __init__(self, mode="raise", order="C"): | |
| self.order = order | ||
|
|
||
| def make_node(self, *inp): | ||
| multi_index = [ptb.as_tensor_variable(i) for i in inp[:-1]] | ||
| dims = ptb.as_tensor_variable(inp[-1]) | ||
| *multi_index, dims = map(ptb.as_tensor_variable, inp) | ||
|
|
||
| for i in multi_index: | ||
| if i.dtype not in int_dtypes: | ||
|
|
@@ -1384,19 +1381,20 @@ def make_node(self, *inp): | |
| if dims.ndim != 1: | ||
| raise TypeError("dims must be a 1D array") | ||
|
|
||
| out_type = multi_index[0].type.clone(dtype="int64") | ||
| return Apply( | ||
| self, | ||
| [*multi_index, dims], | ||
| [TensorType(dtype="int64", shape=(None,) * multi_index[0].type.ndim)()], | ||
| [out_type()], | ||
| ) | ||
|
|
||
| def infer_shape(self, fgraph, node, input_shapes): | ||
| return [input_shapes[0]] | ||
|
|
||
| def perform(self, node, inp, out): | ||
| multi_index, dims = inp[:-1], inp[-1] | ||
| *multi_index, dims = inp | ||
| res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order) | ||
| out[0][0] = np.asarray(res, node.outputs[0].dtype) | ||
| out[0][0] = np.asarray(res, "int64") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it ok to hard cast to 64? What if floatX is set to half precision
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. floatX doesn't affect integers, and make_node already promised int64, we can change if we ever change make_node |
||
|
|
||
|
|
||
| def ravel_multi_index(multi_index, dims, mode="raise", order="C"): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it worth having a code path that does the matmul for speed? Or is the difference marginal?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want matmul, as numba can only do it with floats. Also I doubt this is 1) ever being used and 2) ever being used with very large arrays. And the copy to float and back to int would cause a considerable overhead anyway