Skip to content

Commit

Permalink
feat(rust, python): add reduce/cumreduce expression as an easier fold (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 28, 2022
1 parent 552b4f5 commit c987304
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 34 deletions.
119 changes: 104 additions & 15 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,25 @@ where
a.apply_many(function, &[b], output_type)
}

#[cfg(feature = "dtype-struct")]
fn cumfold_dtype() -> GetOutput {
GetOutput::map_fields(|fields| {
let mut st = fields[0].dtype.clone();
for fld in &fields[1..] {
st = get_supertype(&st, &fld.dtype).unwrap();
}
Field::new(
&fields[0].name,
DataType::Struct(
fields
.iter()
.map(|fld| Field::new(fld.name(), st.clone()))
.collect(),
),
)
})
}

/// Accumulate over multiple columns horizontally / row wise.
pub fn fold_exprs<F: 'static, E: AsRef<[Expr]>>(acc: Expr, f: F, exprs: E) -> Expr
where
Expand Down Expand Up @@ -733,6 +752,90 @@ where
}
}

pub fn reduce_exprs<F: 'static, E: AsRef<[Expr]>>(f: F, exprs: E) -> Expr
where
F: Fn(Series, Series) -> PolarsResult<Series> + Send + Sync + Clone,
{
let exprs = exprs.as_ref().to_vec();

let function = SpecialEq::new(Arc::new(move |series: &mut [Series]| {
let mut s_iter = series.iter();

match s_iter.next() {
Some(acc) => {
let mut acc = acc.clone();

for s in s_iter {
acc = f(acc, s.clone())?;
}
Ok(acc)
}
None => Err(PolarsError::ComputeError(
"Reduce did not have any expressions to fold".into(),
)),
}
}) as Arc<dyn SeriesUdf>);

Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::super_type(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "reduce",
..Default::default()
},
}
}

/// Accumulate over multiple columns horizontally / row wise.
#[cfg(feature = "dtype-struct")]
#[cfg_attr(docsrs, doc(cfg(feature = "rank")))]
pub fn cumreduce_exprs<F: 'static, E: AsRef<[Expr]>>(f: F, exprs: E) -> Expr
where
F: Fn(Series, Series) -> PolarsResult<Series> + Send + Sync + Clone,
{
let exprs = exprs.as_ref().to_vec();

let function = SpecialEq::new(Arc::new(move |series: &mut [Series]| {
let mut s_iter = series.iter();

match s_iter.next() {
Some(acc) => {
let mut acc = acc.clone();
let mut result = vec![acc.clone()];

for s in s_iter {
let name = s.name().to_string();
acc = f(acc, s.clone())?;
acc.rename(&name);
result.push(acc.clone());
}

StructChunked::new(acc.name(), &result).map(|ca| ca.into_series())
}
None => Err(PolarsError::ComputeError(
"Reduce did not have any expressions to fold".into(),
)),
}
}) as Arc<dyn SeriesUdf>);

Expr::AnonymousFunction {
input: exprs,
function,
output_type: cumfold_dtype(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "cumreduce",
..Default::default()
},
}
}

