diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 1350b7c1b..f7dcd1cfc 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -950,6 +950,8 @@ np_1darray_dt: TypeAlias = np_1darray[np.datetime64] np_1darray_td: TypeAlias = np_1darray[np.timedelta64] np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]] +NDArrayT = TypeVar("NDArrayT", bound=np.ndarray) + DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic]) KeysArgType: TypeAlias = Any ListLikeT = TypeVar("ListLikeT", bound=ListLike) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 09bb37a8a..6c6c10aa7 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -102,7 +102,11 @@ from pandas._typing import ( Level, MaskType, NaPosition, + NDArrayT, NumpyFloatNot16DtypeArg, + NumpyNotTimeDtypeArg, + NumpyTimedeltaDtypeArg, + NumpyTimestampDtypeArg, PandasAstypeFloatDtypeArg, PandasFloatDtypeArg, PyArrowFloatDtypeArg, @@ -374,7 +378,15 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]): def dtype(self) -> DtypeObj: ... @final def ravel(self, order: _str = "C") -> Self: ... - def view(self, cls=...): ... + @overload + def view(self, cls: None = None) -> Self: ... + @overload + def view(self, cls: type[NDArrayT]) -> NDArrayT: ... + @overload + def view( + self, + cls: NumpyNotTimeDtypeArg | NumpyTimedeltaDtypeArg | NumpyTimestampDtypeArg, + ) -> np_1darray: ... @overload def astype( self, @@ -596,7 +608,11 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]): def insert(self, loc: int, item: S1) -> Self: ... @overload def insert(self, loc: int, item: object) -> Index: ... - def drop(self, labels, errors: IgnoreRaise = "raise") -> Self: ... + def drop( + self, + labels: IndexOpsMixin | np_ndarray | Iterable[Hashable], + errors: IgnoreRaise = "raise", + ) -> Self: ... @property def shape(self) -> tuple[int, ...]: ... # Extra methods from old stubs diff --git a/pyproject.toml b/pyproject.toml index 46fc93532..437f0de72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,8 +237,12 @@ ignore = [ # TODO: remove when _libs is fully typed "ANN001", "ANN201", "ANN204", "ANN206", ] -"*base.pyi" = [ - # TODO: remove when base.pyi's are fully typed +"*core/base.pyi" = [ + # TODO: remove when core/base.pyi is fully typed + "ANN001", "ANN201", "ANN204", "ANN206", +] +"*excel/_base.pyi" = [ + # TODO: remove when excel/_base.pyi is fully typed "ANN001", "ANN201", "ANN204", "ANN206", ] "scripts/*" = [ diff --git a/tests/indexes/test_indexes.py b/tests/indexes/test_indexes.py index eda08354e..56e1e87ae 100644 --- a/tests/indexes/test_indexes.py +++ b/tests/indexes/test_indexes.py @@ -2,6 +2,7 @@ from collections.abc import Hashable import datetime as dt +import sys from typing import ( Any, cast, @@ -1667,3 +1668,35 @@ def test_index_slice_locs() -> None: start, end = idx.slice_locs(0, 1) check(assert_type(start, np.intp | int), np.integer) check(assert_type(end, np.intp | int), int) + + +def test_index_view() -> None: + ind = pd.Index([1, 2]) + check(assert_type(ind.view("int64"), np_1darray), np_1darray) + check(assert_type(ind.view(), "pd.Index[int]"), pd.Index) + if sys.version_info >= (3, 11): + # mypy and pyright differ here in what they report: + # - mypy: ndarray[Any, Any]" + # - pyright: ndarray[tuple[Any, ...], dtype[Any]] + check(assert_type(ind.view(np.ndarray), np.ndarray), np.ndarray) # type: ignore[assert-type] + else: + check(assert_type(ind.view(np.ndarray), np.ndarray), np.ndarray) + + class MyArray(np.ndarray): ... + + check(assert_type(ind.view(MyArray), MyArray), MyArray) + + +def test_index_drop() -> None: + ind = pd.Index([1, 2, 3]) + check(assert_type(ind.drop([1, 2]), "pd.Index[int]"), pd.Index, np.integer) + check( + assert_type(ind.drop(pd.Index([1, 2])), "pd.Index[int]"), pd.Index, np.integer + ) + check( + assert_type(ind.drop(pd.Series([1, 2])), "pd.Index[int]"), pd.Index, np.integer + ) + check( + assert_type(ind.drop(np.array([1, 2])), "pd.Index[int]"), pd.Index, np.integer + ) + check(assert_type(ind.drop(iter([1, 2])), "pd.Index[int]"), pd.Index, np.integer)