Skip to content

Commit

Permalink
fix predicate pushdown in union + count expression (#3882)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 4, 2022
1 parent 30db393 commit 9b4e00c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 7 deletions.
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/logical_plan/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ pub trait Optimize {
// arbitrary constant to reduce reallocation.
const HASHMAP_SIZE: usize = 16;

pub(crate) fn init_hashmap<K, V>() -> PlHashMap<K, V> {
PlHashMap::with_capacity(HASHMAP_SIZE)
pub(crate) fn init_hashmap<K, V>(max_len: Option<usize>) -> PlHashMap<K, V> {
PlHashMap::with_capacity(std::cmp::min(max_len.unwrap_or(HASHMAP_SIZE), HASHMAP_SIZE))
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ impl PredicatePushDown {
// first we check if we are able to push down the predicate passed this node
// it could be that this node just added the column where we base the predicate on
let input_schema = lp_arena.get(node).schema(lp_arena);
let mut pushdown_predicates = optimizer::init_hashmap();
let mut pushdown_predicates =
optimizer::init_hashmap(Some(acc_predicates.len()));
for (name, &predicate) in acc_predicates.iter() {
// we can pushdown the predicate
if check_input_node(predicate, input_schema, expr_arena) {
Expand Down Expand Up @@ -128,7 +129,12 @@ impl PredicatePushDown {
.iter()
.map(|&node| {
let alp = lp_arena.take(node);
let alp = self.push_down(alp, init_hashmap(), lp_arena, expr_arena)?;
let alp = self.push_down(
alp,
init_hashmap(Some(acc_predicates.len())),
lp_arena,
expr_arena,
)?;
lp_arena.replace(node, alp);
Ok(node)
})
Expand Down Expand Up @@ -386,8 +392,8 @@ impl PredicatePushDown {
let schema_left = lp_arena.get(input_left).schema(lp_arena);
let schema_right = lp_arena.get(input_right).schema(lp_arena);

let mut pushdown_left = optimizer::init_hashmap();
let mut pushdown_right = optimizer::init_hashmap();
let mut pushdown_left = optimizer::init_hashmap(Some(acc_predicates.len()));
let mut pushdown_right = optimizer::init_hashmap(Some(acc_predicates.len()));
let mut local_predicates = Vec::with_capacity(acc_predicates.len());

for (_, predicate) in acc_predicates {
Expand Down Expand Up @@ -478,8 +484,27 @@ impl PredicatePushDown {
self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena)
}
}
lp @ Union {..} => {
let mut local_predicates = vec![];

// a count is influenced by a Union/Vstack
acc_predicates.retain(|_, predicate| {
if has_aexpr(*predicate, expr_arena, |ae| matches!(ae, AExpr::Count)) {
local_predicates.push(*predicate);
false
} else {
true
}
});
let lp = self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?;
Ok(if local_predicates.is_empty() {
self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)
} else {
lp
})
}
// Pushed down passed these nodes
lp @ Cache { .. } | lp @ Union { .. } | lp @ Sort { .. } => {
lp @ Cache { .. } | lp @ Sort { .. } => {
self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)
}
lp @ HStack {..} | lp @ Projection {..} => {
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,3 +1279,21 @@ def test_deadlocks_3409() -> None:
.with_columns([pl.col("col1").cumulative_eval(pl.element().map(lambda x: 0))])
.to_dict(False)
) == {"col1": [0, 0, 0]}


def test_predicate_count_vstack() -> None:
l1 = pl.DataFrame(
{
"k": ["x", "y"],
"v": [3, 2],
}
).lazy()
l2 = pl.DataFrame(
{
"k": ["x", "y"],
"v": [5, 7],
}
).lazy()
assert pl.concat([l1, l2]).filter(pl.count().over("k") == 2).collect()[
"v"
].to_list() == [3, 2, 5, 7]

0 comments on commit 9b4e00c

Please sign in to comment.