Skip to content

Commit

Permalink
fix[rust]: properly handle aggregation predicates in filter + groupby…
Browse files Browse the repository at this point in the history
… context (#4589)
  • Loading branch information
ritchie46 committed Aug 27, 2022
1 parent 7feddd0 commit 0f08272
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 111 deletions.
57 changes: 21 additions & 36 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use polars_core::prelude::*;
use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
#[cfg(feature = "rank")]
use polars_core::utils::coalesce_nulls_series;
use polars_core::utils::get_supertype;
use rayon::prelude::*;

#[cfg(feature = "arg_where")]
Expand All @@ -18,7 +17,6 @@ use crate::dsl::function_expr::FunctionExpr;
use crate::dsl::function_expr::ListFunction;
use crate::dsl::*;
use crate::prelude::*;
use crate::utils::has_wildcard;

/// Compute the covariance between two columns.
pub fn cov(a: Expr, b: Expr) -> Expr {
Expand Down Expand Up @@ -699,47 +697,34 @@ where
}

/// Accumulate over multiple columns horizontally / row wise.
pub fn fold_exprs<F: 'static, E: AsRef<[Expr]>>(mut acc: Expr, f: F, exprs: E) -> Expr
pub fn fold_exprs<F: 'static, E: AsRef<[Expr]>>(acc: Expr, f: F, exprs: E) -> Expr
where
F: Fn(Series, Series) -> Result<Series> + Send + Sync + Clone,
{
let mut exprs = exprs.as_ref().to_vec();
if exprs.iter().any(|e| has_wildcard(e) | has_regex(e)) {
exprs.push(acc);
exprs.push(acc);

let function = SpecialEq::new(Arc::new(move |series: &mut [Series]| {
let mut series = series.to_vec();
let mut acc = series.pop().unwrap();
let function = SpecialEq::new(Arc::new(move |series: &mut [Series]| {
let mut series = series.to_vec();
let mut acc = series.pop().unwrap();

for s in series {
acc = f(acc, s)?;
}
Ok(acc)
}) as Arc<dyn SeriesUdf>);

// Todo! make sure that output type is correct
Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::same_type(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "",
..Default::default()
},
}
} else {
for e in exprs {
acc = map_binary(
acc,
e,
f.clone(),
GetOutput::map_dtypes(|dt| get_supertype(dt[0], dt[1]).unwrap()),
);
for s in series {
acc = f(acc, s)?;
}
acc
Ok(acc)
}) as Arc<dyn SeriesUdf>);

Expr::AnonymousFunction {
input: exprs,
function,
output_type: GetOutput::super_type(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "fold",
..Default::default()
},
}
}

Expand Down
154 changes: 88 additions & 66 deletions polars/polars-lazy/src/physical_plan/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use polars_core::{prelude::*, POOL};
use rayon::prelude::*;

use crate::physical_plan::state::ExecutionState;
use crate::prelude::UpdateGroups::WithSeriesLen;
use crate::prelude::*;

pub struct FilterExpr {
Expand Down Expand Up @@ -44,78 +45,99 @@ impl PhysicalExpr for FilterExpr {
let ac_predicate_f = || self.by.evaluate_on_groups(df, groups, state);

let (ac_s, ac_predicate) = POOL.install(|| rayon::join(ac_s_f, ac_predicate_f));
let (mut ac_s, ac_predicate) = (ac_s?, ac_predicate?);

let groups = ac_s.groups();
let predicate_s = ac_predicate.flat_naive();
let predicate = predicate_s.bool()?.rechunk();

// all values true don't do anything
if predicate.all() {
return Ok(ac_s);
}
// all values false
// create empty groups
let groups = if !predicate.any() {
let groups = groups.iter().map(|gi| [gi.first(), 0]).collect::<Vec<_>>();
GroupsProxy::Slice {
groups,
rolling: false,
let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?);

match ac_predicate.is_aggregated() {
true => {
let preds = ac_predicate.iter_groups();
let s = ac_s.aggregated();
let ca = s.list()?;
let mut out = ca
.amortized_iter()
.zip(preds)
.map(|(opt_s, opt_pred)| match (opt_s, opt_pred) {
(Some(s), Some(pred)) => s.as_ref().filter(pred.as_ref().bool()?).map(Some),
_ => Ok(None),
})
.collect::<Result<ListChunked>>()?;
out.rename(s.name());
ac_s.with_series(out.into_series(), true);
ac_s.update_groups = WithSeriesLen;
Ok(ac_s)
}
}
// filter the indexes that are true
else {
let predicate = predicate.downcast_iter().next().unwrap();
POOL.install(|| {
match groups.as_ref() {
GroupsProxy::Idx(groups) => {
let groups = groups
.par_iter()
.map(|(first, idx)| unsafe {
let idx: Vec<IdxSize> = idx
.iter()
// Safety:
// just checked bounds in short circuited lhs
.filter_map(|i| {
match predicate.value(*i as usize)
&& predicate.is_valid_unchecked(*i as usize)
{
true => Some(*i),
_ => None,
}
false => {
let groups = ac_s.groups();
let predicate_s = ac_predicate.flat_naive();
let predicate = predicate_s.bool()?.rechunk();

// all values true don't do anything
if predicate.all() {
return Ok(ac_s);
}
// all values false
// create empty groups
let groups = if !predicate.any() {
let groups = groups.iter().map(|gi| [gi.first(), 0]).collect::<Vec<_>>();
GroupsProxy::Slice {
groups,
rolling: false,
}
}
// filter the indexes that are true
else {
let predicate = predicate.downcast_iter().next().unwrap();
POOL.install(|| {
match groups.as_ref() {
GroupsProxy::Idx(groups) => {
let groups = groups
.par_iter()
.map(|(first, idx)| unsafe {
let idx: Vec<IdxSize> = idx
.iter()
// Safety:
// just checked bounds in short circuited lhs
.filter_map(|i| {
match predicate.value(*i as usize)
&& predicate.is_valid_unchecked(*i as usize)
{
true => Some(*i),
_ => None,
}
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();

GroupsProxy::Idx(groups)
}
GroupsProxy::Slice { groups, .. } => {
let groups = groups
.par_iter()
.map(|&[first, len]| unsafe {
let idx: Vec<IdxSize> = (first..first + len)
// Safety:
// just checked bounds in short circuited lhs
.filter(|&i| {
predicate.value(i as usize)
&& predicate.is_valid_unchecked(i as usize)
GroupsProxy::Idx(groups)
}
GroupsProxy::Slice { groups, .. } => {
let groups = groups
.par_iter()
.map(|&[first, len]| unsafe {
let idx: Vec<IdxSize> = (first..first + len)
// Safety:
// just checked bounds in short circuited lhs
.filter(|&i| {
predicate.value(i as usize)
&& predicate.is_valid_unchecked(i as usize)
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();

(*idx.first().unwrap_or(&first), idx)
})
.collect();
GroupsProxy::Idx(groups)
}
}
})
};

ac_s.with_groups(groups).set_original_len(false);
Ok(ac_s)
GroupsProxy::Idx(groups)
}
}
})
};

ac_s.with_groups(groups).set_original_len(false);
Ok(ac_s)
}
}
}

fn to_field(&self, input_schema: &Schema) -> Result<Field> {
Expand Down
8 changes: 0 additions & 8 deletions polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,6 @@ pub(crate) fn has_wildcard(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| matches!(e, Expr::Wildcard))
}

// this one is used so much that it has its own function, to reduce inlining
pub(crate) fn has_regex(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| match e {
Expr::Column(name) => name.starts_with('^') && name.ends_with('$'),
_ => false,
})
}

pub(crate) fn has_nth(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| matches!(e, Expr::Nth(_)))
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def fold(
return pli.wrap_expr(pyfold(acc._pyexpr, f, exprs))


def any(name: str | list[pli.Expr] | pli.Expr) -> pli.Expr:
def any(name: str | list[str] | list[pli.Expr] | pli.Expr) -> pli.Expr:
"""Evaluate columnwise or elementwise with a bitwise OR operation."""
if isinstance(name, (list, pli.Expr)):
return fold(lit(False), lambda a, b: a.cast(bool) | b.cast(bool), name).alias(
Expand Down
24 changes: 24 additions & 0 deletions py-polars/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,27 @@ def test_filter_is_in_4572() -> None:
.agg(pl.col("k").filter(pl.col("k").is_in(["a"])).list())
.frame_equal(expected)
)


def test_filter_aggregation_any() -> None:
assert pl.DataFrame(
{
"id": [1, 2, 3, 4],
"group": [1, 2, 1, 1],
"pred_a": [False, True, False, False],
"pred_b": [False, False, True, True],
}
).groupby("group").agg(
[
pl.any(["pred_a", "pred_b"]),
pl.col("id").filter(pl.any(["pred_a", "pred_b"])).alias("filtered"),
]
).sort(
"group"
).to_dict(
False
) == {
"group": [1, 2],
"any": [[False, True, True], [True]],
"filtered": [[3, 4], [2]],
}

0 comments on commit 0f08272

Please sign in to comment.