Skip to content

Commit

Permalink
Add getitem to array protocol (#8406)
Browse files Browse the repository at this point in the history
* Update _typing.py

* Update _typing.py

* Update test_namedarray.py

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update _typing.py

* Update _typing.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] committed Dec 12, 2023
1 parent 562f2f8 commit 0bf38c2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
40 changes: 37 additions & 3 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Default(Enum):
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


_dtype = np.dtype
_DType = TypeVar("_DType", bound=np.dtype[Any])
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
Expand Down Expand Up @@ -69,9 +69,16 @@ def dtype(self) -> _DType_co:
_Dims = tuple[_Dim, ...]

_DimsLike = Union[str, Iterable[_Dim]]
_AttrsLike = Union[Mapping[Any, Any], None]

_dtype = np.dtype
# https://data-apis.org/array-api/latest/API_specification/indexing.html
# TODO: np.array_api was bugged and didn't allow (None,), but should!
# https://github.com/numpy/numpy/pull/25022
# https://github.com/data-apis/array-api/pull/674
_IndexKey = Union[int, slice, "ellipsis"]
_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...]
_IndexKeyLike = Union[_IndexKey, _IndexKeys]

_AttrsLike = Union[Mapping[Any, Any], None]


class _SupportsReal(Protocol[_T_co]):
Expand Down Expand Up @@ -113,6 +120,25 @@ class _arrayfunction(
Corresponds to np.ndarray.
"""

@overload
def __getitem__(
self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
) -> _arrayfunction[Any, _DType_co]:
...

@overload
def __getitem__(self, key: _IndexKeyLike, /) -> Any:
...

def __getitem__(
self,
key: _IndexKeyLike
| _arrayfunction[Any, Any]
| tuple[_arrayfunction[Any, Any], ...],
/,
) -> _arrayfunction[Any, _DType_co] | Any:
...

@overload
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
...
Expand Down Expand Up @@ -165,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
Corresponds to np.ndarray.
"""

def __getitem__(
self,
key: _IndexKeyLike
| Any, # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
/,
) -> _arrayapi[Any, Any]:
...

def __array_namespace__(self) -> ModuleType:
...

Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_AttrsLike,
_DimsLike,
_DType,
_IndexKeyLike,
_Shape,
duckarray,
)
Expand Down Expand Up @@ -58,6 +59,19 @@ class CustomArrayIndexable(
ExplicitlyIndexed,
Generic[_ShapeType_co, _DType_co],
):
def __getitem__(
self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], /
) -> CustomArrayIndexable[Any, _DType_co]:
if isinstance(key, CustomArrayIndexable):
if isinstance(key.array, type(self.array)):
# TODO: key.array is duckarray here, can it be narrowed down further?
# an _arrayapi cannot be used on a _arrayfunction for example.
return type(self)(array=self.array[key.array]) # type: ignore[index]
else:
raise TypeError("key must have the same array type as self")
else:
return type(self)(array=self.array[key])

def __array_namespace__(self) -> ModuleType:
return np

Expand Down

0 comments on commit 0bf38c2

Please sign in to comment.