Skip to content

Commit

Permalink
python: allow horizontal expanding sum (#4242)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 3, 2022
1 parent 20032d1 commit 00b50c2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
6 changes: 4 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def min(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:


@overload
def sum(column: str | list[pli.Expr | str]) -> pli.Expr:
def sum(column: str | list[pli.Expr | str] | pli.Expr) -> pli.Expr:
...


Expand All @@ -384,7 +384,7 @@ def sum(column: pli.Series) -> int | float:
...


def sum(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
def sum(column: str | list[pli.Expr | str] | pli.Series | pli.Expr) -> pli.Expr | Any:
"""
Get the sum value.
Expand All @@ -402,6 +402,8 @@ def sum(column: str | list[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
if isinstance(first, str):
first = col(first)
return fold(first, lambda a, b: a + b, column[1:]).alias("sum")
elif isinstance(column, pli.Expr):
return fold(lit(0), lambda a, b: a + b, column).alias("sum")
else:
return col(column).sum()

Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/test_expr_multi_cols.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,16 @@ def test_fold_regex_expand() -> None:
"y_2": [1.0, 2.5, 3.5],
"y_sum": [2.1, 4.7, 6.8],
}


def test_expanding_sum() -> None:
df = pl.DataFrame(
{
"x": [0, 1, 2],
"y_1": [1.1, 2.2, 3.3],
"y_2": [1.0, 2.5, 3.5],
}
)
assert df.with_column(pl.sum(pl.col(r"^y_.*$")).alias("y_sum"))[
"y_sum"
].to_list() == [2.1, 4.7, 6.8]

0 comments on commit 00b50c2

Please sign in to comment.