Skip to content

Commit

Permalink
TYP: Use Self instead of class-bound TypeVar II (pandas/core/arrays/) (
Browse files Browse the repository at this point in the history
…#51497)

TYP: Use Self for type checking (pandas/core/arrays/)
  • Loading branch information
topper-123 committed Mar 15, 2023
1 parent c8ea34c commit f8a37a7
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 209 deletions.
47 changes: 17 additions & 30 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Any,
Literal,
Sequence,
TypeVar,
cast,
overload,
)
Expand All @@ -23,11 +22,11 @@
PositionalIndexer2D,
PositionalIndexerTuple,
ScalarIndexer,
Self,
SequenceIndexer,
Shape,
TakeIndexer,
npt,
type_t,
)
from pandas.errors import AbstractMethodError
from pandas.util._decorators import doc
Expand Down Expand Up @@ -61,10 +60,6 @@
from pandas.core.indexers import check_array_indexer
from pandas.core.sorting import nargminmax

NDArrayBackedExtensionArrayT = TypeVar(
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
)

if TYPE_CHECKING:
from pandas._typing import (
NumpySorter,
Expand Down Expand Up @@ -153,13 +148,13 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
return arr.view(dtype=dtype) # type: ignore[arg-type]

def take(
self: NDArrayBackedExtensionArrayT,
self,
indices: TakeIndexer,
*,
allow_fill: bool = False,
fill_value: Any = None,
axis: AxisInt = 0,
) -> NDArrayBackedExtensionArrayT:
) -> Self:
if allow_fill:
fill_value = self._validate_scalar(fill_value)

Expand Down Expand Up @@ -218,17 +213,17 @@ def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[overri
raise NotImplementedError
return nargminmax(self, "argmax", axis=axis)

def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
def unique(self) -> Self:
new_data = unique(self._ndarray)
return self._from_backing_data(new_data)

@classmethod
@doc(ExtensionArray._concat_same_type)
def _concat_same_type(
cls: type[NDArrayBackedExtensionArrayT],
to_concat: Sequence[NDArrayBackedExtensionArrayT],
cls,
to_concat: Sequence[Self],
axis: AxisInt = 0,
) -> NDArrayBackedExtensionArrayT:
) -> Self:
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
raise ValueError("to_concat must have the same dtype (tz)", dtypes)
Expand Down Expand Up @@ -268,15 +263,15 @@ def __getitem__(self, key: ScalarIndexer) -> Any:

@overload
def __getitem__(
self: NDArrayBackedExtensionArrayT,
self,
key: SequenceIndexer | PositionalIndexerTuple,
) -> NDArrayBackedExtensionArrayT:
) -> Self:
...

def __getitem__(
self: NDArrayBackedExtensionArrayT,
self,
key: PositionalIndexer2D,
) -> NDArrayBackedExtensionArrayT | Any:
) -> Self | Any:
if lib.is_integer(key):
# fast-path
result = self._ndarray[key]
Expand All @@ -303,9 +298,7 @@ def _fill_mask_inplace(
func(self._ndarray.T, limit=limit, mask=mask.T)

@doc(ExtensionArray.fillna)
def fillna(
self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
) -> NDArrayBackedExtensionArrayT:
def fillna(self, value=None, method=None, limit=None) -> Self:
value, method = validate_fillna_kwargs(
value, method, validate_scalar_dict_value=False
)
Expand Down Expand Up @@ -369,9 +362,7 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:

np.putmask(self._ndarray, mask, value)

def _where(
self: NDArrayBackedExtensionArrayT, mask: npt.NDArray[np.bool_], value
) -> NDArrayBackedExtensionArrayT:
def _where(self: Self, mask: npt.NDArray[np.bool_], value) -> Self:
"""
Analogue to np.where(mask, self, value)
Expand All @@ -393,9 +384,7 @@ def _where(
# ------------------------------------------------------------------------
# Index compat methods

def insert(
self: NDArrayBackedExtensionArrayT, loc: int, item
) -> NDArrayBackedExtensionArrayT:
def insert(self, loc: int, item) -> Self:
"""
Make new ExtensionArray inserting new item at location. Follows
Python list.append semantics for negative values.
Expand Down Expand Up @@ -461,10 +450,10 @@ def value_counts(self, dropna: bool = True) -> Series:
return Series(result._values, index=index, name=result.name)

def _quantile(
self: NDArrayBackedExtensionArrayT,
self,
qs: npt.NDArray[np.float64],
interpolation: str,
) -> NDArrayBackedExtensionArrayT:
) -> Self:
# TODO: disable for Categorical if not ordered?

