Skip to content

Commit

Permalink
fix predicate pushdown for predicates that do aggregations (#3396)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 13, 2022
1 parent 20b0ef3 commit 13e9fc8
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl PredicatePushDown {
// we should not pass these projections
if exprs
.iter()
.any(|e_n| other_column_is_pushdown_boundary(*e_n, expr_arena))
.any(|e_n| project_other_column_is_predicate_pushdown_boundary(*e_n, expr_arena))
{
return self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena);
}
Expand Down Expand Up @@ -165,16 +165,7 @@ impl PredicatePushDown {

// If a predicates result would be influenced by earlier applied filter
// we remove it and apply it locally
let mut apply_local = vec![];
let mut apply_local_nodes = vec![];
for (k, previous) in &acc_predicates {
if predicate_is_pushdown_boundary(*previous, expr_arena) {
apply_local.push(k.clone())
}
}
for name in &apply_local {
apply_local_nodes.push(acc_predicates.remove(name).unwrap());
}
let local_predicates = transfer_to_local_by_node(&mut acc_predicates, |node| predicate_is_pushdown_boundary(node, expr_arena));

let name = roots_to_key(&aexpr_to_root_names(predicate, expr_arena));
insert_and_combine_predicate(&mut acc_predicates, name, predicate, expr_arena);
Expand All @@ -186,18 +177,7 @@ impl PredicatePushDown {
// predicates, we simply don't pushdown this one passed this node
// However, we can do better and let it pass but store the order of the predicates
// so that we can apply them in correct order at the deepest level
if !apply_local_nodes.is_empty() {
let predicate = combine_predicates(apply_local_nodes.into_iter(), expr_arena);
let input = lp_arena.add(new_input);

Ok(Selection {
predicate,
input
})

} else {
Ok(new_input)
}
Ok(self.optional_apply_predicate(new_input, local_predicates, lp_arena, expr_arena))
}
DataFrameScan {
df,
Expand Down Expand Up @@ -231,7 +211,7 @@ impl PredicatePushDown {
|| args.value_vars.iter().any(|s| s.as_str() == name)
};
let local_predicates =
transfer_to_local(expr_arena, &mut acc_predicates, condition);
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition);

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;

Expand Down Expand Up @@ -324,8 +304,9 @@ impl PredicatePushDown {
}
Explode { input, columns, schema } => {
let condition = |name: Arc<str>| columns.iter().any(|s| s.as_str() == &*name);
let local_predicates =
transfer_to_local(expr_arena, &mut acc_predicates, condition);
let mut local_predicates =
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition);
local_predicates.extend_from_slice(&transfer_to_local_by_node(&mut acc_predicates, |node| predicate_is_pushdown_boundary(node, expr_arena)));

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;
let lp = Explode { input, columns, schema };
Expand All @@ -351,8 +332,9 @@ impl PredicatePushDown {
true
}
};
let local_predicates =
transfer_to_local(expr_arena, &mut acc_predicates, condition);
let mut local_predicates =
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition);
local_predicates.extend_from_slice(&transfer_to_local_by_node(&mut acc_predicates, |node| predicate_is_pushdown_boundary(node, expr_arena)));

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;
let lp = Distinct {
Expand Down Expand Up @@ -385,34 +367,41 @@ impl PredicatePushDown {
|e: &AExpr| matches!(e, AExpr::IsNull(_) | AExpr::IsNotNull(_));
if has_aexpr(predicate, expr_arena, matches)
// join might create null values.
|| has_aexpr(predicate, expr_arena, checks_nulls) && matches!(&options.how, JoinType::Left | JoinType::Outer | JoinType::Cross){
|| has_aexpr(predicate, expr_arena, checks_nulls)
// only these join types produce null values
&& matches!(&options.how, JoinType::Left | JoinType::Outer | JoinType::Cross){
local_predicates.push(predicate);
continue;
}
// these indicate to which tables we are going to push down the predicate
let mut filter_left = false;
let mut filter_right = false;

// no else if. predicate can be in both tables.
if check_input_node(predicate, schema_left, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, schema_left);
insert_and_combine_predicate(
&mut pushdown_left,
name,
predicate,
expr_arena,
);
filter_left = true;
}
if check_input_node(predicate, schema_right, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, schema_right);
insert_and_combine_predicate(
&mut pushdown_right,
name,
predicate,
expr_arena,
);
filter_right = true;
// predicate should not have an aggregation or window function as that would
// be influenced by join
if !predicate_is_pushdown_boundary(predicate, expr_arena) {
// no else if. predicate can be in both tables.
if check_input_node(predicate, schema_left, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, schema_left);
insert_and_combine_predicate(
&mut pushdown_left,
name,
predicate,
expr_arena,
);
filter_left = true;
}

if check_input_node(predicate, schema_right, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, schema_right);
insert_and_combine_predicate(
&mut pushdown_right,
name,
predicate,
expr_arena,
);
filter_right = true;
}
}
match (filter_left, filter_right, &options.how) {
// if not pushed down on of the tables we have to do it locally.
Expand All @@ -429,15 +418,6 @@ impl PredicatePushDown {
// business as usual
_ => {}
}
// An outer join or left join may create null values.
// we also do it local
let matches = |e: &AExpr| matches!(e, AExpr::IsNotNull(_) | AExpr::IsNull(_));
if (options.how == JoinType::Outer) | (options.how == JoinType::Left)
&& has_aexpr(predicate, expr_arena, matches)
{
local_predicates.push(predicate);
continue;
}
}

