Skip to content

Commit

Permalink
Normalize handling of agg input types (#4283)
Browse files Browse the repository at this point in the history
Co-authored-by: Matteo Santamaria <msantama@gmail.com>
  • Loading branch information
matteosantama and Matteo Santamaria committed Aug 8, 2022
1 parent f4e7b08 commit a95915b
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 77 deletions.
20 changes: 18 additions & 2 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
UInt32,
py_type_to_dtype,
)
from polars.utils import _timedelta_to_pl_duration
from polars.utils import _timedelta_to_pl_duration, is_expr_sequence, is_pyexpr_sequence

try:
from polars.polars import PyExpr
Expand All @@ -47,13 +47,29 @@

def selection_to_pyexpr_list(
exprs: str | Expr | Sequence[str | Expr | pli.Series] | pli.Series,
) -> List[PyExpr]:
) -> list[PyExpr]:
if isinstance(exprs, (str, Expr, pli.Series)):
exprs = [exprs]

return [expr_to_lit_or_expr(e, str_to_lit=False)._pyexpr for e in exprs]


def ensure_list_of_pyexpr(exprs: object) -> list[PyExpr]:
if isinstance(exprs, PyExpr):
return [exprs]

if is_pyexpr_sequence(exprs):
return list(exprs)

if isinstance(exprs, Expr):
return [exprs._pyexpr]

if is_expr_sequence(exprs):
return [e._pyexpr for e in exprs]

raise TypeError(f"unexpected type '{type(exprs)}'")


def wrap_expr(pyexpr: PyExpr) -> Expr:
return Expr._from_pyexpr(pyexpr)

