Skip to content

Commit

Permalink
fix bug in fitler + binary exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 5, 2021
1 parent 3330fb0 commit 1057bbf
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 2 deletions.
15 changes: 15 additions & 0 deletions polars/polars-core/src/frame/groupby/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ where
{
fn agg_mean(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<Float64Type, _>(groups, |(first, idx)| {
// this can fail due to a bug in lazy code.
// here users can create filters in aggregations
// and thereby creating shorter columns than the original group tuples.
// the group tuples are modified, but if that's done incorrect there can be out of bounds
// access
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
Expand Down Expand Up @@ -124,6 +130,7 @@ where

fn agg_min(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
Expand Down Expand Up @@ -158,6 +165,7 @@ where

fn agg_max(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
Expand Down Expand Up @@ -192,6 +200,7 @@ where

fn agg_sum(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
Expand Down Expand Up @@ -225,6 +234,7 @@ where
}
fn agg_var(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<Float64Type, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
return None;
}
Expand All @@ -238,6 +248,7 @@ where
}
fn agg_std(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<Float64Type, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
return None;
}
Expand All @@ -252,6 +263,7 @@ where
#[cfg(feature = "lazy")]
fn agg_valid_count(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
agg_helper::<UInt32Type, _>(groups, |(_first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if self.null_count() == 0 {
Expand All @@ -273,6 +285,7 @@ macro_rules! impl_agg_first {
let mut ca = $groups
.iter()
.map(|(first, idx)| {
debug_assert!(idx.len() <= $self.len());
if idx.is_empty() {
return None;
}
Expand Down Expand Up @@ -353,6 +366,7 @@ macro_rules! impl_agg_last {
let mut ca = $groups
.iter()
.map(|(_first, idx)| {
debug_assert!(idx.len() <= $self.len());
if idx.is_empty() {
return None;
}
Expand Down Expand Up @@ -432,6 +446,7 @@ macro_rules! impl_agg_n_unique {
$groups
.into_par_iter()
.map(|(_first, idx)| {
debug_assert!(idx.len() <= $self.len());
if idx.is_empty() {
return 0;
}
Expand Down
47 changes: 47 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,34 @@ use crate::physical_plan::PhysicalAggregation;
use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::{prelude::*, POOL};
use std::borrow::Cow;
use std::sync::Arc;

/// In the aggregation of a binary expression, only one expression can modify the size of the groups
/// with a filter operation otherwise the aggregations will produce flawed results.
pub(crate) fn binary_check_group_tuples<'a>(
out: Series,
groups_a: Cow<'a, GroupTuples>,
groups_b: Cow<'a, GroupTuples>,
) -> Result<(Series, Cow<'a, GroupTuples>)> {
match (groups_a, groups_b) {
(Cow::Owned(_), Cow::Owned(_)) => Err(PolarsError::InvalidOperation(
"Cannot apply two filters in a binary expression".into(),
)),
(Cow::Borrowed(a), Cow::Borrowed(b)) => {
if !std::ptr::eq(a, b) {
Err(PolarsError::ValueError(
"filter predicates do not originate from same filter operation".into(),
))
} else {
Ok((out, Cow::Borrowed(a)))
}
}
(Cow::Owned(a), _) => Ok((out, Cow::Owned(a))),
(_, Cow::Owned(a)) => Ok((out, Cow::Owned(a))),
}
}

pub struct BinaryExpr {
pub(crate) left: Arc<dyn PhysicalExpr>,
pub(crate) op: Operator,
Expand Down Expand Up @@ -60,6 +86,27 @@ impl PhysicalExpr for BinaryExpr {
});
apply_operator(&lhs?, &rhs?, self.op)
}

#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupTuples,
state: &ExecutionState,
) -> Result<(Series, Cow<'a, GroupTuples>)> {
let (result_a, result_b) = POOL.install(|| {
rayon::join(
|| self.left.evaluate_on_groups(df, groups, state),
|| self.right.evaluate_on_groups(df, groups, state),
)
});
let (series_a, groups_a) = result_a?;
let (series_b, groups_b) = result_b?;

let out = apply_operator(&series_a, &series_b, self.op)?;
binary_check_group_tuples(out, groups_a, groups_b)
}

fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
todo!()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::logical_plan::Context;
use crate::physical_plan::expressions::binary::binary_check_group_tuples;
use crate::physical_plan::state::ExecutionState;
use crate::physical_plan::PhysicalAggregation;
use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::{prelude::*, POOL};
use std::borrow::Cow;
use std::sync::Arc;

pub(crate) struct BinaryFunctionExpr {
Expand Down Expand Up @@ -41,6 +43,40 @@ impl PhysicalExpr for BinaryFunctionExpr {
})
}

#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
df: &DataFrame,
groups: &'a GroupTuples,
state: &ExecutionState,
) -> Result<(Series, Cow<'a, GroupTuples>)> {
let (series_a, series_b) = POOL.install(|| {
rayon::join(
|| self.input_a.evaluate_on_groups(df, groups, state),
|| self.input_b.evaluate_on_groups(df, groups, state),
)
});
let (series_a, groups_a) = series_a?;
let (series_b, groups_b) = series_b?;

let name = self
.output_field
.get_field(
&df.schema(),
Context::Default,
&Field::new(series_a.name(), series_a.dtype().clone()),
&Field::new(series_b.name(), series_b.dtype().clone()),
)
.map(|fld| fld.name().clone())
.unwrap_or_else(|| "binary_function".to_string());

let out = self.function.call_udf(series_a, series_b).map(|mut s| {
s.rename(&name);
s
})?;
binary_check_group_tuples(out, groups_a, groups_b)
}

fn to_field(&self, input_schema: &Schema) -> Result<Field> {
let field_a = self.input_a.to_field(input_schema)?;
let field_b = self.input_b.to_field(input_schema)?;
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ impl PhysicalAggregation for SortExpr {
groups: &GroupTuples,
state: &ExecutionState,
) -> Result<Option<Series>> {
let s = self.physical_expr.evaluate(df, state)?;
let agg_s = s.agg_list(groups);
let (s, groups) = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let agg_s = s.agg_list(&groups);
let out = agg_s.map(|s| {
s.list()
.unwrap()
Expand Down
23 changes: 23 additions & 0 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1395,3 +1395,26 @@ fn test_regex_selection() -> Result<()> {
assert_eq!(out.get_column_names(), &["anton", "arnold schwars"]);
Ok(())
}

#[test]
fn test_filter_in_groupby_agg() -> Result<()> {
// This tests if the fitler is correctly handled by the binary expression.
// This could lead to UB if it were not the case. The filter creates an empty column.
// but the group tuples could still be untouched leading to out of bounds aggregation.
let df = df![
"a" => [1, 1, 2],
"b" => [1, 2, 3]
]?;

let out = df
.lazy()
.groupby(vec![col("a")])
.agg(vec![
(col("b").filter(col("b").eq(lit(100))) * lit(2)).mean()
])
.collect()?;

assert_eq!(out.column("b_mean")?.null_count(), 2);

Ok(())
}

0 comments on commit 1057bbf

Please sign in to comment.