/// Accumulate over multiple columns horizontally / row wise.
#[cfg(feature = "dtype-struct")]
#[cfg_attr(docsrs, doc(cfg(feature = "rank")))]
Expand Down Expand Up @@ -770,21 +873,7 @@ where
Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::map_fields(|fields| {
let mut st = fields[0].dtype.clone();
for fld in &fields[1..] {
st = get_supertype(&st, &fld.dtype).unwrap();
}
Field::new(
&fields[0].name,
DataType::Struct(
fields
.iter()
.map(|fld| Field::new(fld.name(), st.clone()))
.collect(),
),
)
}),
output_type: cumfold_dtype(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
Expand Down
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ These functions can be used as expression and sometimes also in eager contexts.
coalesce
cov
cumfold
cumreduce
cumsum
date
datetime
Expand All @@ -45,6 +46,7 @@ These functions can be used as expression and sometimes also in eager contexts.
n_unique
pearson_corr
quantile
reduce
repeat
select
spearman_rank_corr
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def version() -> str:
count,
cov,
cumfold,
cumreduce,
cumsum,
duration,
element,
Expand All @@ -112,6 +113,7 @@ def version() -> str:
n_unique,
pearson_corr,
quantile,
reduce,
repeat,
select,
spearman_rank_corr,
Expand Down Expand Up @@ -254,6 +256,8 @@ def version() -> str:
"apply",
"fold",
"cumfold",
"reduce",
"cumreduce",
"cumsum",
"any",
"all",
Expand Down
66 changes: 66 additions & 0 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from polars.polars import count as _count
from polars.polars import cov as pycov
from polars.polars import cumfold as pycumfold
from polars.polars import cumreduce as pycumreduce
from polars.polars import dtype_cols as _dtype_cols
from polars.polars import first as _first
from polars.polars import fold as pyfold
Expand All @@ -49,6 +50,7 @@
from polars.polars import min_exprs as _min_exprs
from polars.polars import pearson_corr as pypearson_corr
from polars.polars import py_datetime, py_duration
from polars.polars import reduce as pyreduce
from polars.polars import repeat as _repeat
from polars.polars import spearman_rank_corr as pyspearman_rank_corr
from polars.polars import sum_exprs as _sum_exprs
Expand Down Expand Up @@ -1033,6 +1035,11 @@ def fold(
exprs
Expressions to aggregate over. May also be a wildcard expression.
Notes
-----
If you simply want the first encountered expression as accumulator,
consider using ``reduce``.
"""
# in case of pl.col("*")
acc = pli.expr_to_lit_or_expr(acc, str_to_lit=True)
Expand All @@ -1043,6 +1050,34 @@ def fold(
return pli.wrap_expr(pyfold(acc._pyexpr, f, exprs))


def reduce(
f: Callable[[pli.Series, pli.Series], pli.Series],
exprs: Sequence[pli.Expr | str] | pli.Expr,
) -> pli.Expr:
"""
Accumulate over multiple columns horizontally/ row wise with a left fold.
Parameters
----------
f
Function to apply over the accumulator and the value.
Fn(acc, value) -> new_value
exprs
Expressions to aggregate over. May also be a wildcard expression.
Notes
-----
See ``fold`` for the version with an explicit accumulator.
"""
# in case of pl.col("*")
if isinstance(exprs, pli.Expr):
exprs = [exprs]

exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(pyreduce(f, exprs))


def cumfold(
acc: IntoExpr,
f: Callable[[pli.Series, pli.Series], pli.Series],
Expand All @@ -1067,6 +1102,11 @@ def cumfold(
include_init
Include the initial accumulator state as struct field.
Notes
-----
If you simply want the first encountered expression as accumulator,
consider using ``cumreduce``.
""" # noqa E501
# in case of pl.col("*")
acc = pli.expr_to_lit_or_expr(acc, str_to_lit=True)
Expand All @@ -1077,6 +1117,32 @@ def cumfold(
return pli.wrap_expr(pycumfold(acc._pyexpr, f, exprs, include_init))


def cumreduce(
f: Callable[[pli.Series, pli.Series], pli.Series],
exprs: Sequence[pli.Expr | str] | pli.Expr,
) -> pli.Expr:
"""
Cumulatively accumulate over multiple columns horizontally/ row wise with a left fold.
Every cumulative result is added as a separate field in a Struct column.
Parameters
----------
f
Function to apply over the accumulator and the value.
Fn(acc, value) -> new_value
exprs
Expressions to aggregate over. May also be a wildcard expression.
""" # noqa E501
# in case of pl.col("*")
if isinstance(exprs, pli.Expr):
exprs = [exprs]

exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(pycumreduce(f, exprs))


def any(name: str | Sequence[str] | Sequence[pli.Expr] | pli.Expr) -> pli.Expr:
"""Evaluate columnwise or elementwise with a bitwise OR operation."""
if isinstance(name, str):
Expand Down
14 changes: 14 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1695,13 +1695,27 @@ pub fn fold(acc: PyExpr, lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
polars::lazy::dsl::fold_exprs(acc.inner, func, exprs).into()
}

pub fn reduce(lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
let exprs = py_exprs_to_exprs(exprs);

let func = move |a: Series, b: Series| binary_lambda(&lambda, a, b);
polars::lazy::dsl::reduce_exprs(func, exprs).into()
}

pub fn cumfold(acc: PyExpr, lambda: PyObject, exprs: Vec<PyExpr>, include_init: bool) -> PyExpr {
let exprs = py_exprs_to_exprs(exprs);

let func = move |a: Series, b: Series| binary_lambda(&lambda, a, b);
polars::lazy::dsl::cumfold_exprs(acc.inner, func, exprs, include_init).into()
}

pub fn cumreduce(lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
let exprs = py_exprs_to_exprs(exprs);

let func = move |a: Series, b: Series| binary_lambda(&lambda, a, b);
polars::lazy::dsl::cumreduce_exprs(func, exprs).into()
}

pub fn lit(value: &PyAny, allow_object: bool) -> PyResult<PyExpr> {
if let Ok(true) = value.is_instance_of::<PyBool>() {
let val = value.extract::<bool>().unwrap();
Expand Down
12 changes: 12 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,21 @@ fn fold(acc: PyExpr, lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
dsl::fold(acc, lambda, exprs)
}

#[pyfunction]
fn reduce(lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
dsl::reduce(lambda, exprs)
}

#[pyfunction]
fn cumfold(acc: PyExpr, lambda: PyObject, exprs: Vec<PyExpr>, include_init: bool) -> PyExpr {
dsl::cumfold(acc, lambda, exprs, include_init)
}

#[pyfunction]
fn cumreduce(lambda: PyObject, exprs: Vec<PyExpr>) -> PyExpr {
dsl::cumreduce(lambda, exprs)
}

#[pyfunction]
fn arange(low: PyExpr, high: PyExpr, step: usize) -> PyExpr {
polars::lazy::dsl::arange(low.inner, high.inner, step).into()
Expand Down Expand Up @@ -577,6 +587,8 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(lit)).unwrap();
m.add_wrapped(wrap_pyfunction!(fold)).unwrap();
m.add_wrapped(wrap_pyfunction!(cumfold)).unwrap();
m.add_wrapped(wrap_pyfunction!(reduce)).unwrap();
m.add_wrapped(wrap_pyfunction!(cumreduce)).unwrap();
m.add_wrapped(wrap_pyfunction!(binary_expr)).unwrap();
m.add_wrapped(wrap_pyfunction!(arange)).unwrap();
m.add_wrapped(wrap_pyfunction!(pearson_corr)).unwrap();
Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/test_folds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
import polars as pl


def test_fold() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]})
out = df.select(
[
pl.sum(["a", "b"]),
pl.max(["a", pl.col("b") ** 2]),
pl.min(["a", pl.col("b") ** 2]),
]
)
assert out["sum"].series_equal(pl.Series("sum", [2.0, 4.0, 6.0]))
assert out["max"].series_equal(pl.Series("max", [1.0, 4.0, 9.0]))
assert out["min"].series_equal(pl.Series("min", [1.0, 2.0, 3.0]))

out = df.select(
pl.fold(acc=pl.lit(0), f=lambda acc, x: acc + x, exprs=pl.all()).alias("foo")
)
assert out["foo"].to_list() == [2, 4, 6]
out = df.select(pl.reduce(f=lambda acc, x: acc + x, exprs=pl.all()).alias("foo"))
assert out["foo"].to_list() == [2, 4, 6]


def test_cumfold() -> None:
df = pl.DataFrame(
{
Expand All @@ -17,6 +38,13 @@ def test_cumfold() -> None:
"b": [6, 8, 10, 12],
"c": [16, 28, 40, 52],
}
assert df.select(
[pl.cumreduce(lambda a, b: a + b, pl.all()).alias("folded")]
).unnest("folded").to_dict(False) == {
"a": [1, 2, 3, 4],
"b": [6, 8, 10, 12],
"c": [16, 28, 40, 52],
}


def test_cumsum_fold() -> None:
Expand Down

0 comments on commit c987304

Please sign in to comment.