Expand Down
83 changes: 13 additions & 70 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6293,21 +6293,13 @@ def __init__(
self.closed = closed
self.by = by

def agg(
self,
column_to_agg: (
list[tuple[str, list[str]]]
| dict[str, str | list[str]]
| list[pli.Expr]
| pli.Expr
),
) -> DF:
def agg(self, aggs: pli.Expr | Sequence[pli.Expr]) -> DF:
return (
self.df.lazy()
.groupby_rolling(
self.time_column, self.period, self.offset, self.closed, self.by
)
.agg(column_to_agg) # type: ignore[arg-type]
.agg(aggs)
.collect(no_optimization=True, string_cache=False)
)

Expand Down Expand Up @@ -6340,15 +6332,7 @@ def __init__(
self.closed = closed
self.by = by

def agg(
self,
column_to_agg: (
list[tuple[str, list[str]]]
| dict[str, str | list[str]]
| list[pli.Expr]
| pli.Expr
),
) -> DF:
def agg(self, aggs: pli.Expr | Sequence[pli.Expr]) -> DF:
return (
self.df.lazy()
.groupby_dynamic(
Expand All @@ -6361,7 +6345,7 @@ def agg(
self.closed,
self.by,
)
.agg(column_to_agg) # type: ignore[arg-type]
.agg(aggs)
.collect(no_optimization=True, string_cache=False)
)

Expand Down Expand Up @@ -6652,23 +6636,15 @@ def apply(self, f: Callable[[DataFrame], DataFrame]) -> DF:
"""
return self._dataframe_class._from_pydf(self._df.groupby_apply(self.by, f))

def agg(
self,
column_to_agg: (
list[tuple[str, list[str]]]
| dict[str, str | list[str]]
| list[pli.Expr]
| pli.Expr
),
) -> DF:
def agg(self, aggs: pli.Expr | Sequence[pli.Expr]) -> DF:
"""
Use multiple aggregations on columns. This can be combined with complete lazy
API and is considered idiomatic polars.
Parameters
----------
column_to_agg
map column to aggregation functions.
aggs
Single / multiple aggregation expression(s).
Returns
-------
Expand Down Expand Up @@ -6697,45 +6673,12 @@ def agg(
└─────┴─────────┴──────────────┘
"""
# a single list comprehension would be cleaner, but mypy complains on different
# lines for py3.7 vs py3.10 about typing errors, so this is the same logic,
# but broken down into two small functions
def _str_to_list(y: Any) -> Any:
return [y] if isinstance(y, str) else y

def _wrangle(x: Any) -> list:
return [(xi[0], _str_to_list(xi[1])) for xi in x]

if isinstance(column_to_agg, pli.Expr):
column_to_agg = [column_to_agg]
if isinstance(column_to_agg, dict):
column_to_agg = _wrangle(column_to_agg.items())
elif isinstance(column_to_agg, list):

if isinstance(column_to_agg[0], tuple):
column_to_agg = _wrangle(column_to_agg)

elif isinstance(column_to_agg[0], pli.Expr):
return (
self._dataframe_class._from_pydf(self._df)
.lazy()
.groupby(self.by, maintain_order=self.maintain_order)
.agg(column_to_agg) # type: ignore[arg-type]
.collect(no_optimization=True, string_cache=False)
)
else:
raise ValueError(
f"argument: {column_to_agg} not understood, have you passed a list"
" of expressions?"
)
else:
raise ValueError(
f"argument: {column_to_agg} not understood, have you passed a list of"
" expressions?"
)

return self._dataframe_class._from_pydf(
self._df.groupby_agg(self.by, column_to_agg)
return (
self._dataframe_class._from_pydf(self._df)
.lazy()
.groupby(self.by, maintain_order=self.maintain_order)
.agg(aggs)
.collect(no_optimization=True, string_cache=False)
)

def head(self, n: int = 5) -> DF:
Expand Down
15 changes: 11 additions & 4 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, Sequence, TypeVar, overload

from polars.internals.expr import ensure_list_of_pyexpr

if sys.version_info >= (3, 8):
from typing import Literal
else:
Expand All @@ -38,6 +40,7 @@
_process_null_values,
deprecated_alias,
format_path,
is_expr_sequence,
)

try:
Expand Down Expand Up @@ -2470,14 +2473,14 @@ def __init__(self, lgb: PyLazyGroupBy, lazyframe_class: type[LDF]) -> None:
self.lgb = lgb
self._lazyframe_class = lazyframe_class

def agg(self, aggs: list[pli.Expr] | pli.Expr) -> LDF:
def agg(self, aggs: pli.Expr | Sequence[pli.Expr]) -> LDF:
"""
Describe the aggregation that need to be done on a group.
Parameters
----------
aggs
Single/ Multiple aggregation expression(s).
Single / multiple aggregation expression(s).
Examples
--------
Expand All @@ -2493,8 +2496,12 @@ def agg(self, aggs: list[pli.Expr] | pli.Expr) -> LDF:
... ) # doctest: +SKIP
"""
aggs = pli.selection_to_pyexpr_list(aggs)
return self._lazyframe_class._from_pyldf(self.lgb.agg(aggs))
if not (isinstance(aggs, pli.Expr) or is_expr_sequence(aggs)):
msg = f"expected 'Expr | Sequence[Expr]', got '{type(aggs)}'"
raise TypeError(msg)

pyexprs = ensure_list_of_pyexpr(aggs)
return self._lazyframe_class._from_pyldf(self.lgb.agg(pyexprs))

def head(self, n: int = 5) -> LDF:
"""
Expand Down
18 changes: 18 additions & 0 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from pathlib import Path
from typing import Any, Callable, Iterable, Sequence

import polars.internals as pli
from polars.datatypes import DataType, Date, Datetime

try:
from polars.polars import PyExpr
from polars.polars import pool_size as _pool_size

_DOCUMENTING = False
Expand Down Expand Up @@ -132,6 +134,22 @@ def is_int_sequence(val: Sequence[object]) -> TypeGuard[Sequence[int]]:
return _is_iterable_of(val, Sequence, int)


def is_expr_sequence(val: object) -> TypeGuard[Sequence[pli.Expr]]:
"""Check whether the given object is a sequence of Exprs."""
if isinstance(val, Sequence):
return _is_iterable_of(val, Sequence, pli.Expr)
else:
return False


def is_pyexpr_sequence(val: object) -> TypeGuard[Sequence[PyExpr]]:
"""Check whether the given object is a sequence of Exprs."""
if isinstance(val, Sequence):
return _is_iterable_of(val, Sequence, PyExpr)
else:
return False


def is_str_sequence(
val: Sequence[object], allow_str: bool = False
) -> TypeGuard[Sequence[str]]:
Expand Down
87 changes: 86 additions & 1 deletion py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,13 @@ def test_groupby() -> None:
}
)

gb_df = df.groupby("a").agg({"b": ["sum", "min"], "c": "count"})
gb_df = df.groupby("a").agg(
[
pl.col("b").sum().alias("b_sum"),
pl.col("b").min().alias("b_min"),
pl.col("c").count(),
]
)
assert "b_sum" in gb_df.columns
assert "b_min" in gb_df.columns

Expand Down Expand Up @@ -424,6 +430,85 @@ def test_groupby() -> None:
df.groupby("a", "b") # type: ignore[arg-type]


BAD_AGG_PARAMETERS = [[("b", "sum")], [("b", ["sum"])], {"b": "sum"}, {"b": ["sum"]}]
GOOD_AGG_PARAMETERS: list[pl.Expr | list[pl.Expr]] = [
[pl.col("b").sum()],
pl.col("b").sum(),
]


@pytest.mark.parametrize("lazy", [True, False])
def test_groupby_agg_input_types(lazy: bool) -> None:
df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]})
df_or_lazy: pl.DataFrame | pl.LazyFrame = df.lazy() if lazy else df

