Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a bound TypeVar for DataArray and Dataset methods #8208

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 24 additions & 29 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import suppress
from html import escape
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -45,10 +45,11 @@
DatetimeLike,
DTypeLikeSave,
ScalarOrArray,
Self,
SideOptions,
T_Chunks,
T_DataWithCoords,
T_Variable,
T_Xarray,
)
from xarray.core.variable import Variable

Expand Down Expand Up @@ -381,11 +382,11 @@ class DataWithCoords(AttrAccessMixin):
__slots__ = ("_close",)

def squeeze(
self: T_DataWithCoords,
self,
dim: Hashable | Iterable[Hashable] | None = None,
drop: bool = False,
axis: int | Iterable[int] | None = None,
) -> T_DataWithCoords:
) -> Self:
"""Return a new object with squeezed data.

Parameters
Expand All @@ -411,15 +412,15 @@ def squeeze(
numpy.squeeze
"""
dims = get_squeeze_dims(self, dim, axis)
return self.isel(drop=drop, **{d: 0 for d in dims})
return self.isel(drop=drop, indexers={d: 0 for d in dims})

def clip(
self: T_DataWithCoords,
self: Self,
min: ScalarOrArray | None = None,
max: ScalarOrArray | None = None,
*,
keep_attrs: bool | None = None,
) -> T_DataWithCoords:
) -> Self:
"""
Return an array whose values are limited to ``[min, max]``.
At least one of max or min must be given.
Expand Down Expand Up @@ -472,10 +473,10 @@ def _calc_assign_results(
return {k: v(self) if callable(v) else v for k, v in kwargs.items()}

def assign_coords(
self: T_DataWithCoords,
self: Self,
coords: Mapping[Any, Any] | None = None,
**coords_kwargs: Any,
) -> T_DataWithCoords:
) -> Self:
"""Assign new coordinates to this object.

Returns a new object with all the original data in addition to the new
Expand Down Expand Up @@ -620,9 +621,7 @@ def assign_coords(
data.coords.update(results)
return data

def assign_attrs(
self: T_DataWithCoords, *args: Any, **kwargs: Any
) -> T_DataWithCoords:
def assign_attrs(self: Self, *args: Any, **kwargs: Any) -> Self:
"""Assign new attrs to this object.

Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``.
Expand Down Expand Up @@ -810,11 +809,11 @@ def pipe(
return func(self, *args, **kwargs)

def rolling_exp(
self: T_DataWithCoords,
self: Self,
window: Mapping[Any, int] | None = None,
window_type: str = "span",
**window_kwargs,
) -> RollingExp[T_DataWithCoords]:
) -> RollingExp[T_Xarray]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> RollingExp[T_Xarray]:
) -> RollingExp[Self]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this doesn't work, similar to the reason outlined below — T_Xarray isn't compatible with T_DataWithCoords...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this doesn't work, similar to the reason outlined below — T_Xarray isn't compatible with T_DataWithCoords...

I think thats why originally I was typing RollingExp using T_DataWithCoords

"""
Exponentially-weighted moving window.
Similar to EWM in pandas
Expand Down Expand Up @@ -849,7 +848,7 @@ def rolling_exp(

window = either_dict_or_kwargs(window, window_kwargs, "rolling_exp")

return rolling_exp.RollingExp(self, window, window_type)
return rolling_exp.RollingExp(cast("T_Xarray", self), window, window_type)

def _resample(
self,
Expand Down Expand Up @@ -1062,8 +1061,8 @@ def _resample(
)

def where(
self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False
) -> T_DataWithCoords:
self: Self, cond: Any, other: Any = dtypes.NA, drop: bool = False
) -> Self:
"""Filter elements from this object according to a condition.

This operation follows the normal broadcasting and alignment rules that
Expand Down Expand Up @@ -1178,8 +1177,8 @@ def _dataset_indexer(dim: Hashable) -> DataArray:
for dim in cond.sizes.keys():
indexers[dim] = _get_indexer(dim)

self = self.isel(**indexers)
cond = cond.isel(**indexers)
self = self.isel(indexers=indexers)
cond = cond.isel(indexers=indexers)

return ops.where_method(self, cond, other)

Expand All @@ -1205,9 +1204,7 @@ def close(self) -> None:
self._close()
self._close = None

def isnull(
self: T_DataWithCoords, keep_attrs: bool | None = None
) -> T_DataWithCoords:
def isnull(self: Self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is a missing value.

Parameters
Expand Down Expand Up @@ -1250,9 +1247,7 @@ def isnull(
keep_attrs=keep_attrs,
)

def notnull(
self: T_DataWithCoords, keep_attrs: bool | None = None
) -> T_DataWithCoords:
def notnull(self: Self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is not a missing value.

Parameters
Expand Down Expand Up @@ -1295,7 +1290,7 @@ def notnull(
keep_attrs=keep_attrs,
)

def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
def isin(self: Self, test_elements: Any) -> Self:
"""Tests each value in the array for whether it is in test elements.

Parameters
Expand Down Expand Up @@ -1344,15 +1339,15 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords:
)

def astype(
self: T_DataWithCoords,
self: Self,
dtype,
*,
order=None,
casting=None,
subok=None,
copy=None,
keep_attrs=True,
) -> T_DataWithCoords:
) -> Self:
"""
Copy of the xarray object, with data cast to a specified type.
Leaves coordinate dtype unchanged.
Expand Down Expand Up @@ -1419,7 +1414,7 @@ def astype(
dask="allowed",
)

def __enter__(self: T_DataWithCoords) -> T_DataWithCoords:
def __enter__(self: Self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
def __setitem__(self, key: Hashable, value: Any) -> None:
self.update({key: value})

def update(self, other: Mapping[Any, Any]) -> None:
def update(self: Self, other: Mapping[Any, Any]) -> None:
"""Update this Coordinates variables with other coordinate variables."""

if not len(other):
Expand Down
7 changes: 3 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
QueryEngineOptions,
QueryParserOptions,
ReindexMethodOptions,
Self,
SideOptions,
T_DataArray,
T_Xarray,
Expand Down Expand Up @@ -2986,14 +2987,12 @@ def transpose(
def T(self: T_DataArray) -> T_DataArray:
return self.transpose()

# change type of self and return to T_DataArray once
# https://github.com/python/mypy/issues/12846 is resolved
def drop_vars(
self,
self: Self,
names: Hashable | Iterable[Hashable],
*,
errors: ErrorOptions = "raise",
) -> DataArray:
) -> Self:
"""Returns an array with dropped variables.

