diff --git a/derive/README.md b/derive/README.md index ec0fcb6fb..aadf99818 100644 --- a/derive/README.md +++ b/derive/README.md @@ -48,33 +48,86 @@ impl Visit for Bar { } ``` -Additionally certain types may wish to call a corresponding method on visitor before recursing +Some types may wish to call a corresponding method on the visitor: ```rust #[derive(Visit, VisitMut)] #[visit(with = "visit_expr")] enum Expr { - A(), - B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool), + IsNull(Box), + .. } ``` -Will generate +This will result in the following sequence of visitor calls when an `IsNull` +expression is visited + +``` +visitor.pre_visit_expr() +visitor.pre_visit_expr() +visitor.post_visit_expr() +visitor.post_visit_expr() +``` + +For some types it is only appropriate to call a particular visitor method in +some contexts. For example, not every `ObjectName` refers to a relation. + +In these cases, the `visit` attribute can be used on the field for which we'd +like to call the method: ```rust -impl Visit for Bar { +#[derive(Visit, VisitMut)] +#[visit(with = "visit_table_factor")] +pub enum TableFactor { + Table { + #[visit(with = "visit_relation")] + name: ObjectName, + alias: Option, + }, + .. +} +``` + +This will generate + +```rust +impl Visit for TableFactor { fn visit(&self, visitor: &mut V) -> ControlFlow { - visitor.visit_expr(self)?; + visitor.pre_visit_table_factor(self)?; match self { - Self::A() => {} - Self::B(_1, _2, _3) => { - _1.visit(visitor)?; - visitor.visit_relation(_3)?; - _2.visit(visitor)?; - _3.visit(visitor)?; + Self::Table { name, alias } => { + visitor.pre_visit_relation(name)?; + alias.visit(name)?; + visitor.post_visit_relation(name)?; + alias.visit(visitor)?; } } + visitor.post_visit_table_factor(self)?; ControlFlow::Continue(()) } } ``` + +Note that annotating both the type and the field is incorrect as it will result +in redundant calls to the method. For example + +```rust +#[derive(Visit, VisitMut)] +#[visit(with = "visit_expr")] +enum Expr { + IsNull(#[visit(with = "visit_expr")] Box), + .. +} +``` + +will result in these calls to the visitor + + +``` +visitor.pre_visit_expr() +visitor.pre_visit_expr() +visitor.pre_visit_expr() +visitor.post_visit_expr() +visitor.post_visit_expr() +visitor.post_visit_expr() +``` diff --git a/src/ast/query.rs b/src/ast/query.rs index d103f589e..f1ed75c02 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -26,6 +26,7 @@ use crate::ast::*; #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] pub struct Query { /// WITH (common table expressions, or CTEs) pub with: Option, @@ -739,7 +740,6 @@ pub enum TableFactor { /// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))` /// See Pivot { - #[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))] table: Box, aggregate_function: Expr, // Function expression value_column: Vec, @@ -755,7 +755,6 @@ pub enum TableFactor { /// /// See . Unpivot { - #[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))] table: Box, value: Ident, name: Ident, diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 99db16107..c4f9c494d 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -12,7 +12,7 @@ //! Recursive visitors for ast Nodes. See [`Visitor`] for more details. -use crate::ast::{Expr, ObjectName, Statement, TableFactor}; +use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor}; use core::ops::ControlFlow; /// A type that can be visited by a [`Visitor`]. See [`Visitor`] for @@ -179,6 +179,16 @@ pub trait Visitor { /// Type returned when the recursion returns early. type Break; + /// Invoked for any queries that appear in the AST before visiting children + fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any queries that appear in the AST after visiting children + fn post_visit_query(&mut self, _query: &Query) -> ControlFlow { + ControlFlow::Continue(()) + } + /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow { ControlFlow::Continue(()) @@ -267,6 +277,16 @@ pub trait VisitorMut { /// Type returned when the recursion returns early. type Break; + /// Invoked for any queries that appear in the AST before visiting children + fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any queries that appear in the AST after visiting children + fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow { + ControlFlow::Continue(()) + } + /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow { ControlFlow::Continue(()) @@ -626,6 +646,18 @@ mod tests { impl Visitor for TestVisitor { type Break = (); + /// Invoked for any queries that appear in the AST before visiting children + fn pre_visit_query(&mut self, query: &Query) -> ControlFlow { + self.visited.push(format!("PRE: QUERY: {query}")); + ControlFlow::Continue(()) + } + + /// Invoked for any queries that appear in the AST after visiting children + fn post_visit_query(&mut self, query: &Query) -> ControlFlow { + self.visited.push(format!("POST: QUERY: {query}")); + ControlFlow::Continue(()) + } + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow { self.visited.push(format!("PRE: RELATION: {relation}")); ControlFlow::Continue(()) @@ -695,10 +727,12 @@ mod tests { "SELECT * from table_name as my_table", vec![ "PRE: STATEMENT: SELECT * FROM table_name AS my_table", + "PRE: QUERY: SELECT * FROM table_name AS my_table", "PRE: TABLE FACTOR: table_name AS my_table", "PRE: RELATION: table_name", "POST: RELATION: table_name", "POST: TABLE FACTOR: table_name AS my_table", + "POST: QUERY: SELECT * FROM table_name AS my_table", "POST: STATEMENT: SELECT * FROM table_name AS my_table", ], ), @@ -706,6 +740,7 @@ mod tests { "SELECT * from t1 join t2 on t1.id = t2.t1_id", vec![ "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", + "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", @@ -720,6 +755,7 @@ mod tests { "PRE: EXPR: t2.t1_id", "POST: EXPR: t2.t1_id", "POST: EXPR: t1.id = t2.t1_id", + "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id", ], ), @@ -727,18 +763,22 @@ mod tests { "SELECT * from t1 where EXISTS(SELECT column from t2)", vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", ], ), @@ -746,18 +786,22 @@ mod tests { "SELECT * from t1 where EXISTS(SELECT column from t2)", vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", + "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)", ], ), @@ -765,25 +809,54 @@ mod tests { "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3", vec![ "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", + "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", "PRE: TABLE FACTOR: t1", "PRE: RELATION: t1", "POST: RELATION: t1", "POST: TABLE FACTOR: t1", "PRE: EXPR: EXISTS (SELECT column FROM t2)", + "PRE: QUERY: SELECT column FROM t2", "PRE: EXPR: column", "POST: EXPR: column", "PRE: TABLE FACTOR: t2", "PRE: RELATION: t2", "POST: RELATION: t2", "POST: TABLE FACTOR: t2", + "POST: QUERY: SELECT column FROM t2", "POST: EXPR: EXISTS (SELECT column FROM t2)", "PRE: TABLE FACTOR: t3", "PRE: RELATION: t3", "POST: RELATION: t3", "POST: TABLE FACTOR: t3", + "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3", ], ), + ( + concat!( + "SELECT * FROM monthly_sales ", + "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ", + "ORDER BY EMPID" + ), + vec![ + "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", + "PRE: TABLE FACTOR: monthly_sales", + "PRE: RELATION: monthly_sales", + "POST: RELATION: monthly_sales", + "POST: TABLE FACTOR: monthly_sales", + "PRE: EXPR: SUM(a.amount)", + "PRE: EXPR: a.amount", + "POST: EXPR: a.amount", + "POST: EXPR: SUM(a.amount)", + "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)", + "PRE: EXPR: EMPID", + "POST: EXPR: EMPID", + "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", + ] + ) ]; for (sql, expected) in tests { let actual = do_visit(sql);