Skip to content

Commit

Permalink
fix[rust]: inform projection optimizer of inline join expressions (#4725
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ritchie46 committed Sep 4, 2022
1 parent beef7f0 commit a56ea7c
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 13 deletions.
14 changes: 13 additions & 1 deletion polars/polars-lazy/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,19 @@ impl<'a> ALogicalPlanBuilder<'a> {
})
.collect::<Vec<_>>();

let schema = det_join_schema(&schema_left, &schema_right, &right_names, &options);
let left_on_exprs = left_on
.iter()
.map(|node| node_to_expr(*node, self.expr_arena))
.collect::<Vec<_>>();

let schema = det_join_schema(
&schema_left,
&schema_right,
&left_on_exprs,
&right_names,
&options,
)
.unwrap();

let lp = ALogicalPlan::Join {
input_left: self.root,
Expand Down
12 changes: 11 additions & 1 deletion polars/polars-lazy/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,17 @@ impl LogicalPlanBuilder {
into
);

let schema = det_join_schema(&schema_left, &schema_right, &right_names, &options);
let schema = try_delayed!(
det_join_schema(
&schema_left,
&schema_right,
&left_on,
&right_names,
&options
),
self.0,
into
);

LogicalPlan::Join {
input_left: Box::new(self.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::prelude::iterator::ArenaExprIter;
use crate::prelude::*;
use crate::utils::{
aexpr_assign_renamed_root, aexpr_to_root_names, aexpr_to_root_nodes, check_input_node,
has_aexpr,
};

fn init_vec() -> Vec<Node> {
Expand Down Expand Up @@ -311,7 +310,7 @@ impl ProjectionPushDown {
// only aliases should be projected locally in the rest of the projections.
} else {
for expr in expr {
if has_aexpr(expr, expr_arena, |e| matches!(e, AExpr::Alias(_, _))) {
if has_aexpr_alias(expr, expr_arena) {
local_projection.push(expr)
}
}
Expand Down Expand Up @@ -768,8 +767,32 @@ impl ProjectionPushDown {
left_side: bool,
) {
add_expr_to_accumulated(expr, acc_projections, projected_names, expr_arena);
if left_side && !local_projection.contains(&expr) {
local_projection.push(expr)
// the projections may do more than simply project.
// e.g. col("foo").truncate().alias("bar")
// that means we don't want to execute the projection as that is already done by
// the JOIN executor
// we only want to add the `col` and the `alias` as two `col()` expressions.
if left_side {
for node in aexpr_to_root_nodes(expr, expr_arena) {
if !local_projection.contains(&node) {
local_projection.push(node)
}
}
if let AExpr::Alias(_, alias_name) = expr_arena.get(expr) {
let mut add = true;
for node in local_projection.as_slice() {
if let AExpr::Column(col_name) = expr_arena.get(*node) {
if alias_name == col_name {
add = false;
break;
}
}
}
if add {
let node = expr_arena.add(AExpr::Column(alias_name.clone()));
local_projection.push(node);
}
}
}
}

Expand Down
14 changes: 12 additions & 2 deletions polars/polars-lazy/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use crate::prelude::*;
pub(crate) fn det_join_schema(
schema_left: &Schema,
schema_right: &Schema,
left_on: &[Expr],
right_on: &[String],
options: &JoinOptions,
) -> SchemaRef {
) -> Result<SchemaRef> {
// column names of left table
let mut names: PlHashSet<&str> =
PlHashSet::with_capacity(schema_left.len() + schema_right.len());
Expand All @@ -18,6 +19,15 @@ pub(crate) fn det_join_schema(
new_schema.with_column(name.to_string(), dtype.clone())
}

// make sure that expression are assigned to the schema
// an expression can have an alias, and change a dtype.
// we only do this for the left hand side as the right hand side
// is dropped.
for e in left_on {
let field = e.to_field(schema_left, Context::Default)?;
new_schema.with_column(field.name, field.dtype)
}

let right_names: PlHashSet<_> = right_on.iter().map(|s| s.as_str()).collect();

for (name, dtype) in schema_right.iter() {
Expand Down Expand Up @@ -45,5 +55,5 @@ pub(crate) fn det_join_schema(
}
}

Arc::new(new_schema)
Ok(Arc::new(new_schema))
}
8 changes: 4 additions & 4 deletions polars/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl PhysicalPlanner {
..
} => {
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();
let has_windows = expr.iter().any(|node| has_window_aexpr(*node, expr_arena));
let has_windows = expr.iter().any(|node| has_aexpr_window(*node, expr_arena));
let input = self.create_physical_plan(input, lp_arena, expr_arena)?;
let phys_expr =
self.create_physical_expressions(&expr, Context::Default, expr_arena)?;
Expand All @@ -185,7 +185,7 @@ impl PhysicalPlanner {
} => {
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();

let has_windows = expr.iter().any(|node| has_window_aexpr(*node, expr_arena));
let has_windows = expr.iter().any(|node| has_aexpr_window(*node, expr_arena));
let input = self.create_physical_plan(input, lp_arena, expr_arena)?;
let phys_expr =
self.create_physical_expressions(&expr, Context::Default, expr_arena)?;
Expand All @@ -207,7 +207,7 @@ impl PhysicalPlanner {
let has_windows = if let Some(projection) = &projection {
projection
.iter()
.any(|node| has_window_aexpr(*node, expr_arena))
.any(|node| has_aexpr_window(*node, expr_arena))
} else {
false
};
Expand Down Expand Up @@ -493,7 +493,7 @@ impl PhysicalPlanner {
}
HStack { input, exprs, .. } => {
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();
let has_windows = exprs.iter().any(|node| has_window_aexpr(*node, expr_arena));
let has_windows = exprs.iter().any(|node| has_aexpr_window(*node, expr_arena));
let input = self.create_physical_plan(input, lp_arena, expr_arena)?;
let phys_expr =
self.create_physical_expressions(&exprs, Context::Default, expr_arena)?;
Expand Down
6 changes: 5 additions & 1 deletion polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ where
arena.iter(current_node).any(|(_node, e)| matches(e))
}

pub(crate) fn has_window_aexpr(current_node: Node, arena: &Arena<AExpr>) -> bool {
pub(crate) fn has_aexpr_alias(current_node: Node, arena: &Arena<AExpr>) -> bool {
has_aexpr(current_node, arena, |e| matches!(e, AExpr::Alias(_, _)))
}

pub(crate) fn has_aexpr_window(current_node: Node, arena: &Arena<AExpr>) -> bool {
has_aexpr(current_node, arena, |e| matches!(e, AExpr::Window { .. }))
}

Expand Down
58 changes: 58 additions & 0 deletions py-polars/tests/unit/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,61 @@ def test_join_chunks_alignment_4720() -> None:
"index2": [10, 10, 11, 11],
"index3": [100, 101, 100, 101],
}


def test_join_inline_alias_4694() -> None:
df = pl.DataFrame(
[
{"ts": datetime(2021, 2, 1, 9, 20), "a1": 1.04, "a2": 0.9},
{"ts": datetime(2021, 2, 1, 9, 50), "a1": 1.04, "a2": 0.9},
{"ts": datetime(2021, 2, 2, 10, 20), "a1": 1.04, "a2": 0.9},
{"ts": datetime(2021, 2, 2, 11, 20), "a1": 1.08, "a2": 0.9},
{"ts": datetime(2021, 2, 3, 11, 50), "a1": 1.08, "a2": 0.9},
{"ts": datetime(2021, 2, 3, 13, 20), "a1": 1.16, "a2": 0.8},
{"ts": datetime(2021, 2, 4, 13, 50), "a1": 1.18, "a2": 0.8},
]
).lazy()

join_against = pl.DataFrame(
[
{"d": datetime(2021, 2, 3, 0, 0), "ets": datetime(2021, 2, 4, 0, 0)},
{"d": datetime(2021, 2, 3, 0, 0), "ets": datetime(2021, 2, 5, 0, 0)},
{"d": datetime(2021, 2, 3, 0, 0), "ets": datetime(2021, 2, 6, 0, 0)},
]
).lazy()

# this adds "dd" column to the lhs followed by a projection
# the projection optimizer must realize that this column is added inline and ensure
# it is not dropped.
assert df.join(
join_against,
left_on=pl.col("ts").dt.truncate("1d").alias("dd"),
right_on=pl.col("d"),
).select(pl.all()).collect().to_dict(False) == {
"ts": [
datetime(2021, 2, 3, 11, 50),
datetime(2021, 2, 3, 11, 50),
datetime(2021, 2, 3, 11, 50),
datetime(2021, 2, 3, 13, 20),
datetime(2021, 2, 3, 13, 20),
datetime(2021, 2, 3, 13, 20),
],
"a1": [1.08, 1.08, 1.08, 1.16, 1.16, 1.16],
"a2": [0.9, 0.9, 0.9, 0.8, 0.8, 0.8],
"dd": [
datetime(2021, 2, 3, 0, 0),
datetime(2021, 2, 3, 0, 0),
datetime(2021, 2, 3, 0, 0),
datetime(2021, 2, 3, 0, 0),
datetime(2021, 2, 3, 0, 0),
datetime(2021, 2, 3, 0, 0),
],
"ets": [
datetime(2021, 2, 4, 0, 0),
datetime(2021, 2, 5, 0, 0),
datetime(2021, 2, 6, 0, 0),
datetime(2021, 2, 4, 0, 0),
datetime(2021, 2, 5, 0, 0),
datetime(2021, 2, 6, 0, 0),
],
}

0 comments on commit a56ea7c

Please sign in to comment.