Skip to content

Commit

Permalink
Lazy make Udf/Function work over any number of inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 23, 2021
1 parent ff60255 commit 6ce9bb1
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 145 deletions.
24 changes: 14 additions & 10 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ use std::{
// reexport the lazy method
pub use crate::frame::IntoLazy;

/// A wrapper trait for any closure `Fn(Vec<Series>) -> Result<Series>`
pub trait SeriesUdf: Send + Sync {
fn call_udf(&self, s: Series) -> Result<Series>;
fn call_udf(&self, s: Vec<Series>) -> Result<Series>;
}

impl<F> SeriesUdf for F
where
F: Fn(Series) -> Result<Series> + Send + Sync,
F: Fn(Vec<Series>) -> Result<Series> + Send + Sync,
{
fn call_udf(&self, s: Series) -> Result<Series> {
fn call_udf(&self, s: Vec<Series>) -> Result<Series> {
self(s)
}
}
Expand All @@ -35,6 +36,7 @@ impl Debug for dyn SeriesUdf {
}
}

/// A wrapper trait for any binary closure `Fn(Series, Series) -> Result<Series>`
pub trait SeriesBinaryUdf: Send + Sync {
fn call_udf(&self, a: Series, b: Series) -> Result<Series>;
}
Expand Down Expand Up @@ -194,8 +196,8 @@ pub enum Expr {
truthy: Box<Expr>,
falsy: Box<Expr>,
},
Udf {
input: Box<Expr>,
Function {
input: Vec<Expr>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: Option<DataType>,
},
Expand Down Expand Up @@ -313,7 +315,7 @@ impl fmt::Debug for Expr {
"\nWHEN {:?}\n\t{:?}\nOTHERWISE\n\t{:?}",
predicate, truthy, falsy
),
Udf { input, .. } => write!(f, "APPLY({:?})", input),
Function { input, .. } => write!(f, "APPLY({:?})", input),
BinaryFunction {
input_a, input_b, ..
} => write!(f, "BinaryFunction({:?}, {:?})", input_a, input_b),
Expand Down Expand Up @@ -622,11 +624,13 @@ impl Expr {
/// the correct output_type. If None given the output type of the input expr is used.
pub fn map<F>(self, function: F, output_type: Option<DataType>) -> Self
where
F: SeriesUdf + 'static,
F: Fn(Series) -> Result<Series> + 'static + Send + Sync,
{
Expr::Udf {
input: Box::new(self),
function: NoEq::new(Arc::new(function)),
let f = move |mut s: Vec<Series>| function(s.pop().unwrap());

Expr::Function {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
}
}
Expand Down
10 changes: 5 additions & 5 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ pub enum AExpr {
truthy: Node,
falsy: Node,
},
Udf {
input: Node,
Function {
input: Vec<Node>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: Option<DataType>,
},
Expand Down Expand Up @@ -289,12 +289,12 @@ impl AExpr {
Ok(Field::new(field.name(), data_type.clone()))
}
Ternary { truthy, .. } => arena.get(*truthy).to_field(schema, ctxt, arena),
Udf {
Function {
output_type, input, ..
} => match output_type {
None => arena.get(*input).to_field(schema, ctxt, arena),
None => arena.get(input[0]).to_field(schema, ctxt, arena),
Some(output_type) => {
let input_field = arena.get(*input).to_field(schema, ctxt, arena)?;
let input_field = arena.get(input[0]).to_field(schema, ctxt, arena)?;
Ok(Field::new(input_field.name(), output_type.clone()))
}
},
Expand Down
21 changes: 9 additions & 12 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
falsy: f,
}
}
Expr::Udf {
Expr::Function {
input,
function,
output_type,
} => AExpr::Udf {
input: to_aexpr(*input, arena),
} => AExpr::Function {
input: input.into_iter().map(|e| to_aexpr(e, arena)).collect(),
function,
output_type,
},
Expand Down Expand Up @@ -533,18 +533,15 @@ pub(crate) fn node_to_exp(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
falsy: Box::new(f),
}
}
AExpr::Udf {
AExpr::Function {
input,
function,
output_type,
} => {
let i = node_to_exp(input, expr_arena);
Expr::Udf {
input: Box::new(i),
function,
output_type,
}
}
} => Expr::Function {
input: nodes_to_exprs(&input, expr_arena),
function,
output_type,
},
AExpr::BinaryFunction {
input_a,
input_b,
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<'a> Iterator for ExprIter<'a> {
push(falsy);
push(predicate)
}
Udf { input, .. } => push(input),
Function { input, .. } => input.iter().for_each(|e| push(e)),
Shift { input, .. } => push(input),
Reverse(e) => push(e),
Duplicated(e) => push(e),
Expand Down Expand Up @@ -164,7 +164,7 @@ impl AExpr {
push(falsy);
push(predicate)
}
Udf { input, .. } => push(input),
Function { input, .. } => input.iter().for_each(|e| push(e)),
Shift { input, .. } => push(input),
Reverse(e) => push(e),
Duplicated(e) => push(e),
Expand Down
9 changes: 6 additions & 3 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,15 @@ fn replace_wildcard_with_column(expr: Expr, column_name: Arc<String>) -> Expr {
truthy: Box::new(replace_wildcard_with_column(*truthy, column_name.clone())),
falsy: Box::new(replace_wildcard_with_column(*falsy, column_name)),
},
Expr::Udf {
Expr::Function {
input,
function,
output_type,
} => Expr::Udf {
input: Box::new(replace_wildcard_with_column(*input, column_name)),
} => Expr::Function {
input: input
.into_iter()
.map(|e| replace_wildcard_with_column(e, column_name.clone()))
.collect(),
function,
output_type,
},
Expand Down
31 changes: 0 additions & 31 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,37 +248,6 @@ impl PhysicalAggregation for CastExpr {
}
}

impl PhysicalAggregation for ApplyExpr {
fn aggregate(
&self,
df: &DataFrame,
groups: &GroupTuples,
state: &ExecutionState,
) -> Result<Option<Series>> {
match self.input.as_agg_expr() {
// layer below is also an aggregation expr.
Ok(expr) => {
let aggregated = expr.aggregate(df, groups, state)?;
let out = aggregated.map(|s| self.function.call_udf(s));
out.transpose()
}
Err(_) => {
let series = self.input.evaluate(df, state)?;
series
.agg_list(groups)
.map(|s| {
let s = self.function.call_udf(s);
s.map(|mut s| {
s.rename(series.name());
s
})
})
.map_or(Ok(None), |v| v.map(Some))
}
}
}
}

pub struct AggQuantileExpr {
pub(crate) expr: Arc<dyn PhysicalExpr>,
pub(crate) quantile: f64,
Expand Down
69 changes: 61 additions & 8 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
use crate::physical_plan::state::ExecutionState;
use crate::physical_plan::PhysicalAggregation;
use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use rayon::prelude::*;
use std::sync::Arc;

pub struct ApplyExpr {
pub input: Arc<dyn PhysicalExpr>,
pub inputs: Vec<Arc<dyn PhysicalExpr>>,
pub function: NoEq<Arc<dyn SeriesUdf>>,
pub output_type: Option<DataType>,
pub expr: Expr,
}

impl ApplyExpr {
pub fn new(
input: Arc<dyn PhysicalExpr>,
input: Vec<Arc<dyn PhysicalExpr>>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: Option<DataType>,
expr: Expr,
) -> Self {
ApplyExpr {
input,
inputs: input,
function,
output_type,
expr,
Expand All @@ -33,9 +35,13 @@ impl PhysicalExpr for ApplyExpr {
}

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let input = self.input.evaluate(df, state)?;
let in_name = input.name().to_string();
let mut out = self.function.call_udf(input)?;
let inputs = self
.inputs
.iter()
.map(|e| e.evaluate(df, state))
.collect::<Result<Vec<_>>>()?;
let in_name = inputs[0].name().to_string();
let mut out = self.function.call_udf(inputs)?;
if in_name != out.name() {
out.rename(&in_name);
}
Expand All @@ -44,13 +50,60 @@ impl PhysicalExpr for ApplyExpr {
fn to_field(&self, input_schema: &Schema) -> Result<Field> {
match &self.output_type {
Some(output_type) => {
let input_field = self.input.to_field(input_schema)?;
let input_field = self.inputs[0].to_field(input_schema)?;
Ok(Field::new(input_field.name(), output_type.clone()))
}
None => self.input.to_field(input_schema),
None => self.inputs[0].to_field(input_schema),
}
}
fn as_agg_expr(&self) -> Result<&dyn PhysicalAggregation> {
Ok(self)
}
}

impl PhysicalAggregation for ApplyExpr {
fn aggregate(
&self,
df: &DataFrame,
groups: &GroupTuples,
state: &ExecutionState,
) -> Result<Option<Series>> {
// two possible paths
// all inputs may be final aggregations
// or they may be expression that can work on groups but not yet produce an aggregation

// we first collect the inputs
// if any of the input aggregations yields None, we return None as well
// we check this by comparing the length of the inputs before and after aggregation
let inputs: Vec<_> = match self.inputs[0].as_agg_expr() {
Ok(_) => {
let inputs = self
.inputs
.par_iter()
.map(|e| {
let e = e.as_agg_expr()?;
e.aggregate(df, groups, state)
})
.collect::<Result<Vec<_>>>()?;
inputs.into_iter().flatten().collect()
}
_ => {
let inputs = self
.inputs
.par_iter()
.map(|e| {
let (s, groups) = e.evaluate_on_groups(df, groups, state)?;
Ok(s.agg_list(&groups))
})
.collect::<Result<Vec<_>>>()?;
inputs.into_iter().flatten().collect()
}
};

if inputs.len() == self.inputs.len() {
self.function.call_udf(inputs).map(Some)
} else {
Ok(None)
}
}
}
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ impl PhysicalExpr for WindowExpr {
);

let out = match &self.function {
Expr::Udf { function, .. } => {
Expr::Function { function, .. } => {
let mut df = gb.agg_list()?;
df.may_apply_at_idx(1, |s| function.call_udf(s.clone()))?;
df.may_apply_at_idx(1, |s| function.call_udf(vec![s.clone()]))?;
Ok(df)
}
Expr::Agg(agg) => match agg {
Expand Down

0 comments on commit 6ce9bb1

Please sign in to comment.