Skip to content

Commit

Permalink
fix[rust]: cse don't see only a single side of a join trail as equal (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 24, 2022
1 parent b2265a7 commit 4085edd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 35 deletions.
42 changes: 7 additions & 35 deletions polars/polars-lazy/src/logical_plan/optimizer/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,42 +226,14 @@ fn lp_node_equal(a: &ALogicalPlan, b: &ALogicalPlan, expr_arena: &Arena<AExpr>)
}
#[cfg(feature = "python")]
(PythonScan { options: l }, PythonScan { options: r, .. }) => l == r,
(
Join {
left_on: left_on_l,
right_on: right_on_l,
options: options_l,
..
},
Join {
left_on: left_on_r,
right_on: right_on_r,
options: options_r,
..
},
) => {
// the inputs should be checked by previous nodes
// as we iterate from leafs to roots
expr_nodes_equal(left_on_l, left_on_r, expr_arena)
&& expr_nodes_equal(right_on_l, right_on_r, expr_arena)
&& options_l == options_r
_ => {
// joins and unions are also false
// they do not originate from a single trail
// so we would need to follow every leaf that
// is below the joining/union root
// that gets complicated quick
false
}
(
Union {
inputs: l,
options: options_l,
},
Union {
inputs: r,
options: options_r,
},
) => {
// the inputs should be checked by previous nodes
// as we iterate from leafs to roots
options_l == options_r && l.len() == r.len()
}

_ => false,
}
}

Expand Down
58 changes: 58 additions & 0 deletions polars/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,61 @@ fn test_cse_union2_4925() -> PolarsResult<()> {

Ok(())
}

#[test]
fn test_cse_joins_4954() -> PolarsResult<()> {
let x = df![
"a"=> [1],
"b"=> [1],
"c"=> [1],
]?
.lazy();

let y = df![
"a"=> [1],
"b"=> [1],
]?
.lazy();

let z = df![
"a"=> [1],
]?
.lazy();

let a = x.left_join(z.clone(), col("a"), col("a"));
let b = y.left_join(z.clone(), col("a"), col("a"));
let c = a.join(
b,
&[col("a"), col("b")],
&[col("a"), col("b")],
JoinType::Left,
);

let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = c.optimize(&mut lp_arena, &mut expr_arena).unwrap();

// ensure we get only one cache and the it is not above the join
// and ensure that every cache only has 1 hit.
let cache_ids = (&lp_arena)
.iter(lp)
.flat_map(|(_, lp)| {
use ALogicalPlan::*;
match lp {
Cache { id, count, input } => {
assert_eq!(*count, 1);
assert!(matches!(
lp_arena.get(*input),
ALogicalPlan::DataFrameScan { .. }
));

Some(*id)
}
_ => None,
}
})
.collect::<BTreeSet<_>>();

assert_eq!(cache_ids.len(), 1);

Ok(())
}

0 comments on commit 4085edd

Please sign in to comment.