-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow indexing unindexed dimensions using dask arrays #5873
Changes from 17 commits
bc4271c
73696e9
fad4348
7dadbf2
b7c382b
46a4b16
ec4d6ee
944dbac
fb5b01e
335b5da
a11be00
d5e7646
9cde88d
8df0c2a
9f5e31b
3306329
32b73c3
d6170ce
aa1df48
c93b297
97fa188
3f008c8
ff42585
d15c7fe
220edc8
8445120
75a6299
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,9 @@ | |
import pandas as pd | ||
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] | ||
|
||
from .npcompat import broadcast_shapes | ||
from .options import OPTIONS | ||
from .pycompat import is_duck_array | ||
|
||
try: | ||
import bottleneck as bn | ||
|
@@ -109,7 +111,11 @@ def _advanced_indexer_subspaces(key): | |
return (), () | ||
|
||
non_slices = [k for k in key if not isinstance(k, slice)] | ||
ndim = len(np.broadcast(*non_slices).shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is just an optimization. Happy to add it in a different PR. |
||
ndim = len( | ||
broadcast_shapes( | ||
*[item.shape if is_duck_array(item) else (0,) for item in non_slices] | ||
) | ||
) | ||
mixed_positions = advanced_index_positions[0] + np.arange(ndim) | ||
vindex_positions = np.arange(ndim) | ||
return mixed_positions, vindex_positions | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
cupy_array_type, | ||
dask_array_type, | ||
integer_types, | ||
is_0d_dask_array, | ||
is_duck_dask_array, | ||
sparse_array_type, | ||
) | ||
|
@@ -601,11 +602,12 @@ def _broadcast_indexes(self, key): | |
key = self._item_key_to_tuple(key) # key is a tuple | ||
# key is a tuple of full size | ||
key = indexing.expanded_indexer(key, self.ndim) | ||
# Convert a scalar Variable to an integer | ||
# Convert a scalar Variable to a 0d-array | ||
key = tuple( | ||
k.data.item() if isinstance(k, Variable) and k.ndim == 0 else k for k in key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is the major change here. Everything else is avoid calling |
||
k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Necessary for #4276 |
||
) | ||
# Convert a 0d-array to an integer | ||
# Convert a 0d numpy arrays to an integer | ||
# dask 0d arrays are passed through | ||
key = tuple( | ||
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key | ||
) | ||
|
@@ -646,7 +648,8 @@ def _validate_indexers(self, key): | |
for dim, k in zip(self.dims, key): | ||
if not isinstance(k, BASIC_INDEXING_TYPES): | ||
if not isinstance(k, Variable): | ||
k = np.asarray(k) | ||
if not is_duck_array(k): | ||
k = np.asarray(k) | ||
if k.ndim > 1: | ||
raise IndexError( | ||
"Unlabeled multi-dimensional array cannot be " | ||
|
@@ -663,6 +666,13 @@ def _validate_indexers(self, key): | |
"{}-dimensional boolean indexing is " | ||
"not supported. ".format(k.ndim) | ||
) | ||
if is_duck_dask_array(k.data): | ||
raise KeyError( | ||
"Indexing with a boolean dask array is not allowed. " | ||
"This will result in a dask array of unknown shape. " | ||
"Such arrays are unsupported by Xarray." | ||
"Please compute the indexer first using .compute()" | ||
) | ||
if getattr(k, "dims", (dim,)) != (dim,): | ||
raise IndexError( | ||
"Boolean indexer should be unlabeled or on the " | ||
|
@@ -673,18 +683,20 @@ def _validate_indexers(self, key): | |
) | ||
|
||
def _broadcast_indexes_outer(self, key): | ||
# drop dim if k is integer or if k is a 0d dask array | ||
dims = tuple( | ||
k.dims[0] if isinstance(k, Variable) else dim | ||
for k, dim in zip(key, self.dims) | ||
if not isinstance(k, integer_types) | ||
if (not isinstance(k, integer_types) and not is_0d_dask_array(k)) | ||
) | ||
|
||
new_key = [] | ||
for k in key: | ||
if isinstance(k, Variable): | ||
k = k.data | ||
if not isinstance(k, BASIC_INDEXING_TYPES): | ||
k = np.asarray(k) | ||
if not is_duck_array(k): | ||
k = np.asarray(k) | ||
if k.size == 0: | ||
# Slice by empty list; numpy could not infer the dtype | ||
k = k.astype(int) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4063,45 +4063,49 @@ def test_query( | |
d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype( | ||
object | ||
) | ||
if backend == "numpy": | ||
aa = DataArray(data=a, dims=["x"], name="a") | ||
bb = DataArray(data=b, dims=["x"], name="b") | ||
cc = DataArray(data=c, dims=["y"], name="c") | ||
dd = DataArray(data=d, dims=["z"], name="d") | ||
aa = DataArray(data=a, dims=["x"], name="a", coords={"a2": ("x", a)}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the tests to actually be lazy. |
||
bb = DataArray(data=b, dims=["x"], name="b", coords={"b2": ("x", b)}) | ||
cc = DataArray(data=c, dims=["y"], name="c", coords={"c2": ("y", c)}) | ||
dd = DataArray(data=d, dims=["z"], name="d", coords={"d2": ("z", d)}) | ||
|
||
elif backend == "dask": | ||
if backend == "dask": | ||
import dask.array as da | ||
|
||
aa = DataArray(data=da.from_array(a, chunks=3), dims=["x"], name="a") | ||
bb = DataArray(data=da.from_array(b, chunks=3), dims=["x"], name="b") | ||
cc = DataArray(data=da.from_array(c, chunks=7), dims=["y"], name="c") | ||
dd = DataArray(data=da.from_array(d, chunks=12), dims=["z"], name="d") | ||
aa = aa.copy(data=da.from_array(a, chunks=3)) | ||
bb = bb.copy(data=da.from_array(b, chunks=3)) | ||
cc = cc.copy(data=da.from_array(c, chunks=7)) | ||
dd = dd.copy(data=da.from_array(d, chunks=12)) | ||
|
||
# query single dim, single variable | ||
actual = aa.query(x="a > 5", engine=engine, parser=parser) | ||
with raise_if_dask_computes(): | ||
actual = aa.query(x="a2 > 5", engine=engine, parser=parser) | ||
expect = aa.isel(x=(a > 5)) | ||
assert_identical(expect, actual) | ||
|
||
# query single dim, single variable, via dict | ||
actual = aa.query(dict(x="a > 5"), engine=engine, parser=parser) | ||
with raise_if_dask_computes(): | ||
actual = aa.query(dict(x="a2 > 5"), engine=engine, parser=parser) | ||
expect = aa.isel(dict(x=(a > 5))) | ||
assert_identical(expect, actual) | ||
|
||
# query single dim, single variable | ||
actual = bb.query(x="b > 50", engine=engine, parser=parser) | ||
with raise_if_dask_computes(): | ||
actual = bb.query(x="b2 > 50", engine=engine, parser=parser) | ||
expect = bb.isel(x=(b > 50)) | ||
assert_identical(expect, actual) | ||
|
||
# query single dim, single variable | ||
actual = cc.query(y="c < .5", engine=engine, parser=parser) | ||
with raise_if_dask_computes(): | ||
actual = cc.query(y="c2 < .5", engine=engine, parser=parser) | ||
expect = cc.isel(y=(c < 0.5)) | ||
assert_identical(expect, actual) | ||
|
||
# query single dim, single string variable | ||
if parser == "pandas": | ||
# N.B., this query currently only works with the pandas parser | ||
# xref https://github.com/pandas-dev/pandas/issues/40436 | ||
actual = dd.query(z='d == "bar"', engine=engine, parser=parser) | ||
with raise_if_dask_computes(): | ||
actual = dd.query(z='d2 == "bar"', engine=engine, parser=parser) | ||
expect = dd.isel(z=(d == "bar")) | ||
assert_identical(expect, actual) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k might be a 0d dask array (see below and #4276)