Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement setitem syntax for .oindex and .vindex properties #8845

Merged
171 changes: 114 additions & 57 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,18 +326,23 @@ def as_integer_slice(value):


class IndexCallable:
"""Provide getitem syntax for a callable object."""
"""Provide getitem and setitem syntax for callable objects."""

__slots__ = ("func",)
__slots__ = ("getter", "setter")

def __init__(self, func):
self.func = func
def __init__(self, getter, setter=None):
self.getter = getter
self.setter = setter

def __getitem__(self, key):
return self.func(key)
return self.getter(key)

def __setitem__(self, key, value):
raise NotImplementedError
if self.setter is None:
raise NotImplementedError(
"Setting values is not supported for this indexer."
)
self.setter(key, value)


class BasicIndexer(ExplicitIndexer):
Expand Down Expand Up @@ -486,10 +491,24 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)

def _oindex_get(self, key):
raise NotImplementedError("This method should be overridden")
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
)

def _vindex_get(self, key):
raise NotImplementedError("This method should be overridden")
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_get method should be overridden"
)

def _oindex_set(self, key, value):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_set method should be overridden"
)

def _vindex_set(self, key, value):
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_set method should be overridden"
)

def _check_and_raise_if_non_basic_indexer(self, key):
if isinstance(key, (VectorizedIndexer, OuterIndexer)):
Expand All @@ -500,11 +519,11 @@ def _check_and_raise_if_non_basic_indexer(self, key):

@property
def oindex(self):
return IndexCallable(self._oindex_get)
return IndexCallable(self._oindex_get, self._oindex_set)

@property
def vindex(self):
return IndexCallable(self._vindex_get)
return IndexCallable(self._vindex_get, self._vindex_set)


