Skip to content

Commit

Permalink
add {pre,post}_visit_query to Visitor (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmhain committed Nov 27, 2023
1 parent 640b939 commit 86aa044
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 15 deletions.
77 changes: 65 additions & 12 deletions derive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,86 @@ impl Visit for Bar {
}
```

Additionally certain types may wish to call a corresponding method on visitor before recursing
Some types may wish to call a corresponding method on the visitor:

```rust
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
IsNull(Box<Expr>),
..
}
```

Will generate
This will result in the following sequence of visitor calls when an `IsNull`
expression is visited

```
visitor.pre_visit_expr(<is null expr>)
visitor.pre_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```

For some types it is only appropriate to call a particular visitor method in
some contexts. For example, not every `ObjectName` refers to a relation.

In these cases, the `visit` attribute can be used on the field for which we'd
like to call the method:

```rust
impl Visit for Bar {
#[derive(Visit, VisitMut)]
#[visit(with = "visit_table_factor")]
pub enum TableFactor {
Table {
#[visit(with = "visit_relation")]
name: ObjectName,
alias: Option<TableAlias>,
},
..
}
```

This will generate

```rust
impl Visit for TableFactor {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.visit_expr(self)?;
visitor.pre_visit_table_factor(self)?;
match self {
Self::A() => {}
Self::B(_1, _2, _3) => {
_1.visit(visitor)?;
visitor.visit_relation(_3)?;
_2.visit(visitor)?;
_3.visit(visitor)?;
Self::Table { name, alias } => {
visitor.pre_visit_relation(name)?;
alias.visit(name)?;
visitor.post_visit_relation(name)?;
alias.visit(visitor)?;
}
}
visitor.post_visit_table_factor(self)?;
ControlFlow::Continue(())
}
}
```

Note that annotating both the type and the field is incorrect as it will result
in redundant calls to the method. For example

```rust
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
IsNull(#[visit(with = "visit_expr")] Box<Expr>),
..
}
```

will result in these calls to the visitor


```
visitor.pre_visit_expr(<is null expr>)
visitor.pre_visit_expr(<is null operand>)
visitor.pre_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```
3 changes: 1 addition & 2 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::ast::*;
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub with: Option<With>,
Expand Down Expand Up @@ -739,7 +740,6 @@ pub enum TableFactor {
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
/// See <https://docs.snowflake.com/en/sql-reference/constructs/pivot>
Pivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
table: Box<TableFactor>,
aggregate_function: Expr, // Function expression
value_column: Vec<Ident>,
Expand All @@ -755,7 +755,6 @@ pub enum TableFactor {
///
/// See <https://docs.snowflake.com/en/sql-reference/constructs/unpivot>.
Unpivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
table: Box<TableFactor>,
value: Ident,
name: Ident,
Expand Down
75 changes: 74 additions & 1 deletion src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.

use crate::ast::{Expr, ObjectName, Statement, TableFactor};
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use core::ops::ControlFlow;

/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
Expand Down Expand Up @@ -179,6 +179,16 @@ pub trait Visitor {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
Expand Down Expand Up @@ -267,6 +277,16 @@ pub trait VisitorMut {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
Expand Down Expand Up @@ -626,6 +646,18 @@ mod tests {
impl Visitor for TestVisitor {
type Break = ();

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: QUERY: {query}"));
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: QUERY: {query}"));
ControlFlow::Continue(())
}

fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: RELATION: {relation}"));
ControlFlow::Continue(())
Expand Down Expand Up @@ -695,17 +727,20 @@ mod tests {
"SELECT * from table_name as my_table",
vec![
"PRE: STATEMENT: SELECT * FROM table_name AS my_table",
"PRE: QUERY: SELECT * FROM table_name AS my_table",
"PRE: TABLE FACTOR: table_name AS my_table",
"PRE: RELATION: table_name",
"POST: RELATION: table_name",
"POST: TABLE FACTOR: table_name AS my_table",
"POST: QUERY: SELECT * FROM table_name AS my_table",
"POST: STATEMENT: SELECT * FROM table_name AS my_table",
],
),
(
"SELECT * from t1 join t2 on t1.id = t2.t1_id",
vec![
"PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
Expand All @@ -720,70 +755,108 @@ mod tests {
"PRE: EXPR: t2.t1_id",
"POST: EXPR: t2.t1_id",
"POST: EXPR: t1.id = t2.t1_id",
"POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t3",
"PRE: RELATION: t3",
"POST: RELATION: t3",
"POST: TABLE FACTOR: t3",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
],
),
(
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
),
vec![
"PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: TABLE FACTOR: monthly_sales",
"PRE: RELATION: monthly_sales",
"POST: RELATION: monthly_sales",
"POST: TABLE FACTOR: monthly_sales",
"PRE: EXPR: SUM(a.amount)",
"PRE: EXPR: a.amount",
"POST: EXPR: a.amount",
"POST: EXPR: SUM(a.amount)",
"POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: EXPR: EMPID",
"POST: EXPR: EMPID",
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
]
)
];
for (sql, expected) in tests {
let actual = do_visit(sql);
Expand Down

0 comments on commit 86aa044

Please sign in to comment.