Skip to content

Commit

Permalink
fix[rust]: asof-join projection pushdown of 'by' argument (#4609)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 29, 2022
1 parent 8d34e96 commit 560f87e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,21 @@ impl ProjectionPushDown {
let mut names_right = init_set();
let mut local_projection = init_vec();

fn add_nodes_to_accumulated_state(
expr: Node,
acc_projections: &mut Vec<Node>,
local_projection: &mut Vec<Node>,
projected_names: &mut PlHashSet<Arc<str>>,
expr_arena: &mut Arena<AExpr>,
// only for left hand side table we add local names
left_side: bool,
) {
add_expr_to_accumulated(expr, acc_projections, projected_names, expr_arena);
if left_side && !local_projection.contains(&expr) {
local_projection.push(expr)
}
}

// if there are no projections we don't have to do anything (all columns are projected)
// otherwise we build local projections to sort out proper column names due to the
// join operation
Expand All @@ -769,26 +784,57 @@ impl ProjectionPushDown {
let schema_left = lp_arena.get(input_left).schema(lp_arena);
let schema_right = lp_arena.get(input_right).schema(lp_arena);

// make sure that the asof join 'by' columns are projected
#[cfg(feature = "asof_join")]
if let JoinType::AsOf(asof_options) = &options.how {
if let (Some(left_by), Some(right_by)) =
(&asof_options.left_by, &asof_options.right_by)
{
for name in left_by {
let node = expr_arena.add(AExpr::Column(Arc::from(name.as_str())));
add_nodes_to_accumulated_state(
node,
&mut pushdown_left,
&mut local_projection,
&mut names_left,
expr_arena,
true,
);
}
for name in right_by {
let node = expr_arena.add(AExpr::Column(Arc::from(name.as_str())));
add_nodes_to_accumulated_state(
node,
&mut pushdown_right,
&mut local_projection,
&mut names_right,
expr_arena,
false,
);
}
}
}

// We need the join columns so we push the projection downwards
for e in &left_on {
add_expr_to_accumulated(
add_nodes_to_accumulated_state(
*e,
&mut pushdown_left,
&mut local_projection,
&mut names_left,
expr_arena,
true,
);
if !local_projection.contains(e) {
local_projection.push(*e)
}
}
for e in &right_on {
add_expr_to_accumulated(
add_nodes_to_accumulated_state(
*e,
&mut pushdown_right,
&mut local_projection,
&mut names_right,
expr_arena,
false,
);
// we don't add right column names to local_projection as they are removed
}

for proj in acc_projections {
Expand All @@ -803,7 +849,7 @@ impl ProjectionPushDown {
let proj = expr_arena.add(AExpr::Alias(node, name.clone()));
local_projection.push(proj)
}
// now we don
// now we don't
add_local = false;
}

Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,12 @@ def test_join_on_cast() -> None:
assert df_a.lazy().join(
df_b.lazy(), on=pl.col("a").cast(pl.Int64)
).collect().to_dict(False) == {"row_nr": [1, 2, 3, 5], "a": [-2, 3, 3, 10]}


def test_asof_join_projection_resolution_4606() -> None:
a = pl.DataFrame({"a": [1], "b": [2], "c": [3]}).lazy()
b = pl.DataFrame({"a": [1], "b": [2], "d": [4]}).lazy()
joined_tbl = a.join_asof(b, on="a", by="b")
assert joined_tbl.groupby("a").agg(
[pl.col("c").sum().alias("c")]
).collect().columns == ["a", "c"]

0 comments on commit 560f87e

Please sign in to comment.