Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow indexing unindexed dimensions using dask arrays #5873

Merged
merged 27 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bc4271c
Attempt to fix indexing for Dask
bzah Oct 15, 2021
73696e9
Works now.
dcherian Dec 14, 2021
fad4348
avoid importorskip
dcherian Dec 14, 2021
7dadbf2
More tests and fixes
dcherian Dec 14, 2021
b7c382b
Merge branch 'main' into fix/dask_indexing
bzah Feb 15, 2022
46a4b16
Merge branch 'main' into fix/dask_indexing
dcherian Mar 18, 2022
ec4d6ee
Raise nicer error when indexing with boolean dask array
dcherian Mar 18, 2022
944dbac
Annotate tests
dcherian Mar 18, 2022
fb5b01e
Merge branch 'main' into fix/dask_indexing
bzah Mar 24, 2022
335b5da
edit query tests
dcherian Apr 12, 2022
a11be00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2022
d5e7646
Merge branch 'main' into fix/dask_indexing
dcherian Jun 24, 2022
9cde88d
Fixes #4276
dcherian Jun 24, 2022
8df0c2a
Add xfail notes.
dcherian Jun 24, 2022
9f5e31b
backcompat: vendor np.broadcast_shapes
dcherian Jun 24, 2022
3306329
Small improvement
dcherian Jun 24, 2022
32b73c3
fix: Handle scalars properly.
dcherian Jun 24, 2022
d6170ce
fix bad test
dcherian Jun 25, 2022
aa1df48
Check computes with setitem
dcherian Jun 25, 2022
c93b297
Merge branch 'main' into fix/dask_indexing
dcherian Jun 25, 2022
97fa188
Merge remote-tracking branch 'upstream/main' into fix/dask_indexing
dcherian Feb 28, 2023
3f008c8
Merge branch 'main' into fix/dask_indexing
dcherian Mar 3, 2023
ff42585
Better error
dcherian Feb 28, 2023
d15c7fe
Cleanup
dcherian Mar 6, 2023
220edc8
Raise nice error with VectorizedIndexer and dask.
dcherian Mar 6, 2023
8445120
Add whats-new
dcherian Mar 6, 2023
75a6299
Merge branch 'main' into fix/dask_indexing
dcherian Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
is_duck_dask_array,
sparse_array_type,
)
from .utils import maybe_cast_to_coords_dtype
from .utils import is_duck_array, maybe_cast_to_coords_dtype


def expanded_indexer(key, ndim):
Expand Down Expand Up @@ -307,7 +307,7 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -320,7 +320,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from . import IndexerMaker, ReturnItem, assert_array_equal

da = pytest.importorskip("dask.array")

B = IndexerMaker(indexing.BasicIndexer)


Expand Down Expand Up @@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
expected = DataArray(expected_data)

assert [actual.data.item()] == [expected.data.item()]


def test_indexing_dask_array():
da = DataArray(
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
dims=("time", "x", "y"),
).chunk(dict(time=-1, x=1, y=1))
da[{"time": 9}] = 42

idx = da.argmax("time")
actual = da.isel(time=idx)

assert np.all(actual == 42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

Suggested change
idx = da.argmax("time")
actual = da.isel(time=idx)
assert np.all(actual == 42)
with raise_if_dask_computes():
actual = da.isel(time=dask.array.from_array([9], chunks=(1,))
expected = da.isel(time=[9])
assert_identical(expected, actual)

with from . import raise_if_dask_computes at the top.

We could test 9, [9], and [9, 9]

Copy link
Contributor Author

@bzah bzah Dec 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edited
When the indexer is an independent Dask array, the computation is triggered.
By modifying the unit test as follow, it raises the "RuntimeError: Too many computes. Total: 1 > max: 0"

def test_indexing_dask_array():
    da = DataArray(
        np.ones(10 * 3 * 3).reshape((10, 3, 3)),
        dims=("time", "x", "y"),
    ).chunk(dict(time=-1, x=1, y=1))
    with raise_if_dask_computes():
        actual = da.isel(time=dask.array.from_array([9], chunks=(1,)))

I thought it would only be triggered when the indexer is constructed from da.
That's why the unit test was using idx = da.argmax("time").

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would only be triggered when the indexer is constructed from da.

Yeah that raise_if_dask_computes decorator is great for checking these ideas.

I've got it fixed but we should wait for @benbovy to finish with #5692 (though it doesn't look like that PR conflicts with these changes). It would be useful to add more test cases here.

Copy link
Contributor

@dcherian dcherian Dec 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other tests could be the ones in #4276, #4663, and this one from #2511

import dask.array as da
import numpy as np
import xarray as xr
from xarray.tests import raise_if_dask_computes

darr = xr.DataArray(data=[0.2, 0.4, 0.6], coords={"z": range(3)}, dims=("z",))
indexer = xr.DataArray(
    data=np.random.randint(0, 3, 8).reshape(4, 2).astype(int),
    coords={"y": range(4), "x": range(2)},
    dims=("y", "x"),
)
with raise_if_dask_computes():
    actual = darr[indexer.chunk({"y": 2})]

xr.testing.assert_identical(actual, darr[indexer])

Which actually does work now, but end up computing the indexer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I was not expecting these changes to be merged, it was only to support the discussion on the issue but it's nice if can be integrated.
I'll try to add the unit tests.