diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index b41353ef4..1c4a13caf 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -61,7 +61,7 @@ from pandas._typing import ( T_COMPLEX, T_INT, AnyAll, - ArrayLike, + AnyArrayLike, AxesData, CategoryDtypeArg, DropKeep, @@ -434,7 +434,7 @@ class Index(IndexOpsMixin[S1]): @property def values(self) -> np_1darray: ... def memory_usage(self, deep: bool = False): ... - def where(self, cond, other: Scalar | ArrayLike | None = None): ... + def where(self, cond, other: Scalar | AnyArrayLike | None = None) -> Self: ... def __contains__(self, key) -> bool: ... @final def __setitem__(self, key, value) -> None: ... diff --git a/tests/indexes/test_indexes.py b/tests/indexes/test_indexes.py index 892cea004..d079c9a8e 100644 --- a/tests/indexes/test_indexes.py +++ b/tests/indexes/test_indexes.py @@ -19,6 +19,7 @@ from pandas.core.arrays.timedeltas import TimedeltaArray from pandas.core.indexes.base import Index from pandas.core.indexes.category import CategoricalIndex +from pandas.core.indexes.datetimes import DatetimeIndex from typing_extensions import ( Never, assert_type, @@ -1608,3 +1609,27 @@ def test_to_series() -> None: np.complexfloating, ) check(assert_type(Index(["1"]).to_series(), "pd.Series[str]"), pd.Series, str) + + +def test_index_where() -> None: + """Test Index.where with multiple types of other GH1419.""" + idx = pd.Index(range(48)) + mask = np.ones(48, dtype=bool) + val_idx = idx.where(mask, idx) + check(assert_type(val_idx, "pd.Index[int]"), pd.Index, int) + + val_sr = idx.where(mask, (idx).to_series()) + check(assert_type(val_sr, "pd.Index[int]"), pd.Index, int) + + +def test_datetimeindex_where() -> None: + """Test DatetimeIndex.where with multiple types of other GH1419.""" + datetime_index = pd.date_range(start="2025-01-01", freq="h", periods=48) + mask = np.ones(48, dtype=bool) + val_idx = datetime_index.where(mask, datetime_index - pd.Timedelta(days=1)) + check(assert_type(val_idx, DatetimeIndex), DatetimeIndex) + + val_sr = datetime_index.where( + mask, (datetime_index - pd.Timedelta(days=1)).to_series() + ) + check(assert_type(val_sr, DatetimeIndex), DatetimeIndex)