Skip to content

Commit

Permalink
Fix and test accumulation in aggregation context; #678
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 25, 2021
1 parent 309ae2b commit 7fddd58
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 91 deletions.
56 changes: 46 additions & 10 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ pub use crate::frame::IntoLazy;

/// A wrapper trait for any closure `Fn(Vec<Series>) -> Result<Series>`
pub trait SeriesUdf: Send + Sync {
fn call_udf(&self, s: Vec<Series>) -> Result<Series>;
fn call_udf(&self, s: &mut [Series]) -> Result<Series>;
}

impl<F> SeriesUdf for F
where
F: Fn(Vec<Series>) -> Result<Series> + Send + Sync,
F: Fn(&mut [Series]) -> Result<Series> + Send + Sync,
{
fn call_udf(&self, s: Vec<Series>) -> Result<Series> {
fn call_udf(&self, s: &mut [Series]) -> Result<Series> {
self(s)
}
}
Expand Down Expand Up @@ -197,9 +197,15 @@ pub enum Expr {
falsy: Box<Expr>,
},
Function {
/// function arguments
input: Vec<Expr>,
/// function to apply
function: NoEq<Arc<dyn SeriesUdf>>,
/// output dtype of the function
output_type: Option<DataType>,
/// if the groups should aggregated to list before
/// execution of the function.
collect_groups: bool,
},
Shift {
input: Box<Expr>,
Expand Down Expand Up @@ -563,15 +569,15 @@ impl Expr {
if has_expr(&self, |e| matches!(e, Expr::Wildcard)) {
panic!("wildcard not supperted in unique expr");
}
self.map(|s: Series| s.unique(), None)
self.apply(|s: Series| s.unique(), None)
}

/// Get the first index of unique values of this expression.
pub fn arg_unique(self) -> Self {
if has_expr(&self, |e| matches!(e, Expr::Wildcard)) {
panic!("wildcard not supported in unique expr");
}
self.map(
self.apply(
|s: Series| s.arg_unique().map(|ca| ca.into_series()),
Some(DataType::UInt32),
)
Expand All @@ -582,7 +588,7 @@ impl Expr {
if has_expr(&self, |e| matches!(e, Expr::Wildcard)) {
panic!("wildcard not supported in unique expr");
}
self.map(
self.apply(
move |s: Series| Ok(s.argsort(reverse).into_series()),
Some(DataType::UInt32),
)
Expand Down Expand Up @@ -620,18 +626,48 @@ impl Expr {
}

/// Apply a function/closure once the logical plan get executed.
///
/// This function is very similar to [apply](Expr::apply), but differs in how it handles aggregations.
///
/// * `map` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power`
/// * `apply` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc.
///
/// It is the responsibility of the caller that the schema is correct by giving
/// the correct output_type. If None given the output type of the input expr is used.
pub fn map<F>(self, function: F, output_type: Option<DataType>) -> Self
where
F: Fn(Series) -> Result<Series> + 'static + Send + Sync,
{
let f = move |mut s: Vec<Series>| function(s.pop().unwrap());
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
collect_groups: false,
}
}

/// Apply a function/closure over the groups. This should only be used in a groupby aggregation.
///
/// It is the responsibility of the caller that the schema is correct by giving
/// the correct output_type. If None given the output type of the input expr is used.
///
/// This difference with [map](Self::map) is that `apply` will create a separate `Series` per group.
///
/// * `map` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power`
/// * `apply` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc.
pub fn apply<F>(self, function: F, output_type: Option<DataType>) -> Self
where
F: Fn(Series) -> Result<Series> + 'static + Send + Sync,
{
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
collect_groups: true,
}
}

Expand Down Expand Up @@ -714,17 +750,17 @@ impl Expr {

/// Get an array with the cumulative sum computed at every element
pub fn cum_sum(self, reverse: bool) -> Self {
self.map(move |s: Series| Ok(s.cum_sum(reverse)), None)
self.apply(move |s: Series| Ok(s.cum_sum(reverse)), None)
}

/// Get an array with the cumulative min computed at every element
pub fn cum_min(self, reverse: bool) -> Self {
self.map(move |s: Series| Ok(s.cum_min(reverse)), None)
self.apply(move |s: Series| Ok(s.cum_min(reverse)), None)
}

/// Get an array with the cumulative max computed at every element
pub fn cum_max(self, reverse: bool) -> Self {
self.map(move |s: Series| Ok(s.cum_max(reverse)), None)
self.apply(move |s: Series| Ok(s.cum_max(reverse)), None)
}

/// Apply window function over a subgroup.
Expand Down
26 changes: 26 additions & 0 deletions polars/polars-lazy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2101,4 +2101,30 @@ mod test {
dbg!(out1, out2);
Ok(())
}

#[test]
fn test_groupby_cumsum() -> Result<()> {
let df = df![
"groups" => [1, 2, 2, 3, 3, 3],
"vals" => [1, 5, 6, 3, 9, 8]
]?;

let out = df
.lazy()
.groupby(vec![col("groups")])
.agg(vec![col("vals").cum_sum(false)])
.sort("groups", false)
.collect()?;

assert_eq!(
Vec::from(out.column("collected")?.explode()?.i32()?),
[1, 5, 11, 3, 12, 20]
.iter()
.copied()
.map(Some)
.collect::<Vec<_>>()
);

Ok(())
}
}
3 changes: 2 additions & 1 deletion polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
/// be used and so on.
pub fn argsort_by(by: Vec<Expr>, reverse: &[bool]) -> Expr {
let reverse = reverse.to_vec();
let function = NoEq::new(Arc::new(move |s: Vec<Series>| {
let function = NoEq::new(Arc::new(move |s: &mut [Series]| {
polars_core::functions::argsort_by(&s, &reverse).map(|ca| ca.into_series())
}) as Arc<dyn SeriesUdf>);

Expr::Function {
input: by,
function,
output_type: Some(DataType::UInt32),
collect_groups: true,
}
}
1 change: 1 addition & 0 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub enum AExpr {
input: Vec<Node>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: Option<DataType>,
collect_groups: bool,
},
Shift {
input: Node,
Expand Down
4 changes: 4 additions & 0 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
input,
function,
output_type,
collect_groups,
} => AExpr::Function {
input: input.into_iter().map(|e| to_aexpr(e, arena)).collect(),
function,
output_type,
collect_groups,
},
Expr::BinaryFunction {
input_a,
Expand Down Expand Up @@ -537,10 +539,12 @@ pub(crate) fn node_to_exp(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
input,
function,
output_type,
collect_groups,
} => Expr::Function {
input: nodes_to_exprs(&input, expr_arena),
function,
output_type,
collect_groups,
},
AExpr::BinaryFunction {
input_a,
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,15 @@ fn replace_wildcard_with_column(expr: Expr, column_name: Arc<String>) -> Expr {
input,
function,
output_type,
collect_groups,
} => Expr::Function {
input: input
.into_iter()
.map(|e| replace_wildcard_with_column(e, column_name.clone()))
.collect(),
function,
output_type,
collect_groups,
},
Expr::BinaryFunction {
input_a,
Expand Down
74 changes: 54 additions & 20 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,7 @@ pub struct ApplyExpr {
pub function: NoEq<Arc<dyn SeriesUdf>>,
pub output_type: Option<DataType>,
pub expr: Expr,
}

impl ApplyExpr {
pub fn new(
input: Vec<Arc<dyn PhysicalExpr>>,
function: NoEq<Arc<dyn SeriesUdf>>,
output_type: Option<DataType>,
expr: Expr,
) -> Self {
ApplyExpr {
inputs: input,
function,
output_type,
expr,
}
}
pub collect_groups: bool,
}

impl PhysicalExpr for ApplyExpr {
Expand All @@ -35,13 +20,13 @@ impl PhysicalExpr for ApplyExpr {
}

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let inputs = self
let mut inputs = self
.inputs
.iter()
.map(|e| e.evaluate(df, state))
.collect::<Result<Vec<_>>>()?;
let in_name = inputs[0].name().to_string();
let mut out = self.function.call_udf(inputs)?;
let mut out = self.function.call_udf(&mut inputs)?;
if in_name != out.name() {
out.rename(&in_name);
}
Expand Down Expand Up @@ -75,7 +60,7 @@ impl PhysicalAggregation for ApplyExpr {
// we first collect the inputs
// if any of the input aggregations yields None, we return None as well
// we check this by comparing the length of the inputs before and after aggregation
let inputs: Vec<_> = match self.inputs[0].as_agg_expr() {
let mut inputs: Vec<_> = match self.inputs[0].as_agg_expr() {
Ok(_) => {
let inputs = self
.inputs
Expand All @@ -101,7 +86,56 @@ impl PhysicalAggregation for ApplyExpr {
};

if inputs.len() == self.inputs.len() {
self.function.call_udf(inputs).map(Some)
if inputs.len() == 1 {
let s = inputs.pop().unwrap();

match (s.list(), self.collect_groups) {
(Ok(ca), true) => {
let mut container = vec![Default::default()];

let ca: ListChunked = ca
.into_iter()
.map(|opt_s| {
opt_s.and_then(|s| {
container[0] = s;
self.function.call_udf(&mut container).ok()
})
})
.collect();
Ok(Some(ca.into_series()))
}
_ => self.function.call_udf(&mut [s]).map(Some),
}
} else {
match (inputs[0].list(), self.collect_groups) {
(Ok(_), true) => {
// container that will hold the arguments &[Series]
let mut args = Vec::with_capacity(inputs.len());
let takers: Vec<_> = inputs
.iter()
.map(|s| s.list().unwrap().take_rand())
.collect();
let ca: ListChunked = (0..inputs[0].len())
.map(|i| {
args.clear();

takers.iter().for_each(|taker| {
if let Some(s) = taker.get(i) {
args.push(s);
}
});
if args.len() == takers.len() {
self.function.call_udf(&mut args).ok()
} else {
None
}
})
.collect();
Ok(Some(ca.into_series()))
}
_ => self.function.call_udf(&mut inputs).map(Some),
}
}
} else {
Ok(None)
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl PhysicalExpr for WindowExpr {
let out = match &self.function {
Expr::Function { function, .. } => {
let mut df = gb.agg_list()?;
df.may_apply_at_idx(1, |s| function.call_udf(vec![s.clone()]))?;
df.may_apply_at_idx(1, |s| function.call_udf(&mut [s.clone()]))?;
Ok(df)
}
Expr::Agg(agg) => match agg {
Expand Down

0 comments on commit 7fddd58

Please sign in to comment.