diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index ae760f9e7..7e567428e 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -36,7 +36,10 @@ from pandas._libs.tslibs import ( Timestamp, ) -from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + ExtensionDtype, +) from pandas.io.formats.format import EngFormatter @@ -210,6 +213,7 @@ S1 = TypeVar( Interval[float], Interval[Timestamp], Interval[Timedelta], + CategoricalDtype, ) T1 = TypeVar( "T1", str, int, np.int64, np.uint64, np.float64, float, np.dtype[np.generic] diff --git a/pandas-stubs/core/reshape/tile.pyi b/pandas-stubs/core/reshape/tile.pyi index c630d6c37..a42be6006 100644 --- a/pandas-stubs/core/reshape/tile.pyi +++ b/pandas-stubs/core/reshape/tile.pyi @@ -7,12 +7,17 @@ from typing import ( import numpy as np from pandas import ( Categorical, + CategoricalDtype, + DatetimeIndex, Float64Index, Index, Int64Index, + Interval, IntervalIndex, Series, + Timestamp, ) +from pandas.core.series import TimestampSeries from pandas._typing import ( Label, @@ -46,6 +51,36 @@ def cut( ordered: bool = ..., ) -> tuple[npt.NDArray[np.intp], IntervalIndex]: ... @overload +def cut( # type: ignore[misc] + x: TimestampSeries, + bins: int + | TimestampSeries + | DatetimeIndex + | Sequence[Timestamp] + | Sequence[np.datetime64], + right: bool = ..., + labels: Literal[False] | Sequence[Label] | None = ..., + *, + retbins: Literal[True], + precision: int = ..., + include_lowest: bool = ..., + duplicates: Literal["raise", "drop"] = ..., + ordered: bool = ..., +) -> tuple[Series, DatetimeIndex]: ... +@overload +def cut( + x: TimestampSeries, + bins: IntervalIndex[Interval[Timestamp]], + right: bool = ..., + labels: Sequence[Label] | None = ..., + *, + retbins: Literal[True], + precision: int = ..., + include_lowest: bool = ..., + duplicates: Literal["raise", "drop"] = ..., + ordered: bool = ..., +) -> tuple[Series, DatetimeIndex]: ... +@overload def cut( x: Series, bins: int | Series | Int64Index | Float64Index | Sequence[int] | Sequence[float], @@ -61,7 +96,7 @@ def cut( @overload def cut( x: Series, - bins: IntervalIndex, + bins: IntervalIndex[Interval[int]] | IntervalIndex[Interval[float]], right: bool = ..., labels: Sequence[Label] | None = ..., *, @@ -117,6 +152,23 @@ def cut( ordered: bool = ..., ) -> npt.NDArray[np.intp]: ... @overload +def cut( + x: TimestampSeries, + bins: int + | TimestampSeries + | DatetimeIndex + | Sequence[Timestamp] + | Sequence[np.datetime64] + | IntervalIndex[Interval[Timestamp]], + right: bool = ..., + labels: Literal[False] | Sequence[Label] | None = ..., + retbins: Literal[False] = ..., + precision: int = ..., + include_lowest: bool = ..., + duplicates: Literal["raise", "drop"] = ..., + ordered: bool = ..., +) -> Series[CategoricalDtype]: ... +@overload def cut( x: Series, bins: int diff --git a/tests/test_pandas.py b/tests/test_pandas.py index f941f76f4..b14465fcf 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -870,6 +870,32 @@ def test_cut() -> None: check(assert_type(n0, pd.Categorical), pd.Categorical) check(assert_type(n1, pd.IntervalIndex), pd.IntervalIndex) + s1 = pd.Series(data=pd.date_range("1/1/2020", periods=300)) + check( + assert_type( + pd.cut(s1, bins=[np.datetime64("2020-01-03"), np.datetime64("2020-09-01")]), + "pd.Series[pd.CategoricalDtype]", + ), + pd.Series, + ) + check( + assert_type( + pd.cut(s1, bins=10), + "pd.Series[pd.CategoricalDtype]", + ), + pd.Series, + pd.Interval, + ) + s0r, s1r = pd.cut(s1, bins=10, retbins=True) + check(assert_type(s0r, pd.Series), pd.Series, pd.Interval) + check(assert_type(s1r, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp) + s0rlf, s1rlf = pd.cut(s1, bins=10, labels=False, retbins=True) + check(assert_type(s0rlf, pd.Series), pd.Series, int) + check(assert_type(s1rlf, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp) + s0rls, s1rls = pd.cut(s1, bins=4, labels=["1", "2", "3", "4"], retbins=True) + check(assert_type(s0rls, pd.Series), pd.Series, str) + check(assert_type(s1rls, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp) + def test_qcut() -> None: val_list = [random.random() for _ in range(20)]