Skip to content

Commit

Permalink
Allow sequence in groupby level (pandas-dev#837)
Browse files Browse the repository at this point in the history
* Add test cases for level sequences

This test fail currently, as the level parameter
currently does not accept any sequences.

* Allow sequences for groupby level parameter

This fixes pandas-dev#836

* Add assert_type

* Remove unnecessary quotes
  • Loading branch information
jens-diewald authored and twoertwein committed Dec 24, 2023
1 parent 29afb64 commit ae102bf
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
16 changes: 8 additions & 8 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: Scalar,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1022,7 +1022,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: DatetimeIndex,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1035,7 +1035,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: TimedeltaIndex,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1048,7 +1048,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: PeriodIndex,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1061,7 +1061,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: IntervalIndex[IntervalT],
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1074,7 +1074,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: MultiIndex | GroupByObjectNonScalar | None = ...,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1087,7 +1087,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: Series[SeriesByT],
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -1100,7 +1100,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
by: CategoricalIndex | Index | Series,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand Down
17 changes: 9 additions & 8 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ from pandas._typing import (
HashableT3,
IgnoreRaise,
IndexingInt,
IndexLabel,
IntDtypeArg,
InterpolateOptions,
IntervalClosedType,
Expand Down Expand Up @@ -547,7 +548,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: Scalar,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -560,7 +561,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: DatetimeIndex,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -573,7 +574,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: TimedeltaIndex,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -586,7 +587,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: PeriodIndex,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -599,7 +600,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: IntervalIndex[IntervalT],
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -612,7 +613,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: MultiIndex | GroupByObjectNonScalar = ...,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -625,7 +626,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: Series[SeriesByT],
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand All @@ -638,7 +639,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
self,
by: CategoricalIndex | Index | Series,
axis: AxisIndex = ...,
level: Level | None = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
sort: _bool = ...,
group_keys: _bool = ...,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,21 @@ def test_types_groupby_iter() -> None:
)


def test_types_groupby_level() -> None:
# GH 836
data = {
"col1": [0, 0, 0],
"col2": [0, 1, 0],
"col3": [1, 2, 3],
"col4": [1, 2, 3],
}
df = pd.DataFrame(data=data).set_index(["col1", "col2", "col3"])
check(
assert_type(df.groupby(level=["col1", "col2"]).sum(), pd.DataFrame),
pd.DataFrame,
)


def test_types_merge() -> None:
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})
df2 = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [0, 1, 0]})
Expand Down
13 changes: 13 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,19 @@ def test_types_max() -> None:
s.max(skipna=False)


def test_types_groupby_level() -> None:
# GH 836
index = pd.MultiIndex.from_tuples(
[(0, 0, 1), (0, 1, 2), (0, 0, 3)], names=["col1", "col2", "col3"]
)
s = pd.Series([1, 2, 3], index=index)
check(
assert_type(s.groupby(level=["col1", "col2"]).sum(), "pd.Series[int]"),
pd.Series,
np.integer,
)


def test_types_quantile() -> None:
s = pd.Series([1, 2, 3, 10])
s.quantile([0.25, 0.5])
Expand Down

0 comments on commit ae102bf

Please sign in to comment.