mask = np.asarray(self.isna())
Expand All @@ -488,9 +477,7 @@ def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
# numpy-like methods

@classmethod
def _empty(
cls: type_t[NDArrayBackedExtensionArrayT], shape: Shape, dtype: ExtensionDtype
) -> NDArrayBackedExtensionArrayT:
def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:
"""
Analogous to np.empty(shape, dtype=dtype)
Expand Down
36 changes: 14 additions & 22 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Callable,
Literal,
Sequence,
TypeVar,
cast,
)

Expand All @@ -24,6 +23,7 @@
NpDtype,
PositionalIndexer,
Scalar,
Self,
SortKind,
TakeIndexer,
TimeAmbiguous,
Expand Down Expand Up @@ -140,8 +140,6 @@ def floordiv_compat(

from pandas import Series

ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")


def get_unit_from_pa_dtype(pa_dtype):
# https://github.com/pandas-dev/pandas/pull/50998#discussion_r1100344804
Expand Down Expand Up @@ -419,16 +417,16 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
return self.to_numpy(dtype=dtype)

def __invert__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def __invert__(self) -> Self:
return type(self)(pc.invert(self._pa_array))

def __neg__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def __neg__(self) -> Self:
return type(self)(pc.negate_checked(self._pa_array))

def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def __pos__(self) -> Self:
return type(self)(self._pa_array)

def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def __abs__(self) -> Self:
return type(self)(pc.abs_checked(self._pa_array))

# GH 42600: __getstate__/__setstate__ not necessary once
Expand Down Expand Up @@ -733,7 +731,7 @@ def argmin(self, skipna: bool = True) -> int:
def argmax(self, skipna: bool = True) -> int:
return self._argmin_max(skipna, "max")

def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def copy(self) -> Self:
"""
Return a shallow copy of the array.
Expand All @@ -745,7 +743,7 @@ def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
"""
return type(self)(self._pa_array)

def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def dropna(self) -> Self:
"""
Return ArrowExtensionArray without NA values.
Expand All @@ -757,11 +755,11 @@ def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:

@doc(ExtensionArray.fillna)
def fillna(
self: ArrowExtensionArrayT,
self,
value: object | ArrayLike | None = None,
method: FillnaOptions | None = None,
limit: int | None = None,
) -> ArrowExtensionArrayT:
) -> Self:
value, method = validate_fillna_kwargs(value, method)

if limit is not None:
Expand Down Expand Up @@ -877,9 +875,7 @@ def reshape(self, *args, **kwargs):
f"as backed by a 1D pyarrow.ChunkedArray."
)

def round(
self: ArrowExtensionArrayT, decimals: int = 0, *args, **kwargs
) -> ArrowExtensionArrayT:
def round(self, decimals: int = 0, *args, **kwargs) -> Self:
"""
Round each value in the array a to the given number of decimals.
Expand Down Expand Up @@ -1052,7 +1048,7 @@ def to_numpy(
result[self.isna()] = na_value
return result

def unique(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
def unique(self) -> Self:
"""
Compute the ArrowExtensionArray of unique values.
Expand Down Expand Up @@ -1123,9 +1119,7 @@ def value_counts(self, dropna: bool = True) -> Series:
return Series(counts, index=index, name="count")

@classmethod
def _concat_same_type(
cls: type[ArrowExtensionArrayT], to_concat
) -> ArrowExtensionArrayT:
def _concat_same_type(cls, to_concat) -> Self:
"""
Concatenate multiple ArrowExtensionArrays.
Expand Down Expand Up @@ -1456,9 +1450,7 @@ def _rank(

return type(self)(result)

def _quantile(
self: ArrowExtensionArrayT, qs: npt.NDArray[np.float64], interpolation: str
) -> ArrowExtensionArrayT:
def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
"""
Compute the quantiles of self for each quantile in `qs`.
Expand Down Expand Up @@ -1495,7 +1487,7 @@ def _quantile(

return type(self)(result)

def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArrayT:
def _mode(self, dropna: bool = True) -> Self:
"""
Returns the mode(s) of the ExtensionArray.
Expand Down

0 comments on commit f8a37a7

Please sign in to comment.