Skip to content

Commit

Permalink
Fix some issues and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
bzah committed Oct 19, 2021
1 parent 71d5ff6 commit 952fe85
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
11 changes: 4 additions & 7 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from contextlib import suppress
from datetime import timedelta
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
import dask.array as da

import dask.array as da
import numpy as np
import pandas as pd

Expand All @@ -19,7 +19,7 @@
is_duck_dask_array,
sparse_array_type,
)
from .utils import maybe_cast_to_coords_dtype
from .utils import is_duck_array, maybe_cast_to_coords_dtype


def expanded_indexer(key, ndim):
Expand Down Expand Up @@ -308,7 +308,7 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray) or isinstance(k, da.Array):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -321,10 +321,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
if isinstance(k, da.Array):
k = da.asarray(k, dtype=np.int64)
else:
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from . import IndexerMaker, ReturnItem, assert_array_equal

da = pytest.importorskip("dask.array")

B = IndexerMaker(indexing.BasicIndexer)


Expand Down Expand Up @@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
expected = DataArray(expected_data)

assert [actual.data.item()] == [expected.data.item()]


def test_indexing_dask_array():
da = DataArray(
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
dims=("time", "x", "y"),
).chunk(dict(time=-1, x=1, y=1))
da[{"time": 9}] = 42

idx = da.argmax("time")
actual = da.isel(time=idx)

assert np.all(actual == 42)

0 comments on commit 952fe85

Please sign in to comment.