class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
Expand Down Expand Up @@ -616,12 +635,18 @@ def __getitem__(self, indexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(self.array, self._updated_key(indexer))

def _vindex_set(self, key, value):
raise NotImplementedError(
"Lazy item assignment with the vectorized indexer is not yet "
"implemented. Load your data first by .load() or compute()."
)

def _oindex_set(self, key, value):
full_key = self._updated_key(key)
self.array.oindex[full_key] = value

def __setitem__(self, key, value):
if isinstance(key, VectorizedIndexer):
raise NotImplementedError(
"Lazy item assignment with the vectorized indexer is not yet "
"implemented. Load your data first by .load() or compute()."
)
self._check_and_raise_if_non_basic_indexer(key)
full_key = self._updated_key(key)
self.array[full_key] = value

Expand Down Expand Up @@ -657,7 +682,6 @@ def shape(self) -> tuple[int, ...]:
return np.broadcast(*self.key.tuple).shape

def get_duck_array(self):

if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
array = apply_indexer(self.array, self.key)
else:
Expand Down Expand Up @@ -739,8 +763,18 @@ def __getitem__(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _vindex_set(self, key, value):
self._ensure_copied()
self.array.vindex[key] = value

def _oindex_set(self, key, value):
self._ensure_copied()
self.array.oindex[key] = value

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
self._ensure_copied()

self.array[key] = value

def __deepcopy__(self, memo):
Expand Down Expand Up @@ -779,7 +813,14 @@ def __getitem__(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _vindex_set(self, key, value):
self.array.vindex[key] = value

def _oindex_set(self, key, value):
self.array.oindex[key] = value

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
self.array[key] = value


Expand Down Expand Up @@ -950,6 +991,16 @@ def apply_indexer(indexable, indexer):
return indexable[indexer]


def set_with_indexer(indexable, indexer, value):
"""Set values in an indexable object using an indexer."""
if isinstance(indexer, VectorizedIndexer):
indexable.vindex[indexer] = value
elif isinstance(indexer, OuterIndexer):
indexable.oindex[indexer] = value
else:
indexable[indexer] = value


def decompose_indexer(
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
Expand Down Expand Up @@ -1399,24 +1450,6 @@ def __init__(self, array):
)
self.array = array

def _indexing_array_and_key(self, key):
if isinstance(key, OuterIndexer):
array = self.array
key = _outer_to_numpy_indexer(key, self.array.shape)
elif isinstance(key, VectorizedIndexer):
array = NumpyVIndexAdapter(self.array)
key = key.tuple
elif isinstance(key, BasicIndexer):
array = self.array
# We want 0d slices rather than scalars. This is achieved by
# appending an ellipsis (see
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = key.tuple + (Ellipsis,)
else:
raise TypeError(f"unexpected key type: {type(key)}")

return array, key

def transpose(self, order):
return self.array.transpose(order)

Expand All @@ -1430,22 +1463,43 @@ def _vindex_get(self, key):

def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
array, key = self._indexing_array_and_key(key)

array = self.array
# We want 0d slices rather than scalars. This is achieved by
# appending an ellipsis (see
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = key.tuple + (Ellipsis,)
return array[key]

def __setitem__(self, key, value):
array, key = self._indexing_array_and_key(key)
def _safe_setitem(self, array, key, value):
try:
array[key] = value
except ValueError:
except ValueError as exc:
# More informative exception if read-only view
if not array.flags.writeable and not array.flags.owndata:
raise ValueError(
"Assignment destination is a view. "
"Do you want to .copy() array first?"
)
else:
raise
raise exc

def _oindex_set(self, key, value):
key = _outer_to_numpy_indexer(key, self.array.shape)
self._safe_setitem(self.array, key, value)

def _vindex_set(self, key, value):
array = NumpyVIndexAdapter(self.array)
self._safe_setitem(array, key.tuple, value)

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
array = self.array
# We want 0d slices rather than scalars. This is achieved by
# appending an ellipsis (see
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = key.tuple + (Ellipsis,)
self._safe_setitem(array, key, value)


class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
Expand Down Expand Up @@ -1488,13 +1542,15 @@ def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def _oindex_set(self, key, value):
self.array[key.tuple] = value

def _vindex_set(self, key, value):
raise TypeError("Vectorized indexing is not supported")

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")
self._check_and_raise_if_non_basic_indexer(key)
self.array[key.tuple] = value

def transpose(self, order):
xp = self.array.__array_namespace__()
Expand Down Expand Up @@ -1530,19 +1586,20 @@ def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def _oindex_set(self, key, value):
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
if num_non_slices > 1:
raise NotImplementedError(
"xarray can't set arrays with multiple " "array indices to dask yet."
)
self.array[key.tuple] = value

def _vindex_set(self, key, value):
self.array.vindex[key.tuple] = value

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
self.array.vindex[key.tuple] = value
elif isinstance(key, OuterIndexer):
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
if num_non_slices > 1:
raise NotImplementedError(
"xarray can't set arrays with multiple "
"array indices to dask yet."
)
self.array[key.tuple] = value
self._check_and_raise_if_non_basic_indexer(key)
self.array[key.tuple] = value

def transpose(self, order):
return self.array.transpose(order)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def __setitem__(self, key, value):
value = np.moveaxis(value, new_order, range(len(new_order)))

indexable = as_indexable(self._data)
indexable[index_tuple] = value
indexing.set_with_indexer(indexable, index_tuple, value)

@property
def encoding(self) -> dict[Any, Any]:
Expand Down
68 changes: 59 additions & 9 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,28 @@
B = IndexerMaker(indexing.BasicIndexer)


class TestIndexCallable:
def test_getitem(self):
def getter(key):
return key * 2

indexer = indexing.IndexCallable(getter)
assert indexer[3] == 6
assert indexer[0] == 0
assert indexer[-1] == -2

def test_setitem(self):
def getter(key):
return key * 2

def setter(key, value):
raise NotImplementedError("Setter not implemented")

indexer = indexing.IndexCallable(getter, setter)
with pytest.raises(NotImplementedError):
indexer[3] = 6


class TestIndexers:
def set_to_zero(self, x, i):
x = x.copy()
Expand Down Expand Up @@ -361,15 +383,8 @@ def test_vectorized_lazily_indexed_array(self) -> None:

def check_indexing(v_eager, v_lazy, indexers):
for indexer in indexers:
if isinstance(indexer, indexing.VectorizedIndexer):
actual = v_lazy.vindex[indexer]
expected = v_eager.vindex[indexer]
elif isinstance(indexer, indexing.OuterIndexer):
actual = v_lazy.oindex[indexer]
expected = v_eager.oindex[indexer]
else:
actual = v_lazy[indexer]
expected = v_eager[indexer]
actual = v_lazy[indexer]
expected = v_eager[indexer]
assert expected.shape == actual.shape
assert isinstance(
actual._data,
Expand Down Expand Up @@ -406,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers):
]
check_indexing(v_eager, v_lazy, indexers)

def test_lazily_indexed_array_vindex_setitem(self) -> None:

lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30))

# vectorized indexing
indexer = indexing.VectorizedIndexer(
(np.array([0, 1]), np.array([0, 1]), slice(None, None, None))
)
with pytest.raises(
NotImplementedError,
match=r"Lazy item assignment with the vectorized indexer is not yet",
):
lazy.vindex[indexer] = 0

@pytest.mark.parametrize(
"indexer_class, key, value",
[
(indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10),
(indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10),
],
)
def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None:
original = np.random.rand(10, 20, 30)
x = indexing.NumpyIndexingAdapter(original)
lazy = indexing.LazilyIndexedArray(x)

if indexer_class is indexing.BasicIndexer:
indexer = indexer_class(key)
lazy[indexer] = value
elif indexer_class is indexing.OuterIndexer:
indexer = indexer_class(key)
lazy.oindex[indexer] = value

assert_array_equal(original[key], value)


class TestCopyOnWriteArray:
def test_setitem(self) -> None:
Expand Down