Skip to content

Commit

Permalink
fix bug in predicate pushdown on dependent predicates (#3394)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 13, 2022
1 parent 6e6086b commit 35a05e6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,42 @@ impl PredicatePushDown {

match lp {
Selection { predicate, input } => {

// If a predicates result would be influenced by earlier applied filter
// we remove it and apply it locally
let mut apply_local = vec![];
let mut apply_local_nodes = vec![];
for (k, previous) in &acc_predicates {
if predicate_is_pushdown_boundary(*previous, expr_arena) {
apply_local.push(k.clone())
}
}
for name in &apply_local {
apply_local_nodes.push(acc_predicates.remove(name).unwrap());
}

let name = roots_to_key(&aexpr_to_root_names(predicate, expr_arena));
insert_and_combine_predicate(&mut acc_predicates, name, predicate, expr_arena);
let alp = lp_arena.take(input);
self.push_down(alp, acc_predicates, lp_arena, expr_arena)
let new_input = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?;

// TODO!
// If a predicates result would be influenced by earlier applied
// predicates, we simply don't pushdown this one passed this node
// However, we can do better and let it pass but store the order of the predicates
// so that we can apply them in correct order at the deepest level
if !apply_local_nodes.is_empty() {
let predicate = combine_predicates(apply_local_nodes.into_iter(), expr_arena);
let input = lp_arena.add(new_input);

Ok(Selection {
predicate,
input
})

} else {
Ok(new_input)
}
}
DataFrameScan {
df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,33 @@ pub(super) fn get_insertion_name(
)
}

// this checks if a predicate from a node upstream can pass
// the predicate in this filter
// Cases where this cannot be the case:
//
// .filter(a > 1) # filter 2
///.filter(a == min(a)) # filter 1
///
/// the min(a) is influenced by filter 2 so min(a) should not pass
pub(super) fn predicate_is_pushdown_boundary(node: Node, expr_arena: &Arena<AExpr>) -> bool {
let matches = |e: &AExpr| {
matches!(
e,
AExpr::Shift { .. } | AExpr::Sort { .. } | AExpr::SortBy { .. }
| AExpr::Agg(_) // an aggregation needs all rows
| AExpr::Reverse(_)
// Apply groups can be something like shift, sort, or an aggregation like skew
// both need all values
| AExpr::AnonymousFunction {options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, .. }, ..}
| AExpr::Function {options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, .. }, ..}
| AExpr::Explode {..}
// A groupby needs all rows for aggregation
| AExpr::Window {..}
)
};
has_aexpr(node, expr_arena, matches)
}

/// Some predicates should not pass a projection if they would influence results of other columns.
/// For instance shifts | sorts results are influenced by a filter so we do all predicates before the shift | sort
/// The rule of thumb is any operation that changes the order of a column w/r/t other columns should be a
Expand All @@ -117,6 +144,7 @@ pub(super) fn other_column_is_pushdown_boundary(node: Node, expr_arena: &Arena<A
// Apply groups can be something like shift, sort, or an aggregation like skew
// both need all values
| AExpr::AnonymousFunction {options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, .. }, ..}
| AExpr::Function {options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, .. }, ..}
| AExpr::BinaryExpr {..}
| AExpr::Cast {data_type: DataType::Float32 | DataType::Float64, ..}
// cast may create nulls
Expand Down
22 changes: 22 additions & 0 deletions polars/tests/it/lazy/predicate_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,25 @@ fn test_many_filters() -> Result<()> {

Ok(())
}

#[test]
fn test_filter_no_combine() -> Result<()> {
let df = df![
"vals" => [1, 2, 3, 4, 5]
]?;

let out = df
.lazy()
.filter(col("vals").gt(lit(1)))
// should be > 2
// if optimizer would combine predicates this would be flawed
.filter(col("vals").gt(col("vals").min()))
.collect()?;

assert_eq!(
Vec::from(out.column("vals")?.i32()?),
&[Some(3), Some(4), Some(5)]
);

Ok(())
}

0 comments on commit 35a05e6

Please sign in to comment.