Skip to content

Commit

Permalink
allow wildcard in fold exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 21, 2021
1 parent 35edfde commit 5abc08b
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 63 deletions.
81 changes: 70 additions & 11 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Domain specific language for the Lazy api.
use crate::logical_plan::Context;
use crate::prelude::*;
use crate::utils::{has_expr, output_name};
use crate::utils::{has_expr, has_wildcard, output_name};
use polars_core::prelude::*;

#[cfg(feature = "temporal")]
Expand All @@ -15,6 +15,7 @@ use std::{
};
// reexport the lazy method
pub use crate::frame::IntoLazy;
use polars_core::utils::get_supertype;

/// A wrapper trait for any closure `Fn(Vec<Series>) -> Result<Series>`
pub trait SeriesUdf: Send + Sync {
Expand Down Expand Up @@ -111,6 +112,26 @@ where
}
}

#[derive(Clone, Copy, PartialEq, Debug)]
pub struct FunctionOptions {
/// Collect groups to a list before applying a function.
/// This can be important in aggregation context.
pub(crate) collect_groups: bool,
/// There can be two ways of expanding wildcards:
///
/// Say the schema is 'a', 'b' and there is a function f
/// f('*')
/// can expand to:
/// 1.
/// f('a', 'b')
/// or
/// 2.
/// f('a'), f('b')
///
/// setting this to true, will lead to behavior 1.
pub(crate) input_wildcard_expansion: bool,
}

#[derive(PartialEq, Clone)]
pub enum AggExpr {
Min(Box<Expr>),
Expand Down Expand Up @@ -203,9 +224,7 @@ pub enum Expr {
function: NoEq<Arc<dyn SeriesUdf>>,
/// output dtype of the function
output_type: Option<DataType>,
/// if the groups should aggregated to list before
/// execution of the function.
collect_groups: bool,
options: FunctionOptions,
},
Shift {
input: Box<Expr>,
Expand Down Expand Up @@ -729,7 +748,10 @@ impl Expr {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
collect_groups: false,
options: FunctionOptions {
collect_groups: false,
input_wildcard_expansion: false,
},
}
}

Expand All @@ -752,7 +774,10 @@ impl Expr {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
collect_groups: true,
options: FunctionOptions {
collect_groups: true,
input_wildcard_expansion: false,
},
}
}

Expand Down Expand Up @@ -1206,14 +1231,48 @@ where
}

/// Accumulate over multiple columns horizontally / row wise.
pub fn fold_exprs<F: 'static>(mut acc: Expr, f: F, exprs: Vec<Expr>) -> Expr
pub fn fold_exprs<F: 'static>(mut acc: Expr, f: F, mut exprs: Vec<Expr>) -> Expr
where
F: Fn(Series, Series) -> Result<Series> + Send + Sync + Copy,
F: Fn(Series, Series) -> Result<Series> + Send + Sync + Clone,
{
for e in exprs {
acc = map_binary(acc, e, f, None);
if exprs.iter().any(has_wildcard) {
exprs.push(acc);

let function = NoEq::new(Arc::new(move |series: &mut [Series]| {
let mut series = series.to_vec();
let mut acc = series.pop().unwrap();

for s in series {
acc = f(acc, s)?
}
Ok(acc)
}) as Arc<dyn SeriesUdf>);

Expr::Function {
input: exprs,
function,
output_type: None,
options: FunctionOptions {
collect_groups: false,
input_wildcard_expansion: true,
},
}
} else {
for e in exprs {
acc = map_binary_lazy_field(
acc,
e,
f.clone(),
// written inline due to lifetime inference issues.
|_schema, _ctxt, f_l: &Field, f_r: &Field| {
get_supertype(f_l.data_type(), f_r.data_type())
.ok()
.map(|dt| Field::new(f_l.name(), dt))
},
);
}
acc
}
acc
}

/// Get the the sum of the values per row
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-lazy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2411,4 +2411,25 @@ mod test {
assert_eq!(out.shape(), (6, 4));
Ok(())
}

#[test]
fn test_fold_wildcard() -> Result<()> {
let df1 = df![
"a" => [1, 2, 3],
"b" => [1, 2, 3]
]?;

let out = df1
.lazy()
.select(vec![
fold_exprs(lit(0), |a, b| Ok(&a + &b), vec![col("*")]).alias("foo")
])
.collect()?;

assert_eq!(
Vec::from(out.column("foo")?.i32()?),
&[Some(2), Some(4), Some(6)]
);
Ok(())
}
}
10 changes: 8 additions & 2 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ pub fn argsort_by(by: Vec<Expr>, reverse: &[bool]) -> Expr {
input: by,
function,
output_type: Some(DataType::UInt32),
collect_groups: true,
options: FunctionOptions {
collect_groups: true,
input_wildcard_expansion: false,
},
}
}

