Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add {pre,post}_visit_query to Visitor #1044

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 65 additions & 12 deletions derive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,86 @@ impl Visit for Bar {
}
```

Additionally certain types may wish to call a corresponding method on visitor before recursing
Some types may wish to call a corresponding method on the visitor:

```rust
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
IsNull(Box<Expr>),
..
}
```

Will generate
This will result in the following sequence of visitor calls when an `IsNull`
expression is visited

```
visitor.pre_visit_expr(<is null expr>)
visitor.pre_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```

For some types it is only appropriate to call a particular visitor method in
some contexts. For example, not every `ObjectName` refers to a relation.

In these cases, the `visit` attribute can be used on the field for which we'd
like to call the method:

```rust
impl Visit for Bar {
#[derive(Visit, VisitMut)]
#[visit(with = "visit_table_factor")]
pub enum TableFactor {
Table {
#[visit(with = "visit_relation")]
name: ObjectName,
alias: Option<TableAlias>,
},
..
}
```

This will generate

```rust
impl Visit for TableFactor {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.visit_expr(self)?;
visitor.pre_visit_table_factor(self)?;
match self {
Self::A() => {}
Self::B(_1, _2, _3) => {
_1.visit(visitor)?;
visitor.visit_relation(_3)?;
_2.visit(visitor)?;
_3.visit(visitor)?;
Self::Table { name, alias } => {
visitor.pre_visit_relation(name)?;
alias.visit(name)?;
visitor.post_visit_relation(name)?;
alias.visit(visitor)?;
}
}
visitor.post_visit_table_factor(self)?;
ControlFlow::Continue(())
}
}
```

Note that annotating both the type and the field is incorrect as it will result
in redundant calls to the method. For example

```rust
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
IsNull(#[visit(with = "visit_expr")] Box<Expr>),
..
}
```

will result in these calls to the visitor


```
visitor.pre_visit_expr(<is null expr>)
visitor.pre_visit_expr(<is null operand>)
visitor.pre_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null operand>)
visitor.post_visit_expr(<is null expr>)
```
3 changes: 1 addition & 2 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::ast::*;
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub with: Option<With>,
Expand Down Expand Up @@ -739,7 +740,6 @@ pub enum TableFactor {
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
/// See <https://docs.snowflake.com/en/sql-reference/constructs/pivot>
Pivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why were these removed?

Copy link
Contributor Author

@jmhain jmhain Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TableFactor is already annotated with this, so it results in two calls to both the pre_ and post_visit_table_factor for the same table factor, which besides being unnecessary is likely to cause issues in user's code if their visitor doesn't happen to be idempotent.

As an example, this is the output of one of the new tests I added without this change:

 [
     "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
     "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
     "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
     "PRE: TABLE FACTOR: monthly_sales",
     "PRE: TABLE FACTOR: monthly_sales",
     "PRE: RELATION: monthly_sales",
     "POST: RELATION: monthly_sales",
     "POST: TABLE FACTOR: monthly_sales",
     "POST: TABLE FACTOR: monthly_sales",
     "PRE: EXPR: SUM(a.amount)",
     "PRE: EXPR: a.amount",
     "POST: EXPR: a.amount",
     "POST: EXPR: SUM(a.amount)",
     "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
     "PRE: EXPR: EMPID",
     "POST: EXPR: EMPID",
     "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
     "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
 ]

Note the duplicate calls for TABLE FACTOR: monthly_sales".

I added a section in the README.md for the derive crate that explains this in more detail: https://github.com/jmhain/sqlparser-rs/blob/visit_query/derive/README.md

I can split this into a separate PR if you'd like since it's separate from adding {pre,post}_visit_query, I just discovered it now because I almost introduced another instance of this same bug here.

table: Box<TableFactor>,
aggregate_function: Expr, // Function expression
value_column: Vec<Ident>,
Expand All @@ -755,7 +755,6 @@ pub enum TableFactor {
///
/// See <https://docs.snowflake.com/en/sql-reference/constructs/unpivot>.
Unpivot {
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
table: Box<TableFactor>,
value: Ident,
name: Ident,
Expand Down
75 changes: 74 additions & 1 deletion src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.

use crate::ast::{Expr, ObjectName, Statement, TableFactor};
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use core::ops::ControlFlow;

/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
Expand Down Expand Up @@ -179,6 +179,16 @@ pub trait Visitor {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
Expand Down Expand Up @@ -267,6 +277,16 @@ pub trait VisitorMut {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
Expand Down Expand Up @@ -626,6 +646,18 @@ mod tests {
impl Visitor for TestVisitor {
type Break = ();

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: QUERY: {query}"));
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
self.visited.push(format!("POST: QUERY: {query}"));
ControlFlow::Continue(())
}

fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.visited.push(format!("PRE: RELATION: {relation}"));
ControlFlow::Continue(())
Expand Down Expand Up @@ -695,17 +727,20 @@ mod tests {
"SELECT * from table_name as my_table",
vec![
"PRE: STATEMENT: SELECT * FROM table_name AS my_table",
"PRE: QUERY: SELECT * FROM table_name AS my_table",
"PRE: TABLE FACTOR: table_name AS my_table",
"PRE: RELATION: table_name",
"POST: RELATION: table_name",
"POST: TABLE FACTOR: table_name AS my_table",
"POST: QUERY: SELECT * FROM table_name AS my_table",
"POST: STATEMENT: SELECT * FROM table_name AS my_table",
],
),
(
"SELECT * from t1 join t2 on t1.id = t2.t1_id",
vec![
"PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
Expand All @@ -720,70 +755,108 @@ mod tests {
"PRE: EXPR: t2.t1_id",
"POST: EXPR: t2.t1_id",
"POST: EXPR: t1.id = t2.t1_id",
"POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
"POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2)",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
],
),
(
"SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
vec![
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"PRE: TABLE FACTOR: t1",
"PRE: RELATION: t1",
"POST: RELATION: t1",
"POST: TABLE FACTOR: t1",
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: QUERY: SELECT column FROM t2",
"PRE: EXPR: column",
"POST: EXPR: column",
"PRE: TABLE FACTOR: t2",
"PRE: RELATION: t2",
"POST: RELATION: t2",
"POST: TABLE FACTOR: t2",
"POST: QUERY: SELECT column FROM t2",
"POST: EXPR: EXISTS (SELECT column FROM t2)",
"PRE: TABLE FACTOR: t3",
"PRE: RELATION: t3",
"POST: RELATION: t3",
"POST: TABLE FACTOR: t3",
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
],
),
(
concat!(
"SELECT * FROM monthly_sales ",
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
"ORDER BY EMPID"
),
vec![
"PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: TABLE FACTOR: monthly_sales",
"PRE: RELATION: monthly_sales",
"POST: RELATION: monthly_sales",
"POST: TABLE FACTOR: monthly_sales",
"PRE: EXPR: SUM(a.amount)",
"PRE: EXPR: a.amount",
"POST: EXPR: a.amount",
"POST: EXPR: SUM(a.amount)",
"POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
"PRE: EXPR: EMPID",
"POST: EXPR: EMPID",
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
]
)
];
for (sql, expected) in tests {
let actual = do_visit(sql);
Expand Down