self.pushdown_and_assign(input_left, pushdown_left, lp_arena, expr_arena)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ pub(super) fn predicate_is_pushdown_boundary(node: Node, expr_arena: &Arena<AExp
/// predicate pushdown blocker.
///
/// This checks the boundary of other columns
pub(super) fn other_column_is_pushdown_boundary(node: Node, expr_arena: &Arena<AExpr>) -> bool {
pub(super) fn project_other_column_is_predicate_pushdown_boundary(
node: Node,
expr_arena: &Arena<AExpr>,
) -> bool {
let matches = |e: &AExpr| {
matches!(
e,
Expand Down Expand Up @@ -164,7 +167,10 @@ pub(super) fn other_column_is_pushdown_boundary(node: Node, expr_arena: &Arena<A
}

/// This checks the boundary of same columns. So that means columns that are referred in the predicate
pub(super) fn predicate_column_is_pushdown_boundary(node: Node, expr_arena: &Arena<AExpr>) -> bool {
pub(super) fn projection_column_is_predicate_pushdown_boundary(
node: Node,
expr_arena: &Arena<AExpr>,
) -> bool {
let matches = |e: &AExpr| {
matches!(
e,
Expand Down Expand Up @@ -212,7 +218,7 @@ where
for projection_node in &projections {
// only if a predicate refers to this projection's output column.
let projection_maybe_boundary =
predicate_column_is_pushdown_boundary(*projection_node, expr_arena);
projection_column_is_predicate_pushdown_boundary(*projection_node, expr_arena);

let projection_expr = expr_arena.get(*projection_node);
let output_field = projection_expr
Expand Down Expand Up @@ -253,7 +259,8 @@ where

// we check if predicates can be done on the input above
// this can only be done if the current projection is not a projection boundary
let is_boundary = other_column_is_pushdown_boundary(*projection_node, expr_arena);
let is_boundary =
project_other_column_is_predicate_pushdown_boundary(*projection_node, expr_arena);

// remove predicates that cannot be done on the input above
let to_local = acc_predicates
Expand Down Expand Up @@ -314,13 +321,13 @@ pub(super) fn no_pushdown_preds<F>(
let columns = aexpr_to_root_names(node, arena);

let condition = |name: Arc<str>| columns.contains(&name);
local_predicates.extend(transfer_to_local(arena, acc_predicates, condition));
local_predicates.extend(transfer_to_local_by_name(arena, acc_predicates, condition));
}
}

/// Transfer a predicate from `acc_predicates` that will be pushed down
/// to a local_predicates vec based on a condition.
pub(super) fn transfer_to_local<F>(
pub(super) fn transfer_to_local_by_name<F>(
expr_arena: &Arena<AExpr>,
acc_predicates: &mut PlHashMap<Arc<str>, Node>,
mut condition: F,
Expand All @@ -347,3 +354,29 @@ where
}
local_predicates
}

/// Transfer a predicate from `acc_predicates` that will be pushed down
/// to a local_predicates vec based on a condition.
pub(super) fn transfer_to_local_by_node<F>(
acc_predicates: &mut PlHashMap<Arc<str>, Node>,
mut condition: F,
) -> Vec<Node>
where
F: FnMut(Node) -> bool,
{
let mut remove_keys = Vec::with_capacity(acc_predicates.len());

for (key, predicate) in &*acc_predicates {
if condition(*predicate) {
remove_keys.push(key.clone());
continue;
}
}
let mut local_predicates = Vec::with_capacity(remove_keys.len());
for key in remove_keys {
if let Some(pred) = acc_predicates.remove(&*key) {
local_predicates.push(pred)
}
}
local_predicates
}
22 changes: 22 additions & 0 deletions polars/tests/it/lazy/predicate_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,25 @@ fn test_filter_no_combine() -> Result<()> {

Ok(())
}

#[test]
fn test_filter_block_join() -> Result<()> {
let df_a = df![
"a" => ["a", "b", "c"],
"c" => [1, 4, 6]
]?;
let df_b = df![
"a" => ["a", "a", "c"],
"d" => [2, 4, 3]
]?;

let out = df_a
.lazy()
.left_join(df_b.lazy(), "a", "a")
// mean is influence by join
.filter(col("c").mean().eq(col("d")))
.collect()?;
assert_eq!(out.shape(), (1, 3));

Ok(())
}

0 comments on commit 13e9fc8

Please sign in to comment.