Skip to content

Commit

Permalink
Attempt to fix indexing for Dask
Browse files Browse the repository at this point in the history
This is a naive attempt to make `isel` work with Dask

Known limitation: it triggers the computation.
  • Loading branch information
bzah committed Nov 15, 2021
1 parent a883ed0 commit f7991fd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
is_duck_dask_array,
sparse_array_type,
)
from .utils import maybe_cast_to_coords_dtype
from .utils import is_duck_array, maybe_cast_to_coords_dtype


def expanded_indexer(key, ndim):
Expand Down Expand Up @@ -307,7 +307,7 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -320,7 +320,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from . import IndexerMaker, ReturnItem, assert_array_equal

da = pytest.importorskip("dask.array")

B = IndexerMaker(indexing.BasicIndexer)


Expand Down Expand Up @@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
expected = DataArray(expected_data)

assert [actual.data.item()] == [expected.data.item()]


def test_indexing_dask_array():
da = DataArray(
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
dims=("time", "x", "y"),
).chunk(dict(time=-1, x=1, y=1))
da[{"time": 9}] = 42

idx = da.argmax("time")
actual = da.isel(time=idx)

assert np.all(actual == 42)

0 comments on commit f7991fd

Please sign in to comment.