Parameters
Expand Down
16 changes: 13 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
from numbers import Number
from operator import methodcaller
from os import PathLike
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
cast,
overload,
)

import numpy as np

Expand Down Expand Up @@ -146,6 +155,7 @@
QueryEngineOptions,
QueryParserOptions,
ReindexMethodOptions,
Self,
SideOptions,
T_Xarray,
)
Expand Down Expand Up @@ -1114,13 +1124,13 @@ def _replace(
return obj

def _replace_with_new_dims(
self: T_Dataset,
self: Self,
variables: dict[Hashable, Variable],
coord_names: set | None = None,
attrs: dict[Hashable, Any] | None | Default = _default,
indexes: dict[Hashable, Index] | None = None,
inplace: bool = False,
) -> T_Dataset:
) -> Self:
"""Replace variables with recalculated dimensions."""
dims = calculate_dimensions(variables)
return self._replace(
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Generic,
Literal,
Union,
cast,
)

import numpy as np
Expand Down Expand Up @@ -256,7 +257,7 @@ def to_dataarray(self) -> DataArray:

def _ensure_1d(
group: T_Group, obj: T_Xarray
) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable],]:
) -> tuple[T_Group, T_Xarray, str | None, list[Hashable]]:
# 1D cases: do nothing
if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1:
return group, obj, None, []
Expand All @@ -271,7 +272,7 @@ def _ensure_1d(
inserted_dims = [dim for dim in group.dims if dim not in group.coords]
newgroup = group.stack({stacked_dim: orig_dims})
newobj = obj.stack({stacked_dim: orig_dims})
return newgroup, newobj, stacked_dim, inserted_dims
return newgroup, cast(T_Xarray, newobj), stacked_dim, inserted_dims

raise TypeError(
f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}."
Expand Down Expand Up @@ -800,7 +801,10 @@ def __getitem__(self, key: GroupKey) -> T_Xarray:
"""
Get DataArray or Dataset corresponding to a particular group label.
"""
return self._obj.isel({self._group_dim: self.groups[key]})
return cast(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's weird that these casts are now necessary...

T_Xarray,
self._obj.isel(indexers={self._group_dim: self.groups[key]}),
)

def __len__(self) -> int:
(grouper,) = self.groupers
Expand All @@ -822,7 +826,7 @@ def __repr__(self) -> str:
def _iter_grouped(self) -> Iterator[T_Xarray]:
"""Iterate over each element in this group"""
for indices in self._group_indices:
yield self._obj.isel({self._group_dim: indices})
yield cast(T_Xarray, self._obj.isel({self._group_dim: indices}))

def _infer_concat_args(self, applied_example):
(grouper,) = self.groupers
Expand Down
11 changes: 7 additions & 4 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,17 @@ def _wrapper(

if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, aligned[0], *args, **kwargs)
template_indexes = set(template._indexes)
template_inferred = infer_template(func, aligned[0], *args, **kwargs)
template = template_inferred
template_indexes = set(template_inferred._indexes)
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template._indexes[k] for k in new_indexes})
indexes.update({k: template_inferred._indexes[k] for k in new_indexes})
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
dim: input_chunks[dim]
for dim in template_inferred.dims
if dim in input_chunks
}

else:
Expand Down