Skip to content

Commit

Permalink
Attempt to fix indexing for Dask
Browse files Browse the repository at this point in the history
This is a naive attempt to make `isel` work with Dask

Known limitation: it triggers the computation.
  • Loading branch information
bzah committed Nov 3, 2021
1 parent 960010b commit 270873a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 4 additions & 4 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import timedelta
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union

import dask.array as da
import numpy as np
import pandas as pd

Expand All @@ -18,7 +19,7 @@
is_duck_dask_array,
sparse_array_type,
)
from .utils import maybe_cast_to_coords_dtype
from .utils import is_duck_array, maybe_cast_to_coords_dtype


def expanded_indexer(key, ndim):
Expand Down Expand Up @@ -307,7 +308,7 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -320,7 +321,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down Expand Up @@ -973,7 +974,6 @@ def _arrayize_vectorized_indexer(indexer, shape):

def _dask_array_with_chunks_hint(array, chunks):
"""Create a dask array using the chunks hint for dimensions of size > 1."""
import dask.array as da

if len(chunks) < array.ndim:
raise ValueError("not enough chunks in hint")
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from . import IndexerMaker, ReturnItem, assert_array_equal

da = pytest.importorskip("dask.array")

B = IndexerMaker(indexing.BasicIndexer)


Expand Down Expand Up @@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
expected = DataArray(expected_data)

assert [actual.data.item()] == [expected.data.item()]


def test_indexing_dask_array():
da = DataArray(
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
dims=("time", "x", "y"),
).chunk(dict(time=-1, x=1, y=1))
da[{"time": 9}] = 42

idx = da.argmax("time")
actual = da.isel(time=idx)

assert np.all(actual == 42)

0 comments on commit 270873a

Please sign in to comment.