From 882db06dc89643d370fab66ba2d84381d854b3b7 Mon Sep 17 00:00:00 2001 From: Mingj Date: Wed, 13 Oct 2021 17:46:49 -0400 Subject: [PATCH] refactor(binder): support BoundTableRef enum --- src/binder/statement/select.rs | 23 ++++++++++++++++------- src/binder/table_ref/mod.rs | 19 +++++++++++++------ src/logical_plan/select.rs | 18 +++++++++++++----- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/binder/statement/select.rs b/src/binder/statement/select.rs index 9911cbbc..0465b284 100644 --- a/src/binder/statement/select.rs +++ b/src/binder/statement/select.rs @@ -1,4 +1,5 @@ use super::*; +use crate::binder::BoundTableRef; use crate::parser::{Query, SelectItem, SetExpr}; #[derive(Debug, PartialEq, Clone)] @@ -30,6 +31,11 @@ impl Binder { }; // Bind table ref let mut from_table = vec![]; + // We don't support cross join now. + // The cross join will have multiple TableWithJoin in "from" struct. + // Other types of join will onyl have one TableWithJoin in "from" struct. + assert!(select.from.len() <= 1); + for table_ref in select.from.iter() { let table_ref = self.bind_table_ref(&table_ref.relation)?; from_table.push(table_ref); @@ -65,12 +71,15 @@ impl Binder { // Add referred columns for base table reference for table_ref in from_table.iter_mut() { - table_ref.column_ids = self - .context - .column_ids - .get(&table_ref.table_name) - .unwrap() - .clone(); + match table_ref { + BoundTableRef::BaseTableRef { + ref_id: _, + table_name, + column_ids, + } => { + *column_ids = self.context.column_ids.get(table_name).unwrap().clone(); + } + } } Ok(Box::new(BoundSelect { @@ -136,7 +145,7 @@ mod tests { return_type: Some(DataTypeKind::Int.not_null()), }, ], - from_table: vec![BoundTableRef { + from_table: vec![BoundTableRef::BaseTableRef { ref_id: TableRefId::new(0, 0, 0), table_name: "t".into(), column_ids: vec![1, 0], diff --git a/src/binder/table_ref/mod.rs b/src/binder/table_ref/mod.rs index 62713a15..a5da1f3a 100644 --- a/src/binder/table_ref/mod.rs +++ b/src/binder/table_ref/mod.rs @@ -1,11 +1,12 @@ use super::*; use crate::parser::TableFactor; - #[derive(Debug, PartialEq, Clone)] -pub struct BoundTableRef { - pub ref_id: TableRefId, - pub table_name: String, - pub column_ids: Vec, +pub enum BoundTableRef { + BaseTableRef { + ref_id: TableRefId, + table_name: String, + column_ids: Vec, + }, } impl Binder { @@ -33,12 +34,18 @@ impl Binder { self.context .column_ids .insert(table_name.into(), Vec::new()); - Ok(BoundTableRef { + Ok(BoundTableRef::BaseTableRef { ref_id, table_name: table_name.into(), column_ids: vec![], }) } + TableFactor::NestedJoin(table_with_joins) => { + let bounded_table_ref = self.bind_table_ref(&table_with_joins.relation)?; + // We only support cross join now. + assert_eq!(table_with_joins.joins.len(), 0); + Ok(bounded_table_ref) + } _ => panic!("bind table ref"), } } diff --git a/src/logical_plan/select.rs b/src/logical_plan/select.rs index 4ed31503..244935e7 100644 --- a/src/logical_plan/select.rs +++ b/src/logical_plan/select.rs @@ -1,14 +1,22 @@ use super::*; -use crate::binder::BoundSelect; +use crate::binder::{BoundSelect, BoundTableRef}; impl LogicalPlaner { pub fn plan_select(&self, stmt: Box) -> Result { let mut plan = LogicalPlan::Dummy; if let Some(table_ref) = stmt.from_table.get(0) { - plan = LogicalPlan::SeqScan(LogicalSeqScan { - table_ref_id: table_ref.ref_id, - column_ids: table_ref.column_ids.clone(), - }); + match table_ref { + BoundTableRef::BaseTableRef { + ref_id, + table_name: _, + column_ids, + } => { + plan = LogicalPlan::SeqScan(LogicalSeqScan { + table_ref_id: *ref_id, + column_ids: column_ids.to_vec(), + }); + } + } } if let Some(expr) = stmt.where_clause { plan = LogicalPlan::Filter(LogicalFilter {