Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs(python): Clarify documentation for the agg_list argument in Expr.map_batches #13625

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3991,7 +3991,10 @@ def map_batches(
If set to true this can run in the streaming engine, but may yield
incorrect results in group-by. Ensure you know what you are doing!
agg_list
Aggregate list.
Aggregate the values of the expression into a list before applying the
function. This parameter only works in a group-by context.
The function will be invoked only once on a list of groups, rather than
once per group.

Warnings
--------
Expand Down Expand Up @@ -4020,6 +4023,46 @@ def map_batches(
╞══════╪════════╡
│ 1 ┆ 0 │
└──────┴────────┘

In a group-by context, the `agg_list` parameter can improve performance if used
correctly. The following example has `agg_list` set to `False`, which causes
the function to be applied once per group. The input of the function is a
Series of type `Int64`. This is less efficient.

>>> df = pl.DataFrame(
... {
... "a": [0, 1, 0, 1],
... "b": [1, 2, 3, 4],
... }
... )
>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.max(), agg_list=False)
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬───────────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 1 ┆ [4] │
│ 0 ┆ [3] │
└─────┴───────────┘

Using `agg_list=True` would be more efficient. In this example, the input of
the function is a Series of type `List(Int64)`.

>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.list.max(), agg_list=True)
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 0 ┆ 3 │
│ 1 ┆ 4 │
└─────┴─────┘
"""
if return_dtype is not None:
return_dtype = py_type_to_dtype(return_dtype)
Expand Down