Skip to content

Commit

Permalink
feat(python): Allow renaming expressions with keyword syntax in `grou…
Browse files Browse the repository at this point in the history
…p_by` (#14071)

Co-authored-by: Stijn de Gooijer <stijndegooijer@gmail.com>
  • Loading branch information
deanm0000 and stinodego committed Jan 29, 2024
1 parent 5da14a0 commit b1f315a
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 53 deletions.
16 changes: 9 additions & 7 deletions py-polars/polars/dataframe/frame.py
Expand Up @@ -5362,22 +5362,21 @@ def with_row_count(self, name: str = "row_nr", offset: int = 0) -> Self:
"""
return self.with_row_index(name, offset)

@deprecate_parameter_as_positional("by", version="0.20.7")
def group_by(
self,
by: IntoExpr | Iterable[IntoExpr],
*more_by: IntoExpr,
*by: IntoExpr | Iterable[IntoExpr],
maintain_order: bool = False,
**named_by: IntoExpr,
) -> GroupBy:
"""
Start a group by operation.
Parameters
----------
by
*by
Column(s) to group by. Accepts expression input. Strings are parsed as
column names.
*more_by
Additional columns to group by, specified as positional arguments.
maintain_order
Ensure that the order of the groups is consistent with the input data.
This is slower than a default group by.
Expand All @@ -5387,6 +5386,9 @@ def group_by(
.. note::
Within each group, the order of rows is always preserved, regardless
of this argument.
**named_by
Additional columns to group by, specified as keyword arguments.
The columns will be renamed to the keyword used.
Returns
-------
Expand Down Expand Up @@ -5498,7 +5500,7 @@ def group_by(
│ c ┆ 3 ┆ 1 │
└─────┴─────┴─────┘
"""
return GroupBy(self, by, *more_by, maintain_order=maintain_order)
return GroupBy(self, *by, **named_by, maintain_order=maintain_order)

def rolling(
self,
Expand Down Expand Up @@ -9191,7 +9193,7 @@ def n_unique(self, subset: str | Expr | Sequence[str | Expr] | None = None) -> i
In aggregate context there is also an equivalent method for returning the
unique values per-group:
>>> df_agg_nunique = df.group_by(by=["a"]).n_unique()
>>> df_agg_nunique = df.group_by(["a"]).n_unique()
Examples
--------
Expand Down
46 changes: 21 additions & 25 deletions py-polars/polars/dataframe/group_by.py
Expand Up @@ -35,9 +35,9 @@ class GroupBy:
def __init__(
self,
df: DataFrame,
by: IntoExpr | Iterable[IntoExpr],
*more_by: IntoExpr,
*by: IntoExpr | Iterable[IntoExpr],
maintain_order: bool,
**named_by: IntoExpr,
):
"""
Utility class for performing a group by operation over the given DataFrame.
Expand All @@ -48,18 +48,19 @@ def __init__(
----------
df
DataFrame to perform the group by operation over.
by
*by
Column or columns to group by. Accepts expression input. Strings are parsed
as column names.
*more_by
Additional columns to group by, specified as positional arguments.
maintain_order
Ensure that the order of the groups is consistent with the input data.
This is slower than a default group by.
**named_by
Additional column(s) to group by, specified as keyword arguments.
The columns will be named as the keyword used.
"""
self.df = df
self.by = by
self.more_by = more_by
self.named_by = named_by
self.maintain_order = maintain_order

def __iter__(self) -> Self:
Expand Down Expand Up @@ -99,19 +100,21 @@ def __iter__(self) -> Self:
temp_col = "__POLARS_GB_GROUP_INDICES"
groups_df = (
self.df.lazy()
.group_by(self.by, *self.more_by, maintain_order=self.maintain_order)
.group_by(*self.by, **self.named_by, maintain_order=self.maintain_order)
.agg(F.first().agg_groups().alias(temp_col))
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
key_as_single_value = isinstance(self.by, str) and not self.more_by
key_as_single_value = (
len(self.by) == 1 and isinstance(self.by[0], str) and not self.named_by
)
if key_as_single_value:
issue_deprecation_warning(
"`group_by` iteration will change to always return group identifiers as tuples."
f" Pass `by` as a list to silence this warning, e.g. `group_by([{self.by!r}])`.",
f" Pass `by` as a list to silence this warning, e.g. `group_by([{self.by[0]!r}])`.",
version="0.20.4",
)
self._group_names = iter(group_names.to_series())
Expand Down Expand Up @@ -242,7 +245,7 @@ def agg(
"""
return (
self.df.lazy()
.group_by(self.by, *self.more_by, maintain_order=self.maintain_order)
.group_by(*self.by, **self.named_by, maintain_order=self.maintain_order)
.agg(*aggs, **named_aggs)
.collect(no_optimization=True)
)
Expand Down Expand Up @@ -308,24 +311,17 @@ def map_groups(self, function: Callable[[DataFrame], DataFrame]) -> DataFrame:
... pl.int_range(pl.len()).shuffle().over("color") < 2
... ) # doctest: +IGNORE_RESULT
"""
by: list[str]

if isinstance(self.by, str):
by = [self.by]
elif isinstance(self.by, Iterable) and all(isinstance(c, str) for c in self.by):
by = list(self.by) # type: ignore[arg-type]
else:
msg = "cannot call `map_groups` when grouping by an expression"
if self.named_by:
msg = "cannot call `map_groups` when grouping by named expressions"
raise TypeError(msg)

if all(isinstance(c, str) for c in self.more_by):
by.extend(self.more_by) # type: ignore[arg-type]
else:
if not all(isinstance(c, str) for c in self.by):
msg = "cannot call `map_groups` when grouping by an expression"
raise TypeError(msg)

return self.df.__class__._from_pydf(
self.df._df.group_by_map_groups(by, function, self.maintain_order)
self.df._df.group_by_map_groups(
list(self.by), function, self.maintain_order
)
)

def head(self, n: int = 5) -> DataFrame:
Expand Down Expand Up @@ -375,7 +371,7 @@ def head(self, n: int = 5) -> DataFrame:
"""
return (
self.df.lazy()
.group_by(self.by, *self.more_by, maintain_order=self.maintain_order)
.group_by(*self.by, **self.named_by, maintain_order=self.maintain_order)
.head(n)
.collect(no_optimization=True)
)
Expand Down Expand Up @@ -427,7 +423,7 @@ def tail(self, n: int = 5) -> DataFrame:
"""
return (
self.df.lazy()
.group_by(self.by, *self.more_by, maintain_order=self.maintain_order)
.group_by(*self.by, **self.named_by, maintain_order=self.maintain_order)
.tail(n)
.collect(no_optimization=True)
)
Expand Down
14 changes: 8 additions & 6 deletions py-polars/polars/lazyframe/frame.py
Expand Up @@ -3126,27 +3126,29 @@ def select_seq(
)
return self._from_pyldf(self._ldf.select_seq(pyexprs))

@deprecate_parameter_as_positional("by", version="0.20.7")
def group_by(
self,
by: IntoExpr | Iterable[IntoExpr],
*more_by: IntoExpr,
*by: IntoExpr | Iterable[IntoExpr],
maintain_order: bool = False,
**named_by: IntoExpr,
) -> LazyGroupBy:
"""
Start a group by operation.
Parameters
----------
by
*by
Column(s) to group by. Accepts expression input. Strings are parsed as
column names.
*more_by
Additional columns to group by, specified as positional arguments.
maintain_order
Ensure that the order of the groups is consistent with the input data.
This is slower than a default group by.
Setting this to `True` blocks the possibility
to run on the streaming engine.
**named_by
Additional columns to group by, specified as keyword arguments.
The columns will be renamed to the keyword used.
Examples
--------
Expand Down Expand Up @@ -3219,7 +3221,7 @@ def group_by(
│ c ┆ 1 ┆ 1.0 │
└─────┴─────┴─────┘
"""
exprs = parse_as_list_of_expressions(by, *more_by)
exprs = parse_as_list_of_expressions(*by, **named_by)
lgb = self._ldf.group_by(exprs, maintain_order)
return LazyGroupBy(lgb)

Expand Down
25 changes: 14 additions & 11 deletions py-polars/polars/utils/deprecation.py
Expand Up @@ -92,17 +92,20 @@ def myfunc(new_name):
def decorate(function: Callable[P, T]) -> Callable[P, T]:
@wraps(function)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if param_args := kwargs.pop(old_name, []):
issue_deprecation_warning(
f"named `{old_name}` param is deprecated; use positional `*args` instead.",
version=version,
)
if param_args:
if not isinstance(param_args, Sequence) or isinstance(param_args, str):
param_args = (param_args,)
elif not isinstance(param_args, tuple):
param_args = tuple(param_args)
args = args + param_args # type: ignore[assignment]
try:
param_args = kwargs.pop(old_name)
except KeyError:
return function(*args, **kwargs)

issue_deprecation_warning(
f"named `{old_name}` param is deprecated; use positional `*args` instead.",
version=version,
)
if not isinstance(param_args, Sequence) or isinstance(param_args, str):
param_args = (param_args,)
elif not isinstance(param_args, tuple):
param_args = tuple(param_args)
args = args + param_args # type: ignore[assignment]
return function(*args, **kwargs)

wrapper.__signature__ = inspect.signature(function) # type: ignore[attr-defined]
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_categorical.py
Expand Up @@ -574,9 +574,9 @@ def test_nested_categorical_aggregation_7848() -> None:
"letter": ["a", "b", "c", "d", "e", "f", "g"],
}
).with_columns([pl.col("letter").cast(pl.Categorical)]).group_by(
maintain_order=True, by=["group"]
"group", maintain_order=True
).all().with_columns(pl.col("letter").list.len().alias("c_group")).group_by(
by=["c_group"], maintain_order=True
["c_group"], maintain_order=True
).agg(pl.col("letter")).to_dict(as_series=False) == {
"c_group": [2, 3],
"letter": [[["a", "b"], ["f", "g"]], [["c", "d", "e"]]],
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_aggregations.py
Expand Up @@ -387,7 +387,7 @@ def test_agg_filter_over_empty_df_13610() -> None:

out = (
ldf.drop_nulls()
.group_by(by=["a"], maintain_order=True)
.group_by(["a"], maintain_order=True)
.agg(pl.col("b").filter(pl.col("b").shift(1)))
.collect()
)
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Expand Up @@ -917,3 +917,24 @@ def test_group_by_all_12869() -> None:
df = pl.DataFrame({"a": [1]})
result = next(iter(df.group_by(pl.all())))[1]
assert_frame_equal(df, result)


def test_group_by_named() -> None:
df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)})
result = df.group_by(z=pl.col("a") * 2, maintain_order=True).agg(pl.col("b").min())
expected = df.group_by((pl.col("a") * 2).alias("z"), maintain_order=True).agg(
pl.col("b").min()
)
assert_frame_equal(result, expected)


def test_group_by_deprecated_by_arg() -> None:
df = pl.DataFrame({"a": [1, 1, 2, 2, 3, 3], "b": range(6)})
with pytest.deprecated_call():
result = df.group_by(by=(pl.col("a") * 2), maintain_order=True).agg(
pl.col("b").min()
)
expected = df.group_by((pl.col("a") * 2), maintain_order=True).agg(
pl.col("b").min()
)
assert_frame_equal(result, expected)
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_schema.py
Expand Up @@ -629,6 +629,6 @@ def test_literal_subtract_schema_13284() -> None:
assert (
pl.LazyFrame({"a": [23, 30]}, schema={"a": pl.UInt8})
.with_columns(pl.col("a") - pl.lit(1))
.group_by(by="a")
.group_by("a")
.len()
).schema == OrderedDict([("a", pl.UInt8), ("len", pl.UInt32)])

0 comments on commit b1f315a

Please sign in to comment.