for bad_param in BAD_AGG_PARAMETERS:
with pytest.raises(TypeError):
result = df_or_lazy.groupby("a").agg(bad_param) # type: ignore[arg-type]
if lazy:
result.collect()

expected = pl.DataFrame({"a": [1, 2], "b": [3, 7]})

for good_param in GOOD_AGG_PARAMETERS:
result = df_or_lazy.groupby("a", maintain_order=True).agg(good_param)
if lazy:
result = result.collect()
assert_frame_equal(result, expected)


@pytest.mark.parametrize("lazy", [True, False])
def test_groupby_rolling_agg_input_types(lazy: bool) -> None:
df = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]})
df_or_lazy: pl.DataFrame | pl.LazyFrame = df.lazy() if lazy else df

for bad_param in BAD_AGG_PARAMETERS:
with pytest.raises(TypeError):
result = df_or_lazy.groupby_rolling(
index_column="index_column", period="2i"
).agg(
bad_param # type: ignore[arg-type]
)
if lazy:
result.collect()

expected = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]})

for good_param in GOOD_AGG_PARAMETERS:
result = df_or_lazy.groupby_rolling(
index_column="index_column", period="2i"
).agg(good_param)
if lazy:
result = result.collect()
assert_frame_equal(result, expected)


@pytest.mark.parametrize("lazy", [True, False])
def test_groupby_dynamic_agg_input_types(lazy: bool) -> None:
df = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]})
df_or_lazy: pl.DataFrame | pl.LazyFrame = df.lazy() if lazy else df

for bad_param in BAD_AGG_PARAMETERS:
with pytest.raises(TypeError):
result = df_or_lazy.groupby_dynamic(
index_column="index_column", every="2i"
).agg(
bad_param # type: ignore[arg-type]
)
if lazy:
result.collect()

expected = pl.DataFrame({"index_column": [0, 0, 2], "b": [1, 4, 2]})

for good_param in GOOD_AGG_PARAMETERS:
result = df_or_lazy.groupby_dynamic(
index_column="index_column", every="2i"
).agg(good_param)
if lazy:
result = result.collect()
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"stack,exp_shape,exp_columns",
[
Expand Down

0 comments on commit a95915b

Please sign in to comment.