Skip to content

Commit

Permalink
fix: Only push predicates depending on the subset columns past `uniqu…
Browse files Browse the repository at this point in the history
…e()` (#14668)
  • Loading branch information
nameexhaustion committed Feb 25, 2024
1 parent b13a7f5 commit f52473a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
6 changes: 5 additions & 1 deletion crates/polars-plan/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ impl LogicalPlan {
input._format(f, sub_indent)
},
Distinct { input, options } => {
write!(f, "{:indent$}UNIQUE BY {:?}", "", options.subset)?;
write!(
f,
"{:indent$}UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}",
"", options.maintain_order, options.keep_strategy, options.subset
)?;
input._format(f, sub_indent)
},
Slice { input, offset, len } => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,24 +403,15 @@ impl<'a> PredicatePushDown<'a> {
input,
options
} => {

if matches!(options.keep_strategy, UniqueKeepStrategy::Any | UniqueKeepStrategy::None) {
// currently the distinct operation only keeps the first occurrences.
// this may have influence on the pushed down predicates. If the pushed down predicates
// contain a binary expression (thus depending on values in multiple columns)
// the final result may differ if it is pushed down.

let mut root_count = 0;

// if this condition is called more than once, its a binary or ternary operation.
let condition = |_| {
if root_count == 0 {
root_count += 1;
false
} else {
true
}
if let Some(ref subset) = options.subset {
// Predicates on the subset can pass.
let subset = subset.clone();
let mut names_set = PlHashSet::<&str>::with_capacity(subset.len());
for name in subset.iter() {
names_set.insert(name.as_str());
};

let condition = |name: Arc<str>| !names_set.contains(name.as_ref());
let local_predicates =
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena<AExpr>)
has_aexpr(node, expr_arena, matches)
}

/// Transfer a predicate from `acc_predicates` that will be pushed down
/// to a local_predicates vec based on a condition.
/// Evaluates a condition on the column name inputs of every predicate, where if
/// the condition evaluates to true on any column name the predicate is
/// transferred to local.
pub(super) fn transfer_to_local_by_name<F>(
expr_arena: &Arena<AExpr>,
acc_predicates: &mut PlHashMap<Arc<str>, Node>,
Expand All @@ -129,7 +130,7 @@ where
for name in root_names {
if condition(name) {
remove_keys.push(key.clone());
continue;
break;
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion crates/polars-plan/src/logical_plan/tree_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,13 @@ impl<'a> TreeFmtNode<'a> {
.collect(),
),
NL(h, Distinct { input, options }) => ND(
wh(h, &format!("UNIQUE BY {:?}", options.subset)),
wh(
h,
&format!(
"UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}",
options.maintain_order, options.keep_strategy, options.subset
),
),
vec![NL(None, input)],
),
NL(h, LogicalPlan::Slice { input, offset, len }) => ND(
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/operations/unique/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def test_unique_predicate_pd() -> None:
expected = pl.DataFrame({"x": ["abc"], "y": ["xxx"], "z": [True]})
assert_frame_equal(result, expected)

# Issue #14595: filter should not naively be pushed past unique()
for maintain_order in (True, False):
for keep in ("first", "last", "any", "none"):
q = (
lf.unique("x", maintain_order=maintain_order, keep=keep) # type: ignore[arg-type]
.filter(pl.col("x") == "abc")
.filter(pl.col("z"))
)
plan = q.explain()
assert r'FILTER col("z")' in plan
# We can push filters if they only depend on the subset columns of unique()
assert r'SELECTION: "[(col(\"x\")) == (String(abc))]"' in plan
assert_frame_equal(q.collect(predicate_pushdown=False), q.collect())


def test_unique_on_list_df() -> None:
assert pl.DataFrame(
Expand Down

0 comments on commit f52473a

Please sign in to comment.