Skip to content

Commit

Permalink
Update optimize_projections.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
berkaysynnada committed May 3, 2024
1 parent 5dd4289 commit 8bdd12f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 36 deletions.
90 changes: 64 additions & 26 deletions datafusion/core/src/physical_optimizer/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ impl ProjectionOptimizer {
if all_columns_required(&requirement_map) {
let required_columns = mem::take(&mut self.required_columns);
for c in self.children_nodes.iter_mut() {
c.required_columns = required_columns.clone();
c.required_columns.clone_from(&required_columns);
}
} else {
let plan = self.plan.clone();
Expand Down Expand Up @@ -3058,7 +3058,7 @@ impl PhysicalOptimizerRule for OptimizeProjections {
let mut optimizer = ProjectionOptimizer::new_default(final_schema_determinant);

// Insert the initial requirements to the root node, and run the rule.
optimizer.required_columns = initial_requirements.clone();
optimizer.required_columns.clone_from(&initial_requirements);
let mut optimized = optimizer.transform_down(|o: ProjectionOptimizer| {
o.adjust_node_with_requirements().map(Transformed::yes)
})?;
Expand Down Expand Up @@ -3906,7 +3906,9 @@ fn update_hj_left_child(
};

let mut right_node = children.swap_remove(0);
right_node.required_columns = hj_right_requirements.clone();
right_node
.required_columns
.clone_from(hj_right_requirements);

Ok((new_left_node, right_node))
}
Expand Down Expand Up @@ -3941,7 +3943,7 @@ fn update_hj_right_child(
};

let mut left_node = children.swap_remove(0);
left_node.required_columns = hj_left_requirements.clone();
left_node.required_columns.clone_from(hj_left_requirements);

Ok((left_node, new_right_node))
}
Expand Down Expand Up @@ -4868,6 +4870,7 @@ fn expr_mapping(schema_mapping: IndexMap<Column, Column>) -> ExprMapping {
#[cfg(test)]
mod tests {
use super::*;
use std::any::Any;
use std::sync::Arc;

use crate::datasource::file_format::file_compression_type::FileCompressionType;
Expand All @@ -4886,9 +4889,14 @@ mod tests {

use arrow_schema::{DataType, Field, Schema, SortOptions};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics};
use datafusion_common::{
plan_err, JoinSide, JoinType, Result, ScalarValue, Statistics,
};
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_expr::{Operator, ScalarFunctionDefinition, WindowFrame};
use datafusion_expr::{
ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl,
Signature, Volatility, WindowFrame,
};
use datafusion_physical_expr::expressions::{
rank, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, RowNumber,
Sum,
Expand All @@ -4903,6 +4911,44 @@ mod tests {
use datafusion_physical_plan::union::UnionExec;
use datafusion_physical_plan::InputOrderMode;

#[derive(Debug)]
struct AddOne {
signature: Signature,
}

impl AddOne {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Int32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for AddOne {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"add_one"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if !matches!(args.first(), Some(&DataType::Int32)) {
return plan_err!("add_one only accepts Int32 arguments");
}
Ok(DataType::Int32)
}
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}

fn create_simple_csv_exec() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Expand Down Expand Up @@ -4975,9 +5021,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::Name(
"dummy".to_owned().into_boxed_str().into(),
),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::from(AddOne::new()))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -5043,9 +5087,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::Name(
"dummy".to_owned().into_boxed_str().into(),
),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::from(AddOne::new()))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -5114,9 +5156,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::Name(
"dummy".to_owned().into_boxed_str().into(),
),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::from(AddOne::new()))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -5182,9 +5222,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
ScalarFunctionDefinition::Name(
"dummy".to_owned().into_boxed_str().into(),
),
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::from(AddOne::new()))),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b_new", 1)),
Expand Down Expand Up @@ -5782,7 +5820,7 @@ mod tests {
let initial = get_plan_string(&projection);
let expected_initial = [
"ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]",
" SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]",
" SortExec: expr=[b@1 ASC,c@2 + a@0 ASC], preserve_partitioning=[false]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false"
];
assert_eq!(initial, expected_initial);
Expand All @@ -5792,7 +5830,7 @@ mod tests {

let expected = [
"ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]",
" SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]",
" SortExec: expr=[b@1 ASC,c@2 + a@0 ASC], preserve_partitioning=[false]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c], has_header=false"
];
assert_eq!(get_plan_string(&after_optimize), expected);
Expand Down Expand Up @@ -5962,7 +6000,7 @@ mod tests {
"FilterExec: sum@0 > 0",
" ProjectionExec: expr=[c@2 + x@0 as sum]",
" ProjectionExec: expr=[x@2 as x, x@0 as x, c@1 as c]",
" SortExec: expr=[c@1 ASC,x@2 ASC]",
" SortExec: expr=[c@1 ASC,x@2 ASC], preserve_partitioning=[false]",
" ProjectionExec: expr=[x@1 as x, c@0 as c, a@2 as x]",
" ProjectionExec: expr=[c@2 as c, e@4 as x, a@0 as a]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false"];
Expand All @@ -5975,7 +6013,7 @@ mod tests {
let expected = [
"FilterExec: sum@0 > 0",
" ProjectionExec: expr=[c@0 + x@1 as sum]",
" SortExec: expr=[c@0 ASC,x@1 ASC]",
" SortExec: expr=[c@0 ASC,x@1 ASC], preserve_partitioning=[false]",
" ProjectionExec: expr=[c@1 as c, a@0 as x]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, c], has_header=false"];

Expand Down Expand Up @@ -6091,7 +6129,7 @@ mod tests {
let expected_initial = [
"AggregateExec: mode=Single, gby=[a@2 as a], aggr=[SUM(ROW_NUMBER())]",
" HashJoinExec: mode=Auto, join_type=LeftAnti, on=[(a@1, b@1)], filter=ROW_NUMBER()@0 < RANK@1, projection=[a@1, ROW_NUMBER()@0, a@1]",
" SortExec: expr=[ROW_NUMBER()@4 ASC,d@2 ASC]",
" SortExec: expr=[ROW_NUMBER()@4 ASC,d@2 ASC], preserve_partitioning=[false]",
" ProjectionExec: expr=[ROW_NUMBER()@5 as ROW_NUMBER(), a@0 as a, d@3 as d, d@3 as d, ROW_NUMBER()@5 as ROW_NUMBER()]",
" WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false",
Expand All @@ -6107,7 +6145,7 @@ mod tests {
"AggregateExec: mode=Single, gby=[a@2 as a], aggr=[SUM(ROW_NUMBER())]",
" HashJoinExec: mode=Auto, join_type=LeftAnti, on=[(a@1, b@0)], filter=ROW_NUMBER()@0 < RANK@1, projection=[a@1, ROW_NUMBER()@0, a@1]",
" ProjectionExec: expr=[ROW_NUMBER()@0 as ROW_NUMBER(), a@1 as a, ROW_NUMBER()@3 as ROW_NUMBER()]",
" SortExec: expr=[ROW_NUMBER()@3 ASC,d@2 ASC]",
" SortExec: expr=[ROW_NUMBER()@3 ASC,d@2 ASC], preserve_partitioning=[false]",
" ProjectionExec: expr=[ROW_NUMBER()@5 as ROW_NUMBER(), a@0 as a, d@3 as d, ROW_NUMBER()@5 as ROW_NUMBER()]",
" WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false",
Expand Down Expand Up @@ -6228,7 +6266,7 @@ mod tests {
let expected_initial = [
"AggregateExec: mode=Single, gby=[b@2 as b], aggr=[SUM(a)]",
" HashJoinExec: mode=Auto, join_type=RightSemi, on=[(a@1, b@1)], filter=ROW_NUMBER()@0 < RANK@1, projection=[b@1, a@0, b@1]",
" SortExec: expr=[ROW_NUMBER()@4 ASC,d@2 ASC]",
" SortExec: expr=[ROW_NUMBER()@4 ASC,d@2 ASC], preserve_partitioning=[false]",
" ProjectionExec: expr=[ROW_NUMBER()@5 as ROW_NUMBER(), a@0 as a, d@3 as d, d@3 as d, ROW_NUMBER()@5 as ROW_NUMBER()]",
" WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false",
Expand All @@ -6244,7 +6282,7 @@ mod tests {
"AggregateExec: mode=Single, gby=[b@2 as b], aggr=[SUM(a)]",
" HashJoinExec: mode=Auto, join_type=RightSemi, on=[(a@0, b@1)], filter=ROW_NUMBER()@0 < RANK@1, projection=[b@1, a@0, b@1]",
" ProjectionExec: expr=[a@0 as a, ROW_NUMBER()@2 as ROW_NUMBER()]",
" SortExec: expr=[ROW_NUMBER()@2 ASC,d@1 ASC]",
" SortExec: expr=[ROW_NUMBER()@2 ASC,d@1 ASC], preserve_partitioning=[false]",
" ProjectionExec: expr=[a@0 as a, d@3 as d, ROW_NUMBER()@5 as ROW_NUMBER()]",
" WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]",
" CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false",
Expand Down
Loading

0 comments on commit 8bdd12f

Please sign in to comment.