Expand All @@ -90,6 +93,9 @@ pub fn concat_str(s: Vec<Expr>, delimiter: &str) -> Expr {
input: s,
function,
output_type: Some(DataType::Utf8),
collect_groups: false,
options: FunctionOptions {
collect_groups: false,
input_wildcard_expansion: false,
},
}
}
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub enum AExpr {
input: Vec<Node>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: Option<DataType>,
collect_groups: bool,
options: FunctionOptions,
},
Shift {
input: Node,
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
input,
function,
output_type,
collect_groups,
options,
} => AExpr::Function {
input: to_aexprs(input, arena),
function,
output_type,
collect_groups,
options,
},
Expr::BinaryFunction {
input_a,
Expand Down Expand Up @@ -530,12 +530,12 @@ pub(crate) fn node_to_exp(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
input,
function,
output_type,
collect_groups,
options,
} => Expr::Function {
input: nodes_to_exprs(&input, expr_arena),
function,
output_type,
collect_groups,
options,
},
AExpr::BinaryFunction {
input_a,
Expand Down
58 changes: 52 additions & 6 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use polars_io::csv_core::utils::infer_file_schema;
#[cfg(feature = "parquet")]
use polars_io::{parquet::ParquetReader, SerReader};

use crate::logical_plan::iterator::ArenaExprIter;
use crate::logical_plan::LogicalPlan::DataFrameScan;
use crate::utils::{
combine_predicates_expr, expr_to_root_column_name, expr_to_root_column_names, has_expr,
rename_expr_root_name,
has_wildcard, rename_expr_root_name,
};
use crate::{prelude::*, utils};
use polars_io::csv::NullValues;
Expand Down Expand Up @@ -713,15 +714,15 @@ fn replace_wildcard_with_column(expr: Expr, column_name: Arc<String>) -> Expr {
input,
function,
output_type,
collect_groups,
options,
} => Expr::Function {
input: input
.into_iter()
.map(|e| replace_wildcard_with_column(e, column_name.clone()))
.collect(),
function,
output_type,
collect_groups,
options,
},
Expr::BinaryFunction {
input_a,
Expand Down Expand Up @@ -847,9 +848,7 @@ fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr> {
}
}

let has_wildcard = has_expr(&expr, |e| matches!(e, Expr::Wildcard));

if has_wildcard {
if has_wildcard(&expr) {
// if count wildcard. count one column
if has_expr(&expr, |e| matches!(e, Expr::Agg(AggExpr::Count(_)))) {
let new_name = Arc::new(schema.field(0).unwrap().name().clone());
Expand All @@ -864,6 +863,53 @@ fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr> {

continue;
}
// this path prepares the wildcard as input for the Function Expr
// To deal with the borrow checker we first create an arena from this expression.
// Then we clone that arena to `new_arena` because we cannot borrow and have mutable access,
// We iterate the old_arena mutable.
// * Replace the wildcard column with new column names (assign nodes into new arena)
// * Swap the inputs vec, from Expr::Function in the old arena to Expr::Function in the new arena
// * convert from arena expr to boxed expr
if has_expr(
&expr,
|e| matches!(e, Expr::Function { input, options, .. } if options.input_wildcard_expansion && input.iter().any(has_wildcard)),
) {
let mut arena = Arena::with_capacity(16);
let root = to_aexpr(expr, &mut arena);
let mut new_arena = arena.clone();
let iter = (&arena).iter(root);
let mut function_node = Default::default();
let mut function_inputs = Default::default();

for (node, ae) in iter {
if let AExpr::Function { .. } = ae {
function_node = node;
break;
}
}

if let AExpr::Function { input, .. } = arena.get_mut(function_node) {
let (idx, _) = input
.iter()
.find_position(|&node| matches!(new_arena.get(*node), AExpr::Wildcard))
.expect("should have wildcard");
input.remove(idx);

for field in schema.fields() {
let node = new_arena.add(AExpr::Column(Arc::new(field.name().clone())));
input.push(node);
}
function_inputs = std::mem::take(input);
}

if let AExpr::Function { input, .. } = new_arena.get_mut(function_node) {
*input = function_inputs;
}

let new_expr = node_to_exp(root, &new_arena);
result.push(new_expr);
continue;
}

for field in schema.fields() {
let name = field.name();
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,15 +814,15 @@ impl DefaultPlanner {
input,
function,
output_type,
collect_groups,
options,
} => {
let input = self.create_physical_expressions(&input, ctxt, expr_arena)?;
Ok(Arc::new(ApplyExpr {
inputs: input,
function,
output_type,
expr: node_to_exp(expression, expr_arena),
collect_groups,
collect_groups: options.collect_groups,
}))
}
BinaryFunction {
Expand Down
5 changes: 5 additions & 0 deletions polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ where
current_expr.into_iter().any(|e| matches(e))
}

// this one is used so much that it has its own function, to reduce inlining
pub(crate) fn has_wildcard(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| matches!(e, Expr::Wildcard))
}

/// output name of expr
pub(crate) fn output_name(expr: &Expr) -> Result<Arc<String>> {
for e in expr {
Expand Down
14 changes: 9 additions & 5 deletions py-polars/polars/lazy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..polars import concat_str as _concat_str
from ..polars import cov as pycov
from ..polars import except_ as pyexcept
from ..polars import fold as pyfold
from ..polars import lit as pylit
from ..polars import pearson_corr as pypearson_corr
from ..polars import series_from_range as _series_from_range
Expand Down Expand Up @@ -432,7 +433,7 @@ def map_binary(
def fold(
acc: "pl.Expr",
f: Callable[["pl.Series", "pl.Series"], "pl.Series"],
exprs: tp.List["pl.Expr"],
exprs: Union[tp.List["pl.Expr"], "pl.Expr"],
) -> "pl.Expr":
"""
Accumulate over multiple columns horizontally/ row wise with a left fold.
Expand All @@ -447,11 +448,14 @@ def fold(
Function to apply over the accumulator and the value.
Fn(acc, value) -> new_value
exprs
Expressions to aggregate over.
Expressions to aggregate over. May also be a wildcard expression.
"""
for e in exprs:
acc = map_binary(acc, e, f, None)
return acc
# in case of pl.col("*")
if isinstance(exprs, pl.Expr):
exprs = [exprs]

exprs = pl.lazy.expr._selection_to_pyexpr_list(exprs)
return pl.wrap_expr(pyfold(acc._pyexpr, f, exprs))


def any(name: Union[str, tp.List["pl.Expr"]]) -> "pl.Expr":
Expand Down

0 comments on commit 5abc08b

Please sign in to comment.