diff --git a/src/optimizer/logical_plan_rewriter/bool_expr_simplification.rs b/src/optimizer/logical_plan_rewriter/bool_expr_simplification.rs index 095ad991..cbf244a9 100644 --- a/src/optimizer/logical_plan_rewriter/bool_expr_simplification.rs +++ b/src/optimizer/logical_plan_rewriter/bool_expr_simplification.rs @@ -50,7 +50,7 @@ impl PlanRewriter for BoolExprSimplificationRule { let child = self.rewrite(plan.child()); let new_plan = Arc::new(plan.clone_with_rewrite_expr(child, self)); match &new_plan.expr() { - Constant(Bool(false) | Null) => Arc::new(Dummy {}), + Constant(Bool(false) | Null) => Arc::new(Dummy::new(new_plan.schema())), Constant(Bool(true)) => return plan.child().clone(), _ => new_plan, } diff --git a/src/optimizer/plan_nodes/dummy.rs b/src/optimizer/plan_nodes/dummy.rs index 54a05aae..db675799 100644 --- a/src/optimizer/plan_nodes/dummy.rs +++ b/src/optimizer/plan_nodes/dummy.rs @@ -8,10 +8,23 @@ use super::*; /// A dummy plan. #[derive(Debug, Clone, Serialize)] -pub struct Dummy {} +pub struct Dummy { + schema: Vec, +} + +impl Dummy { + pub fn new(schema: Vec) -> Self { + Self { schema } + } +} + impl PlanTreeNodeLeaf for Dummy {} impl_plan_tree_node_for_leaf!(Dummy); -impl PlanNode for Dummy {} +impl PlanNode for Dummy { + fn schema(&self) -> Vec { + self.schema.clone() + } +} impl fmt::Display for Dummy { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "Dummy:") diff --git a/src/optimizer/plan_nodes/internal.rs b/src/optimizer/plan_nodes/internal.rs index f796bff9..8c85c909 100644 --- a/src/optimizer/plan_nodes/internal.rs +++ b/src/optimizer/plan_nodes/internal.rs @@ -60,6 +60,20 @@ impl PlanNode for Internal { fn schema(&self) -> Vec { self.column_descs.clone() } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let (column_ids, column_descs) = required_cols + .iter() + .map(|col_idx| (self.column_ids[col_idx], self.column_descs[col_idx].clone())) + .unzip(); + Internal::new( + self.table_name.clone(), + self.table_ref_id, + column_ids, + column_descs, + ) + .into_plan_ref() + } } impl fmt::Display for Internal { diff --git a/src/optimizer/plan_nodes/logical_aggregate.rs b/src/optimizer/plan_nodes/logical_aggregate.rs index 06c9e4db..2a058f02 100644 --- a/src/optimizer/plan_nodes/logical_aggregate.rs +++ b/src/optimizer/plan_nodes/logical_aggregate.rs @@ -5,7 +5,7 @@ use std::fmt; use serde::Serialize; use super::*; -use crate::binder::{BoundAggCall, BoundExpr}; +use crate::binder::{BoundAggCall, BoundExpr, ExprVisitor}; use crate::optimizer::logical_plan_rewriter::ExprRewriter; /// The logical plan of hash aggregate operation. @@ -90,6 +90,76 @@ impl PlanNode for LogicalAggregate { fn estimated_cardinality(&self) -> usize { self.child().estimated_cardinality() } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let group_keys_len = self.group_keys.len(); + + // Collect ref_idx of AggCall args + let mut visitor = + CollectRequiredCols(BitSet::with_capacity(group_keys_len + self.agg_calls.len())); + let mut new_agg_calls: Vec<_> = required_cols + .iter() + .filter(|&index| index >= group_keys_len) + .map(|index| { + let call = &self.agg_calls[index - group_keys_len]; + call.args.iter().for_each(|expr| { + visitor.visit_expr(expr); + }); + self.agg_calls[index - group_keys_len].clone() + }) + .collect(); + + // Collect ref_idx of GroupExpr + self.group_keys + .iter() + .for_each(|group| visitor.visit_expr(group)); + + let input_cols = visitor.0; + + let mapper = Mapper::new_with_bitset(&input_cols); + for call in &mut new_agg_calls { + call.args.iter_mut().for_each(|expr| { + mapper.rewrite_expr(expr); + }) + } + + let mut group_keys = self.group_keys.clone(); + group_keys + .iter_mut() + .for_each(|expr| mapper.rewrite_expr(expr)); + + let new_agg = LogicalAggregate::new( + new_agg_calls.clone(), + group_keys, + self.child.prune_col(input_cols), + ); + + let bitset = BitSet::from_iter(0..group_keys_len); + + if bitset.is_subset(&required_cols) { + new_agg.into_plan_ref() + } else { + // Need prune + let mut new_projection: Vec = required_cols + .iter() + .filter(|&i| i < group_keys_len) + .map(|index| { + BoundExpr::InputRef(BoundInputRef { + index, + return_type: self.group_keys[index].return_type().unwrap(), + }) + }) + .collect(); + + for (index, item) in new_agg_calls.iter().enumerate() { + new_projection.push(BoundExpr::InputRef(BoundInputRef { + index: group_keys_len + index, + return_type: item.return_type.clone(), + })) + } + LogicalProjection::new(new_projection, new_agg.into_plan_ref()).into_plan_ref() + } + } } impl fmt::Display for LogicalAggregate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -129,7 +199,7 @@ mod tests { }, ], vec![], - Arc::new(Dummy {}), + Arc::new(Dummy::new(Vec::new())), ); let column_names = plan.out_names(); @@ -138,4 +208,88 @@ mod tests { assert_eq!(column_names[2], "count"); assert_eq!(column_names[3], "count"); } + + #[test] + /// Pruning + /// ```text + /// Agg(gk: input_ref(2), call: sum(input_ref(0)), avg(input_ref(1))) + /// TableScan(v1, v2, v3) + /// ``` + /// with required columns [2] will result in + /// ```text + /// Projection(input_ref(1)) + /// Agg(gk: input_ref(1), call: avg(input_ref(0))) + /// TableScan(v1, v3) + /// ``` + fn test_prune_aggregate() { + let ty = DataTypeKind::Int(None).not_null(); + let col_descs = vec![ + ty.clone().to_column("v1".into()), + ty.clone().to_column("v2".into()), + ty.clone().to_column("v3".into()), + ]; + + let table_scan = LogicalTableScan::new( + crate::catalog::TableRefId { + database_id: 0, + schema_id: 0, + table_id: 0, + }, + vec![1, 2, 3], + col_descs, + false, + false, + None, + ); + + let input_refs = vec![ + BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: ty.clone(), + }), + BoundExpr::InputRef(BoundInputRef { + index: 1, + return_type: ty.clone(), + }), + BoundExpr::InputRef(BoundInputRef { + index: 2, + return_type: ty, + }), + ]; + + let aggregate = LogicalAggregate::new( + vec![ + BoundAggCall { + kind: AggKind::Sum, + args: vec![input_refs[0].clone()], + return_type: DataTypeKind::Int(None).not_null(), + }, + BoundAggCall { + kind: AggKind::Avg, + args: vec![input_refs[1].clone()], + return_type: DataTypeKind::Int(None).not_null(), + }, + ], + vec![input_refs[2].clone()], + Arc::new(table_scan), + ); + + let mut required_cols = BitSet::new(); + required_cols.insert(2); + let plan = aggregate.prune_col(required_cols); + let plan = plan.as_logical_projection().unwrap(); + + assert_eq!(plan.project_expressions(), vec![input_refs[1].clone()]); + let plan = plan.child(); + let plan = plan.as_logical_aggregate().unwrap(); + + assert_eq!( + plan.agg_calls(), + vec![BoundAggCall { + kind: AggKind::Avg, + args: vec![input_refs[0].clone()], + return_type: DataTypeKind::Int(None).not_null(), + }] + ); + } } diff --git a/src/optimizer/plan_nodes/logical_copy_to_file.rs b/src/optimizer/plan_nodes/logical_copy_to_file.rs index 35742f67..f9fa2aa8 100644 --- a/src/optimizer/plan_nodes/logical_copy_to_file.rs +++ b/src/optimizer/plan_nodes/logical_copy_to_file.rs @@ -66,7 +66,13 @@ impl PlanTreeNodeUnary for LogicalCopyToFile { } } impl_plan_tree_node_for_unary!(LogicalCopyToFile); -impl PlanNode for LogicalCopyToFile {} +impl PlanNode for LogicalCopyToFile { + fn prune_col(&self, _required_cols: BitSet) -> PlanRef { + let input_cols = (0..self.child().out_types().len()).into_iter().collect(); + self.clone_with_child(self.child.prune_col(input_cols)) + .into_plan_ref() + } +} impl fmt::Display for LogicalCopyToFile { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/src/optimizer/plan_nodes/logical_delete.rs b/src/optimizer/plan_nodes/logical_delete.rs index f25730b4..efa6f0da 100644 --- a/src/optimizer/plan_nodes/logical_delete.rs +++ b/src/optimizer/plan_nodes/logical_delete.rs @@ -46,6 +46,12 @@ impl PlanNode for LogicalDelete { false, )] } + + fn prune_col(&self, _required_cols: BitSet) -> PlanRef { + let input_cols = (0..self.child().out_types().len()).into_iter().collect(); + self.clone_with_child(self.child.prune_col(input_cols)) + .into_plan_ref() + } } impl fmt::Display for LogicalDelete { diff --git a/src/optimizer/plan_nodes/logical_filter.rs b/src/optimizer/plan_nodes/logical_filter.rs index 883678e6..a3401703 100644 --- a/src/optimizer/plan_nodes/logical_filter.rs +++ b/src/optimizer/plan_nodes/logical_filter.rs @@ -1,6 +1,5 @@ // Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0. -use std::collections::HashMap; use std::fmt; use serde::Serialize; @@ -56,51 +55,31 @@ impl PlanNode for LogicalFilter { } fn prune_col(&self, required_cols: BitSet) -> PlanRef { - struct CollectRequiredCols(BitSet); - impl ExprVisitor for CollectRequiredCols { - fn visit_input_ref(&mut self, expr: &BoundInputRef) { - self.0.insert(expr.index); - } - } let mut visitor = CollectRequiredCols(required_cols.clone()); visitor.visit_expr(&self.expr); let input_cols = visitor.0; - let mut idx_table = HashMap::new(); - for (new_idx, old_idx) in input_cols.iter().enumerate() { - idx_table.insert(old_idx, new_idx); - } - - struct Mapper(HashMap); - impl ExprRewriter for Mapper { - fn rewrite_input_ref(&self, expr: &mut BoundExpr) { - match expr { - BoundExpr::InputRef(ref mut input_ref) => { - input_ref.index = self.0[&input_ref.index]; - } - _ => unreachable!(), - } - } - } - let mut expr = self.expr.clone(); - Mapper(idx_table.clone()).rewrite_expr(&mut expr); + let mapper = Mapper::new_with_bitset(&input_cols); + mapper.rewrite_expr(&mut expr); + let need_prune = required_cols != input_cols; let new_filter = Self { expr, - child: self.child.prune_col(input_cols.clone()), + child: self.child.prune_col(input_cols), } .into_plan_ref(); - if required_cols == input_cols { + if !need_prune { return new_filter; } + let input_types = self.out_types(); let exprs = required_cols .iter() .map(|old_idx| { BoundExpr::InputRef(BoundInputRef { - index: idx_table[&old_idx], + index: mapper[old_idx], return_type: input_types[old_idx].clone(), }) }) diff --git a/src/optimizer/plan_nodes/logical_insert.rs b/src/optimizer/plan_nodes/logical_insert.rs index da7a0322..c38b4d1f 100644 --- a/src/optimizer/plan_nodes/logical_insert.rs +++ b/src/optimizer/plan_nodes/logical_insert.rs @@ -55,6 +55,16 @@ impl PlanNode for LogicalInsert { false, )] } + + fn prune_col(&self, _required_cols: BitSet) -> PlanRef { + let input_cols = self + .column_ids + .iter() + .map(|&column_id| column_id as usize) + .collect(); + self.clone_with_child(self.child.prune_col(input_cols)) + .into_plan_ref() + } } impl fmt::Display for LogicalInsert { diff --git a/src/optimizer/plan_nodes/logical_join.rs b/src/optimizer/plan_nodes/logical_join.rs index ea549bc3..acf036e8 100644 --- a/src/optimizer/plan_nodes/logical_join.rs +++ b/src/optimizer/plan_nodes/logical_join.rs @@ -5,7 +5,7 @@ use std::fmt; use serde::Serialize; use super::*; -use crate::binder::BoundJoinOperator; +use crate::binder::{BoundJoinOperator, ExprVisitor}; use crate::optimizer::logical_plan_rewriter::ExprRewriter; /// The logical plan of join, it only records join tables and operators. @@ -98,6 +98,52 @@ impl PlanNode for LogicalJoin { fn estimated_cardinality(&self) -> usize { self.left().estimated_cardinality() * self.right().estimated_cardinality() } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let mut on_clause = self.predicate.to_on_clause(); + let mut visitor = CollectRequiredCols(required_cols.clone()); + visitor.visit_expr(&on_clause); + let input_cols = visitor.0; + + let mapper = Mapper::new_with_bitset(&input_cols); + mapper.rewrite_expr(&mut on_clause); + + let left_schema_len = self.left_plan.out_types().len(); + let left_input_cols = input_cols + .iter() + .filter(|&col_idx| col_idx < left_schema_len) + .collect::(); + let right_input_cols = input_cols + .iter() + .filter(|&col_idx| col_idx >= left_schema_len) + .map(|col_idx| col_idx - left_schema_len) + .collect(); + + let join_predicate = JoinPredicate::create(left_input_cols.len(), on_clause); + + let new_join = LogicalJoin::new( + self.left_plan.prune_col(left_input_cols), + self.right_plan.prune_col(right_input_cols), + self.join_op, + join_predicate, + ); + + if required_cols == input_cols { + new_join.into_plan_ref() + } else { + let out_types = self.out_types(); + let project_expressions = required_cols + .iter() + .map(|col_idx| { + BoundExpr::InputRef(BoundInputRef { + index: mapper[col_idx], + return_type: out_types[col_idx].clone(), + }) + }) + .collect(); + LogicalProjection::new(project_expressions, new_join.into_plan_ref()).into_plan_ref() + } + } } impl fmt::Display for LogicalJoin { diff --git a/src/optimizer/plan_nodes/logical_limit.rs b/src/optimizer/plan_nodes/logical_limit.rs index 8b186b65..f1c6febe 100644 --- a/src/optimizer/plan_nodes/logical_limit.rs +++ b/src/optimizer/plan_nodes/logical_limit.rs @@ -51,6 +51,11 @@ impl PlanNode for LogicalLimit { fn schema(&self) -> Vec { self.child.schema() } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + LogicalLimit::new(self.offset, self.limit, self.child.prune_col(required_cols)) + .into_plan_ref() + } } impl fmt::Display for LogicalLimit { diff --git a/src/optimizer/plan_nodes/logical_order.rs b/src/optimizer/plan_nodes/logical_order.rs index b1138250..616a5fc4 100644 --- a/src/optimizer/plan_nodes/logical_order.rs +++ b/src/optimizer/plan_nodes/logical_order.rs @@ -5,7 +5,7 @@ use std::fmt; use serde::Serialize; use super::*; -use crate::binder::BoundOrderBy; +use crate::binder::{BoundOrderBy, ExprVisitor}; use crate::optimizer::logical_plan_rewriter::ExprRewriter; /// The logical plan of order. @@ -54,6 +54,26 @@ impl PlanNode for LogicalOrder { fn estimated_cardinality(&self) -> usize { self.child().estimated_cardinality() } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let mut visitor = CollectRequiredCols(required_cols); + for node in &self.comparators { + visitor.visit_expr(&node.expr); + } + let input_cols = visitor.0; + + let mapper = Mapper::new_with_bitset(&input_cols); + let mut comparators = self.comparators.clone(); + for node in &mut comparators { + mapper.rewrite_expr(&mut node.expr); + } + + Self { + comparators, + child: self.child.prune_col(input_cols), + } + .into_plan_ref() + } } impl fmt::Display for LogicalOrder { @@ -61,3 +81,81 @@ impl fmt::Display for LogicalOrder { writeln!(f, "LogicalOrder: {:?}", self.comparators) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{DataTypeExt, DataTypeKind}; + + #[test] + fn test_prune_order() { + let ty = DataTypeKind::Int(None).not_null(); + let col_descs = vec![ + ty.clone().to_column("v1".into()), + ty.clone().to_column("v2".into()), + ty.clone().to_column("v3".into()), + ]; + let table_scan = LogicalTableScan::new( + crate::catalog::TableRefId { + database_id: 0, + schema_id: 0, + table_id: 0, + }, + vec![1, 2, 3], + col_descs.clone(), + false, + false, + None, + ); + + let project_expressions = vec![ + BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: ty.clone(), + }), + BoundExpr::InputRef(BoundInputRef { + index: 1, + return_type: ty.clone(), + }), + BoundExpr::InputRef(BoundInputRef { + index: 2, + return_type: ty.clone(), + }), + ]; + + let projection = LogicalProjection::new(project_expressions, table_scan.into_plan_ref()); + + let node = vec![BoundOrderBy { + expr: BoundExpr::InputRef(BoundInputRef { + index: 1, + return_type: ty.clone(), + }), + descending: false, + }]; + + let orderby = LogicalOrder::new(node, projection.into_plan_ref()); + + let mut required_cols = BitSet::new(); + required_cols.insert(2); + let plan = orderby.prune_col(required_cols); + let orderby = plan.as_logical_order().unwrap(); + + assert_eq!( + orderby.comparators, + vec![BoundOrderBy { + expr: BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: ty, + }), + descending: false, + }] + ); + + let plan = orderby.child(); + let projection = plan.as_logical_projection().unwrap(); + let plan = projection.child(); + let table_scan = plan.as_logical_table_scan().unwrap(); + assert_eq!(table_scan.column_descs(), &col_descs[1..3]); + assert_eq!(table_scan.column_ids(), &[2, 3]); + } +} diff --git a/src/optimizer/plan_nodes/logical_projection.rs b/src/optimizer/plan_nodes/logical_projection.rs index 8fe64eef..92663e5a 100644 --- a/src/optimizer/plan_nodes/logical_projection.rs +++ b/src/optimizer/plan_nodes/logical_projection.rs @@ -5,7 +5,7 @@ use std::fmt; use serde::Serialize; use super::*; -use crate::binder::BoundExpr; +use crate::binder::{BoundExpr, ExprVisitor}; use crate::optimizer::logical_plan_rewriter::ExprRewriter; /// The logical plan of project operation. @@ -106,6 +106,28 @@ impl PlanNode for LogicalProjection { fn estimated_cardinality(&self) -> usize { self.child().estimated_cardinality() } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let mut new_projection_expressions: Vec = required_cols + .iter() + .map(|index| self.project_expressions[index].clone()) + .collect(); + + let mut visitor = CollectRequiredCols(BitSet::with_capacity(required_cols.len())); + new_projection_expressions + .iter() + .for_each(|expr| visitor.visit_expr(expr)); + + let input_cols = visitor.0; + + let mapper = Mapper::new_with_bitset(&input_cols); + new_projection_expressions.iter_mut().for_each(|expr| { + mapper.rewrite_expr(expr); + }); + + LogicalProjection::new(new_projection_expressions, self.child.prune_col(input_cols)) + .into_plan_ref() + } } impl fmt::Display for LogicalProjection { @@ -140,7 +162,7 @@ mod tests { }), BoundExpr::Constant(DataValue::Int32(0)), ], - Arc::new(Dummy {}), + Arc::new(Dummy::new(Vec::new())), ); let column_names = plan.out_names(); @@ -169,7 +191,7 @@ mod tests { }), BoundExpr::Constant(DataValue::Int32(0)), ], - Arc::new(Dummy {}), + Arc::new(Dummy::new(Vec::new())), ); let outer = LogicalProjection::new( @@ -199,4 +221,58 @@ mod tests { assert_eq!(outermost.out_names()[0], "v1"); assert!(outermost.child.as_dummy().is_ok()); } + + #[test] + fn test_prune_projection() { + let ty = DataTypeKind::Int(None).not_null(); + let col_descs = vec![ + ty.clone().to_column("v1".into()), + ty.clone().to_column("v2".into()), + ty.clone().to_column("v3".into()), + ]; + let table_scan = LogicalTableScan::new( + crate::catalog::TableRefId { + database_id: 0, + schema_id: 0, + table_id: 0, + }, + vec![1, 2, 3], + col_descs.clone(), + false, + false, + None, + ); + let project_expressions = vec![ + BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: ty.clone(), + }), + BoundExpr::InputRef(BoundInputRef { + index: 1, + return_type: ty.clone(), + }), + BoundExpr::InputRef(BoundInputRef { + index: 2, + return_type: ty.clone(), + }), + ]; + let projection = LogicalProjection::new(project_expressions, table_scan.into_plan_ref()); + + let mut required_cols = BitSet::new(); + required_cols.insert(1); + + let plan = projection.prune_col(required_cols); + let plan = plan.as_logical_projection().unwrap(); + assert_eq!(1, plan.project_expressions().len()); + assert_eq!( + plan.project_expressions()[0], + BoundExpr::InputRef(BoundInputRef { + index: 0, + return_type: ty, + }) + ); + let child = plan.child.as_logical_table_scan().unwrap(); + assert_eq!(child.column_descs(), &col_descs[1..2]); + assert_eq!(child.column_ids(), &[2]); + } } diff --git a/src/optimizer/plan_nodes/logical_table_scan.rs b/src/optimizer/plan_nodes/logical_table_scan.rs index 92747bc4..731e8ba6 100644 --- a/src/optimizer/plan_nodes/logical_table_scan.rs +++ b/src/optimizer/plan_nodes/logical_table_scan.rs @@ -1,12 +1,15 @@ // Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0. +use std::collections::HashMap; use std::fmt; use itertools::Itertools; use serde::Serialize; use super::*; +use crate::binder::ExprVisitor; use crate::catalog::{ColumnDesc, TableRefId}; +use crate::optimizer::logical_plan_rewriter::ExprRewriter; use crate::types::ColumnId; /// The logical plan of sequential scan operation. #[derive(Debug, Clone, Serialize)] @@ -82,24 +85,81 @@ impl PlanNode for LogicalTableScan { } fn prune_col(&self, required_cols: BitSet) -> PlanRef { - let (column_ids, column_descs) = required_cols + let mut visitor = CollectRequiredCols(BitSet::new()); + if self.expr.is_some() { + visitor.visit_expr(self.expr.as_ref().unwrap()); + } + let filter_cols = visitor.0; + + let mut need_rewrite = false; + + if !filter_cols.is_empty() + && filter_cols + .iter() + .any(|index| !required_cols.contains(index)) + { + need_rewrite = true; + } + + let mut idx_table = HashMap::new(); + let (mut column_ids, mut column_descs): (Vec<_>, Vec<_>) = required_cols .iter() + .filter(|&id| id < self.column_ids.len()) .map(|id| { + idx_table.insert(id, idx_table.len()); ( self.column_ids[id as usize], self.column_descs[id as usize].clone(), ) }) .unzip(); - Self { + + let mut offset = column_ids.len(); + let mut expr = self.expr.clone(); + if need_rewrite { + let (f_column_ids, f_column_descs): (Vec<_>, Vec<_>) = filter_cols + .iter() + .filter(|&id| id < self.column_ids.len() && !required_cols.contains(id)) + .map(|id| { + idx_table.insert(id, offset); + offset += 1; + ( + self.column_ids[id as usize], + self.column_descs[id as usize].clone(), + ) + }) + .unzip(); + + column_ids.extend(f_column_ids.into_iter()); + column_descs.extend(f_column_descs.into_iter()); + + Mapper(idx_table).rewrite_expr(expr.as_mut().unwrap()); + } + + let new_scan = Self { table_ref_id: self.table_ref_id, column_ids, - column_descs, + column_descs: column_descs.clone(), with_row_handler: self.with_row_handler, is_sorted: self.is_sorted, - expr: self.expr.clone(), + expr, + } + .into_plan_ref(); + + if need_rewrite { + let project_expressions = (0..required_cols.len()) + .into_iter() + .map(|index| { + BoundExpr::InputRef(BoundInputRef { + index, + return_type: column_descs[index].datatype().clone(), + }) + }) + .collect(); + LogicalProjection::new(project_expressions, new_scan).into_plan_ref() + } else { + new_scan } - .into_plan_ref() } } impl fmt::Display for LogicalTableScan { diff --git a/src/optimizer/plan_nodes/logical_top_n.rs b/src/optimizer/plan_nodes/logical_top_n.rs index fd894176..0daf9421 100644 --- a/src/optimizer/plan_nodes/logical_top_n.rs +++ b/src/optimizer/plan_nodes/logical_top_n.rs @@ -6,6 +6,8 @@ use serde::Serialize; use super::*; use crate::binder::statement::BoundOrderBy; +use crate::binder::ExprVisitor; +use crate::optimizer::logical_plan_rewriter::ExprRewriter; /// The logical plan of top N operation. #[derive(Debug, Clone, Serialize)] @@ -69,6 +71,49 @@ impl PlanNode for LogicalTopN { fn estimated_cardinality(&self) -> usize { self.limit } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let mut visitor = CollectRequiredCols(required_cols.clone()); + + self.comparators + .iter() + .for_each(|node| visitor.visit_expr(&node.expr)); + + let input_cols = visitor.0; + + let mapper = Mapper::new_with_bitset(&input_cols); + let mut comparators = self.comparators.clone(); + + comparators + .iter_mut() + .for_each(|node| mapper.rewrite_expr(&mut node.expr)); + + let need_prune = input_cols != required_cols; + + let new_topn = LogicalTopN::new( + self.offset, + self.limit, + comparators, + self.child.prune_col(input_cols), + ) + .into_plan_ref(); + + if !need_prune { + new_topn + } else { + let out_types = self.out_types(); + let project_expressions = required_cols + .iter() + .map(|col_idx| { + BoundExpr::InputRef(BoundInputRef { + index: mapper[col_idx], + return_type: out_types[col_idx].clone(), + }) + }) + .collect(); + LogicalProjection::new(project_expressions, new_topn).into_plan_ref() + } + } } impl fmt::Display for LogicalTopN { diff --git a/src/optimizer/plan_nodes/logical_values.rs b/src/optimizer/plan_nodes/logical_values.rs index 410b4bdf..3c3d26b6 100644 --- a/src/optimizer/plan_nodes/logical_values.rs +++ b/src/optimizer/plan_nodes/logical_values.rs @@ -47,7 +47,37 @@ impl LogicalValues { impl PlanTreeNodeLeaf for LogicalValues {} impl_plan_tree_node_for_leaf!(LogicalValues); -impl PlanNode for LogicalValues {} +impl PlanNode for LogicalValues { + fn schema(&self) -> Vec { + self.values[0] + .iter() + .map(|expr| { + let name = "?column?".to_string(); + expr.return_type().unwrap().to_column(name) + }) + .collect() + } + + fn prune_col(&self, required_cols: BitSet) -> PlanRef { + let types: Vec<_> = required_cols + .iter() + .map(|index| self.column_types[index].clone()) + .collect(); + + let new_values: Vec<_> = self + .values + .iter() + .map(|row_expr| { + required_cols + .iter() + .map(|index| row_expr[index].clone()) + .collect() + }) + .collect(); + + LogicalValues::new(types, new_values).into_plan_ref() + } +} impl fmt::Display for LogicalValues { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/src/optimizer/plan_nodes/mod.rs b/src/optimizer/plan_nodes/mod.rs index 49dc26ca..50fc8338 100644 --- a/src/optimizer/plan_nodes/mod.rs +++ b/src/optimizer/plan_nodes/mod.rs @@ -13,7 +13,9 @@ //! * [Callbacks](https://adventures.michaelfbryan.com/posts/non-trivial-macros/#callbacks) //! * [Type Exercise in Rust (Day 4)](https://github.com/skyzh/type-exercise-in-rust/blob/master/archive/day4/src/macros.rs) +use std::collections::HashMap; use std::fmt::{Debug, Display}; +use std::ops::Index; use std::sync::Arc; use bit_set::BitSet; @@ -21,7 +23,7 @@ use downcast_rs::{impl_downcast, Downcast}; use erased_serde::serialize_trait_object; use paste::paste; -use crate::binder::{BoundExpr, BoundInputRef}; +use crate::binder::{BoundExpr, BoundInputRef, ExprVisitor}; use crate::types::DataType; mod plan_tree_node; @@ -105,6 +107,7 @@ pub use physical_table_scan::*; pub use physical_top_n::*; pub use physical_values::*; +use super::logical_plan_rewriter::ExprRewriter; use crate::catalog::ColumnDesc; /// The upcast trait for `PlanNode`. @@ -299,3 +302,41 @@ macro_rules! impl_into_plan_ref { } } for_all_plan_nodes! {impl_into_plan_ref } + +struct CollectRequiredCols(BitSet); +impl ExprVisitor for CollectRequiredCols { + fn visit_input_ref(&mut self, expr: &BoundInputRef) { + self.0.insert(expr.index); + } +} + +struct Mapper(HashMap); + +impl Mapper { + fn new_with_bitset(bitset: &BitSet) -> Self { + let mut idx_table = HashMap::with_capacity(bitset.len()); + for (new_idx, old_idx) in bitset.iter().enumerate() { + idx_table.insert(old_idx, new_idx); + } + Self(idx_table) + } +} + +impl ExprRewriter for Mapper { + fn rewrite_input_ref(&self, expr: &mut BoundExpr) { + match expr { + BoundExpr::InputRef(ref mut input_ref) => { + input_ref.index = self.0[&input_ref.index]; + } + _ => unreachable!(), + } + } +} + +impl Index for Mapper { + type Output = usize; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[&index] + } +} diff --git a/tests/planner_test/tpch.planner.sql b/tests/planner_test/tpch.planner.sql index 295447fa..70ea789c 100644 --- a/tests/planner_test/tpch.planner.sql +++ b/tests/planner_test/tpch.planner.sql @@ -121,10 +121,10 @@ PhysicalOrder: InputRef #3 (alias to sum_base_price) InputRef #4 (alias to sum_disc_price) InputRef #5 (alias to sum_charge) - (InputRef #2 / InputRef #7) (alias to avg_qty) - (InputRef #3 / InputRef #9) (alias to avg_price) - (InputRef #10 / InputRef #11) (alias to avg_disc) - InputRef #12 (alias to count_order) + (InputRef #2 / InputRef #6) (alias to avg_qty) + (InputRef #3 / InputRef #7) (alias to avg_price) + (InputRef #8 / InputRef #9) (alias to avg_disc) + InputRef #10 (alias to count_order) PhysicalHashAgg: InputRef #1 InputRef #2 @@ -132,9 +132,7 @@ PhysicalOrder: sum(InputRef #4) -> NUMERIC(15,2) sum((InputRef #4 * (1 - InputRef #5))) -> NUMERIC(15,2) (null) sum(((InputRef #4 * (1 - InputRef #5)) * (1 + InputRef #6))) -> NUMERIC(15,2) (null) - sum(InputRef #3) -> NUMERIC(15,2) count(InputRef #3) -> INT - sum(InputRef #4) -> NUMERIC(15,2) count(InputRef #4) -> INT sum(InputRef #5) -> NUMERIC(15,2) count(InputRef #5) -> INT @@ -180,34 +178,50 @@ PhysicalTopN: offset: 0, limit: 10, order by [InputRef #1 (desc), InputRef #2 (a InputRef #1 InputRef #2 PhysicalHashAgg: - InputRef #6 - InputRef #4 - InputRef #5 - sum((InputRef #8 * (1 - InputRef #9))) -> NUMERIC(15,2) (null) - PhysicalHashJoin: - op Inner, - predicate: Eq(InputRef #3, InputRef #6) + InputRef #2 + InputRef #0 + InputRef #1 + sum((InputRef #3 * (1 - InputRef #4))) -> NUMERIC(15,2) (null) + PhysicalProjection: + InputRef #1 + InputRef #2 + InputRef #3 + InputRef #4 + InputRef #5 PhysicalHashJoin: op Inner, - predicate: Eq(InputRef #1, InputRef #2) - PhysicalTableScan: - table #5, - columns [6, 0], - with_row_handler: false, - is_sorted: false, - expr: Eq(InputRef #0, String("BUILDING") (const)) - PhysicalTableScan: - table #6, - columns [1, 0, 4, 7], - with_row_handler: false, - is_sorted: false, - expr: Lt(InputRef #2, Date(Date(9204)) (const)) - PhysicalTableScan: - table #7, - columns [0, 10, 5, 6], - with_row_handler: false, - is_sorted: false, - expr: Gt(InputRef #1, Date(Date(9204)) (const)) + predicate: Eq(InputRef #0, InputRef #3) + PhysicalProjection: + InputRef #2 + InputRef #3 + InputRef #4 + PhysicalHashJoin: + op Inner, + predicate: Eq(InputRef #0, InputRef #1) + PhysicalProjection: + InputRef #0 + PhysicalTableScan: + table #5, + columns [0, 6], + with_row_handler: false, + is_sorted: false, + expr: Eq(InputRef #1, String("BUILDING") (const)) + PhysicalTableScan: + table #6, + columns [1, 0, 4, 7], + with_row_handler: false, + is_sorted: false, + expr: Lt(InputRef #2, Date(Date(9204)) (const)) + PhysicalProjection: + InputRef #0 + InputRef #1 + InputRef #2 + PhysicalTableScan: + table #7, + columns [0, 5, 6, 10], + with_row_handler: false, + is_sorted: false, + expr: Gt(InputRef #3, Date(Date(9204)) (const)) */ -- tpch-q5: TPC-H Q5 @@ -243,59 +257,85 @@ PhysicalOrder: InputRef #0 InputRef #1 (alias to revenue) PhysicalHashAgg: - InputRef #13 - sum((InputRef #7 * (1 - InputRef #8))) -> NUMERIC(15,2) (null) - PhysicalHashJoin: - op Inner, - predicate: Eq(InputRef #12, InputRef #14) + InputRef #2 + sum((InputRef #0 * (1 - InputRef #1))) -> NUMERIC(15,2) (null) + PhysicalProjection: + InputRef #0 + InputRef #1 + InputRef #3 PhysicalHashJoin: op Inner, - predicate: Eq(InputRef #10, InputRef #11) - PhysicalHashJoin: - op Inner, - predicate: And(Eq(InputRef #6, InputRef #9), Eq(InputRef #1, InputRef #10)) + predicate: Eq(InputRef #2, InputRef #4) + PhysicalProjection: + InputRef #0 + InputRef #1 + InputRef #4 + InputRef #5 PhysicalHashJoin: op Inner, - predicate: Eq(InputRef #3, InputRef #5) - PhysicalHashJoin: - op Inner, - predicate: Eq(InputRef #0, InputRef #2) - PhysicalTableScan: - table #5, - columns [0, 3], - with_row_handler: false, - is_sorted: false, - expr: None - PhysicalTableScan: - table #6, - columns [1, 0, 4], - with_row_handler: false, - is_sorted: false, - expr: And(GtEq(InputRef #2, Date(Date(8766)) (const)), Lt(InputRef #2, Date(Date(9131)) (const))) + predicate: Eq(InputRef #2, InputRef #3) + PhysicalProjection: + InputRef #2 + InputRef #3 + InputRef #5 + PhysicalHashJoin: + op Inner, + predicate: And(Eq(InputRef #1, InputRef #4), Eq(InputRef #0, InputRef #5)) + PhysicalProjection: + InputRef #0 + InputRef #3 + InputRef #4 + InputRef #5 + PhysicalHashJoin: + op Inner, + predicate: Eq(InputRef #1, InputRef #2) + PhysicalProjection: + InputRef #1 + InputRef #3 + PhysicalHashJoin: + op Inner, + predicate: Eq(InputRef #0, InputRef #2) + PhysicalTableScan: + table #5, + columns [0, 3], + with_row_handler: false, + is_sorted: false, + expr: None + PhysicalProjection: + InputRef #0 + InputRef #1 + PhysicalTableScan: + table #6, + columns [1, 0, 4], + with_row_handler: false, + is_sorted: false, + expr: And(GtEq(InputRef #2, Date(Date(8766)) (const)), Lt(InputRef #2, Date(Date(9131)) (const))) + PhysicalTableScan: + table #7, + columns [0, 2, 5, 6], + with_row_handler: false, + is_sorted: false, + expr: None + PhysicalTableScan: + table #3, + columns [0, 3], + with_row_handler: false, + is_sorted: false, + expr: None PhysicalTableScan: - table #7, - columns [0, 2, 5, 6], + table #0, + columns [0, 2, 1], with_row_handler: false, is_sorted: false, expr: None + PhysicalProjection: + InputRef #0 PhysicalTableScan: - table #3, - columns [0, 3], + table #1, + columns [0, 1], with_row_handler: false, is_sorted: false, - expr: None - PhysicalTableScan: - table #0, - columns [0, 2, 1], - with_row_handler: false, - is_sorted: false, - expr: None - PhysicalTableScan: - table #1, - columns [0, 1], - with_row_handler: false, - is_sorted: false, - expr: Eq(InputRef #1, String("AFRICA") (const)) + expr: Eq(InputRef #1, String("AFRICA") (const)) */ -- tpch-q6 @@ -313,13 +353,16 @@ where PhysicalProjection: InputRef #0 (alias to revenue) PhysicalSimpleAgg: - sum((InputRef #3 * InputRef #1)) -> NUMERIC(15,2) (null) - PhysicalTableScan: - table #7, - columns [10, 6, 4, 5], - with_row_handler: false, - is_sorted: false, - expr: And(And(And(GtEq(InputRef #0, Date(Date(8766)) (const)), Lt(InputRef #0, Date(Date(9131)) (const))), And(GtEq(InputRef #1, Decimal(0.07) (const)), LtEq(InputRef #1, Decimal(0.09) (const)))), Lt(InputRef #2, Decimal(24) (const))) + sum((InputRef #1 * InputRef #0)) -> NUMERIC(15,2) (null) + PhysicalProjection: + InputRef #0 + InputRef #1 + PhysicalTableScan: + table #7, + columns [6, 5, 10, 4], + with_row_handler: false, + is_sorted: false, + expr: And(And(And(GtEq(InputRef #2, Date(Date(8766)) (const)), Lt(InputRef #2, Date(Date(9131)) (const))), And(GtEq(InputRef #0, Decimal(0.07) (const)), LtEq(InputRef #0, Decimal(0.09) (const)))), Lt(InputRef #3, Decimal(24) (const))) */ -- tpch-q10: TPC-H Q10 @@ -369,45 +412,81 @@ PhysicalTopN: offset: 0, limit: 20, order by [InputRef #2 (desc)] InputRef #6 PhysicalHashAgg: InputRef #0 + InputRef #1 InputRef #2 + InputRef #4 + InputRef #8 InputRef #3 InputRef #5 - InputRef #15 - InputRef #4 - InputRef #6 - sum((InputRef #12 * (1 - InputRef #13))) -> NUMERIC(15,2) (null) - PhysicalHashJoin: - op Inner, - predicate: Eq(InputRef #1, InputRef #14) + sum((InputRef #6 * (1 - InputRef #7))) -> NUMERIC(15,2) (null) + PhysicalProjection: + InputRef #0 + InputRef #2 + InputRef #3 + InputRef #4 + InputRef #5 + InputRef #6 + InputRef #7 + InputRef #8 + InputRef #10 PhysicalHashJoin: op Inner, - predicate: Eq(InputRef #8, InputRef #10) - PhysicalHashJoin: - op Inner, - predicate: Eq(InputRef #0, InputRef #7) - PhysicalTableScan: - table #5, - columns [0, 3, 1, 5, 2, 4, 7], - with_row_handler: false, - is_sorted: false, - expr: None - PhysicalTableScan: - table #6, - columns [1, 0, 4], - with_row_handler: false, - is_sorted: false, - expr: And(GtEq(InputRef #2, Date(Date(8674)) (const)), Lt(InputRef #2, Date(Date(8766)) (const))) + predicate: Eq(InputRef #1, InputRef #9) + PhysicalProjection: + InputRef #0 + InputRef #1 + InputRef #2 + InputRef #3 + InputRef #4 + InputRef #5 + InputRef #6 + InputRef #9 + InputRef #10 + PhysicalHashJoin: + op Inner, + predicate: Eq(InputRef #7, InputRef #8) + PhysicalProjection: + InputRef #0 + InputRef #1 + InputRef #2 + InputRef #3 + InputRef #4 + InputRef #5 + InputRef #6 + InputRef #8 + PhysicalHashJoin: + op Inner, + predicate: Eq(InputRef #0, InputRef #7) + PhysicalTableScan: + table #5, + columns [0, 3, 1, 5, 2, 4, 7], + with_row_handler: false, + is_sorted: false, + expr: None + PhysicalProjection: + InputRef #0 + InputRef #1 + PhysicalTableScan: + table #6, + columns [1, 0, 4], + with_row_handler: false, + is_sorted: false, + expr: And(GtEq(InputRef #2, Date(Date(8674)) (const)), Lt(InputRef #2, Date(Date(8766)) (const))) + PhysicalProjection: + InputRef #0 + InputRef #1 + InputRef #2 + PhysicalTableScan: + table #7, + columns [0, 5, 6, 8], + with_row_handler: false, + is_sorted: false, + expr: Eq(InputRef #3, String("R") (const)) PhysicalTableScan: - table #7, - columns [0, 8, 5, 6], + table #0, + columns [0, 1], with_row_handler: false, is_sorted: false, - expr: Eq(InputRef #1, String("R") (const)) - PhysicalTableScan: - table #0, - columns [0, 1], - with_row_handler: false, - is_sorted: false, - expr: None + expr: None */