Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement setitem syntax for .oindex and .vindex properties #8845

Merged
142 changes: 104 additions & 38 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,18 +326,23 @@ def as_integer_slice(value):


class IndexCallable:
"""Provide getitem syntax for a callable object."""
"""Provide getitem and setitem syntax for callable objects."""

__slots__ = ("func",)
__slots__ = ("func_get", "func_set")
andersy005 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, func):
self.func = func
def __init__(self, func_get, func_set=None):
self.func_get = func_get
self.func_set = func_set

def __getitem__(self, key):
return self.func(key)
return self.func_get(key)

def __setitem__(self, key, value):
raise NotImplementedError
if self.func_set is None:
raise NotImplementedError(
"Setting values is not supported for this indexer."
)
self.func_set(key, value)


class BasicIndexer(ExplicitIndexer):
Expand Down Expand Up @@ -486,10 +491,24 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)

def _oindex_get(self, key):
raise NotImplementedError("This method should be overridden")
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
)

def _vindex_get(self, key):
raise NotImplementedError("This method should be overridden")
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_get method should be overridden"
)

def _oindex_set(self, key, value):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_set method should be overridden"
)

def _vindex_set(self, key, value):
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_set method should be overridden"
)

def _check_and_raise_if_non_basic_indexer(self, key):
if isinstance(key, (VectorizedIndexer, OuterIndexer)):
Expand All @@ -500,11 +519,11 @@ def _check_and_raise_if_non_basic_indexer(self, key):

@property
def oindex(self):
return IndexCallable(self._oindex_get)
return IndexCallable(self._oindex_get, self._oindex_set)

@property
def vindex(self):
return IndexCallable(self._vindex_get)
return IndexCallable(self._vindex_get, self._vindex_set)


class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
Expand Down Expand Up @@ -616,12 +635,18 @@ def __getitem__(self, indexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(self.array, self._updated_key(indexer))

def _vindex_set(self, key, value):
raise NotImplementedError(
"Lazy item assignment with the vectorized indexer is not yet "
"implemented. Load your data first by .load() or compute()."
)

def _oindex_set(self, key, value):
full_key = self._updated_key(key)
self.array[full_key] = value
andersy005 marked this conversation as resolved.
Show resolved Hide resolved

def __setitem__(self, key, value):
if isinstance(key, VectorizedIndexer):
raise NotImplementedError(
"Lazy item assignment with the vectorized indexer is not yet "
"implemented. Load your data first by .load() or compute()."
)
self._check_and_raise_if_non_basic_indexer(key)
full_key = self._updated_key(key)
self.array[full_key] = value

Expand Down Expand Up @@ -657,7 +682,6 @@ def shape(self) -> tuple[int, ...]:
return np.broadcast(*self.key.tuple).shape

def get_duck_array(self):

if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
array = apply_indexer(self.array, self.key)
else:
Expand Down Expand Up @@ -739,8 +763,18 @@ def __getitem__(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _vindex_set(self, key, value):
self._ensure_copied()
self.array.vindex[key] = value

def _oindex_set(self, key, value):
self._ensure_copied()
self.array.oindex[key] = value

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
self._ensure_copied()

self.array[key] = value

def __deepcopy__(self, memo):
Expand Down Expand Up @@ -779,7 +813,14 @@ def __getitem__(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _vindex_set(self, key, value):
self.array.vindex[key] = value

def _oindex_set(self, key, value):
self.array.oindex[key] = value

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
self.array[key] = value


Expand Down Expand Up @@ -950,6 +991,16 @@ def apply_indexer(indexable, indexer):
return indexable[indexer]


def set_with_indexer(indexable, indexer, value):
"""Set values in an indexable object using an indexer."""
if isinstance(indexer, VectorizedIndexer):
indexable.vindex[indexer] = value
elif isinstance(indexer, OuterIndexer):
indexable.oindex[indexer] = value
else:
indexable[indexer] = value


def decompose_indexer(
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
Expand Down Expand Up @@ -1433,19 +1484,31 @@ def __getitem__(self, key):
array, key = self._indexing_array_and_key(key)
return array[key]

def __setitem__(self, key, value):
array, key = self._indexing_array_and_key(key)
def _safe_setitem(self, array, key, value):
try:
array[key] = value
except ValueError:
except ValueError as exc:
# More informative exception if read-only view
if not array.flags.writeable and not array.flags.owndata:
raise ValueError(
"Assignment destination is a view. "
"Do you want to .copy() array first?"
)
else:
raise
raise exc

def _oindex_set(self, key, value):
key = _outer_to_numpy_indexer(key, self.array.shape)
self._safe_setitem(self.array, key, value)

def _vindex_set(self, key, value):
array = NumpyVIndexAdapter(self.array)
self._safe_setitem(array, key.tuple, value)

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
array, key = self._indexing_array_and_key(key)
andersy005 marked this conversation as resolved.
Show resolved Hide resolved
self._safe_setitem(array, key, value)


class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
Expand Down Expand Up @@ -1488,13 +1551,15 @@ def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def _oindex_set(self, key, value):
self.array[key.tuple] = value

def _vindex_set(self, key, value):
raise TypeError("Vectorized indexing is not supported")

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")
self._check_and_raise_if_non_basic_indexer(key)
self.array[key.tuple] = value

def transpose(self, order):
xp = self.array.__array_namespace__()
Expand Down Expand Up @@ -1530,19 +1595,20 @@ def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def _oindex_set(self, key, value):
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
if num_non_slices > 1:
raise NotImplementedError(
"xarray can't set arrays with multiple " "array indices to dask yet."
)
self.array[key.tuple] = value

def _vindex_set(self, key, value):
self.array.vindex[key.tuple] = value

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
self.array.vindex[key.tuple] = value
elif isinstance(key, OuterIndexer):
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
if num_non_slices > 1:
raise NotImplementedError(
"xarray can't set arrays with multiple "
"array indices to dask yet."
)
self.array[key.tuple] = value
self._check_and_raise_if_non_basic_indexer(key)
self.array[key.tuple] = value

def transpose(self, order):
return self.array.transpose(order)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def __setitem__(self, key, value):
value = np.moveaxis(value, new_order, range(len(new_order)))

indexable = as_indexable(self._data)
indexable[index_tuple] = value
indexing.set_with_indexer(indexable, index_tuple, value)

@property
def encoding(self) -> dict[Any, Any]:
Expand Down