diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 3f74a9b5d..8709a96c5 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -908,6 +908,7 @@ SeriesDType: TypeAlias = ( | datetime.datetime # includes pd.Timestamp | datetime.timedelta # includes pd.Timedelta ) + S1 = TypeVar("S1", bound=SeriesDType, default=Any) # Like S1, but without `default=Any`. S2 = TypeVar("S2", bound=SeriesDType) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index d6b245983..3ff4db49a 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -405,7 +405,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]): __bool__ = ... def union( self, other: list[HashableT] | Self, sort: bool | None = None - ) -> Index: ... + ) -> Self: ... def intersection( self, other: list[S1] | Self, sort: bool | None = False ) -> Self: ... diff --git a/pandas-stubs/core/indexes/multi.pyi b/pandas-stubs/core/indexes/multi.pyi index e66c845ec..7dc76e7f7 100644 --- a/pandas-stubs/core/indexes/multi.pyi +++ b/pandas-stubs/core/indexes/multi.pyi @@ -135,7 +135,7 @@ class MultiIndex(Index): def append(self, other): ... def repeat(self, repeats, axis=...): ... def drop(self, codes, level: Level | None = None, errors: str = "raise") -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] - def swaplevel(self, i: int = -2, j: int = -1): ... + def swaplevel(self, i: int = -2, j: int = -1) -> Self: ... def reorder_levels(self, order): ... def sortlevel( self, diff --git a/pandas-stubs/py.typed b/pandas-stubs/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/indexes/test_indexes.py b/tests/indexes/test_indexes.py index 3b2960b3c..d477572ee 100644 --- a/tests/indexes/test_indexes.py +++ b/tests/indexes/test_indexes.py @@ -1601,3 +1601,19 @@ def test_to_series() -> None: np.complexfloating, ) check(assert_type(Index(["1"]).to_series(), "pd.Series[str]"), pd.Series, str) + + +def test_multiindex_union() -> None: + """Test that MultiIndex.union returns MultiIndex""" + mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) + mi2 = pd.MultiIndex.from_product([["a", "b"], [3, 4]], names=["let", "num"]) + + check(assert_type(mi.union(mi2), "pd.MultiIndex"), pd.MultiIndex) + check(assert_type(mi.union([("c", 3), ("d", 4)]), "pd.MultiIndex"), pd.MultiIndex) + check(assert_type(mi.union([1, 2, 3]), "pd.MultiIndex"), pd.MultiIndex) + + +def test_multiindex_swaplevel() -> None: + """Test that MultiIndex.swaplevel returns MultiIndex""" + mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"]) + check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex)