Skip to content

Commit

Permalink
fix bug in groupby apply aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 5, 2021
1 parent 1057bbf commit aa91577
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
39 changes: 39 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use rayon::prelude::*;
use std::borrow::Cow;
use std::sync::Arc;

pub struct ApplyExpr {
Expand Down Expand Up @@ -32,6 +33,44 @@ impl PhysicalExpr for ApplyExpr {
}
Ok(out)
}
#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupTuples,
state: &ExecutionState,
) -> Result<(Series, Cow<'a, GroupTuples>)> {
let mut owned_count = 0;
let mut inputs = Vec::with_capacity(self.inputs.len());
let mut groups_vec = Vec::with_capacity(self.inputs.len());
let mut owned_group = None;

self.inputs.iter().try_for_each::<_, Result<_>>(|e| {
let (s, groups_) = e.evaluate_on_groups(df, groups, state)?;
inputs.push(s);
if let Cow::Owned(_) = &groups_ {
owned_group = Some(groups_);
owned_count += 1;
return Ok(());
}
groups_vec.push(groups_);
Ok(())
})?;

let in_name = inputs[0].name().to_string();
let mut out = self.function.call_udf(&mut inputs)?;
if in_name != out.name() {
out.rename(&in_name);
}

match owned_count {
0 => Ok((out, groups_vec.pop().unwrap())),
1 => Ok((out, owned_group.unwrap())),
_ => Err(PolarsError::ValueError(
"Function may only have one input that contains a filter expression".into(),
)),
}
}
fn to_field(&self, input_schema: &Schema) -> Result<Field> {
match &self.output_type {
Some(output_type) => {
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,7 @@ fn test_filter_in_groupby_agg() -> Result<()> {
]?;

let out = df
.clone()
.lazy()
.groupby(vec![col("a")])
.agg(vec![
Expand All @@ -1416,5 +1417,15 @@ fn test_filter_in_groupby_agg() -> Result<()> {

assert_eq!(out.column("b_mean")?.null_count(), 2);

let out = df
.lazy()
.groupby(vec![col("a")])
.agg(vec![(col("b")
.filter(col("b").eq(lit(100)))
.map(|s| Ok(s), None))
.mean()])
.collect()?;
assert_eq!(out.column("b_mean")?.null_count(), 2);

Ok(())
}

0 comments on commit aa91577

Please sign in to comment.