Skip to content

Commit

Permalink
visit_query
Browse files Browse the repository at this point in the history
  • Loading branch information
jmhain committed Nov 9, 2023
1 parent 4cdaa40 commit 23e42b8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 17 deletions.
47 changes: 38 additions & 9 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta, Type
};


Expand Down Expand Up @@ -48,7 +48,7 @@ fn derive_visit(
let generics = add_trait_bounds(input.generics, visit_type);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let (pre_visit, post_visit) = attributes.visit(quote!(self));
let (pre_visit, post_visit) = attributes.visit(quote!(self), false);
let children = visit_children(&input.data, visit_type);

let expanded = quote! {
Expand Down Expand Up @@ -111,19 +111,48 @@ impl Attributes {
}

/// Returns the pre and post visit token streams
fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
fn visit(&self, s: TokenStream, is_option: bool) -> (Option<TokenStream>, Option<TokenStream>) {
let pre_visit = self.with.as_ref().map(|m| {
let m = format_ident!("pre_{}", m);
quote!(visitor.#m(#s)?;)
if is_option {
quote! {
if let Some(f) = #s {
visitor.#m(f)?;
}
}
} else {
quote!(visitor.#m(#s)?;)
}
});
let post_visit = self.with.as_ref().map(|m| {
let m = format_ident!("post_{}", m);
quote!(visitor.#m(#s)?;)
if is_option {
quote! {
if let Some(f) = #s {
visitor.#m(f)?;
}
}
} else {
quote!(visitor.#m(#s)?;)
}
});
(pre_visit, post_visit)
}
}

fn is_option(mut ty: &Type) -> bool {
while let Type::Group(group) = ty {
ty = &group.elem;
}
let Type::Path(ty) = &ty else {
return false;
};
let Some(seg) = ty.path.segments.last() else {
return false;
};
seg.ident == "Option"
}

// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
for param in &mut generics.params {
Expand All @@ -142,7 +171,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name), is_option(&f.ty));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
});
quote! {
Expand All @@ -153,7 +182,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index), is_option(&f.ty));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
});
quote! {
Expand All @@ -173,7 +202,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream(), is_option(&f.ty));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});

Expand All @@ -188,7 +217,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream(), is_option(&f.ty));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});

Expand Down
24 changes: 18 additions & 6 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ pub enum Expr {
/// `[ NOT ] IN (SELECT ...)`
InSubquery {
expr: Box<Expr>,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
subquery: Box<Query>,
negated: bool,
},
Expand Down Expand Up @@ -600,12 +601,16 @@ pub enum Expr {
},
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
/// `WHERE [ NOT ] EXISTS (SELECT ...)`.
Exists { subquery: Box<Query>, negated: bool },
Exists {
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
subquery: Box<Query>,
negated: bool,
},
/// A parenthesized subquery `(SELECT ...)`, used in expression like
/// `SELECT (subquery) AS x` or `WHERE (subquery) = x`
Subquery(Box<Query>),
Subquery(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
/// An array subquery constructor, e.g. `SELECT ARRAY(SELECT 1 UNION SELECT 2)`
ArraySubquery(Box<Query>),
ArraySubquery(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
ListAgg(ListAgg),
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
Expand Down Expand Up @@ -1368,7 +1373,7 @@ pub enum Statement {
partition_action: Option<AddDropSync>,
},
/// SELECT
Query(Box<Query>),
Query(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
/// INSERT
Insert {
/// Only for Sqlite
Expand All @@ -1385,6 +1390,7 @@ pub enum Statement {
/// Overwrite (Hive)
overwrite: bool,
/// A SQL query that specifies what to insert
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
source: Box<Query>,
/// partitioned insert (Hive)
partitioned: Option<Vec<Expr>>,
Expand All @@ -1402,6 +1408,7 @@ pub enum Statement {
local: bool,
path: String,
file_format: Option<FileFormat>,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
source: Box<Query>,
},
Copy {
Expand Down Expand Up @@ -1480,6 +1487,7 @@ pub enum Statement {
/// View name
name: ObjectName,
columns: Vec<Ident>,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
query: Box<Query>,
with_options: Vec<SqlOption>,
cluster_by: Vec<Ident>,
Expand Down Expand Up @@ -1510,6 +1518,7 @@ pub enum Statement {
with_options: Vec<SqlOption>,
file_format: Option<FileFormat>,
location: Option<String>,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
query: Option<Box<Query>>,
without_rowid: bool,
like: Option<ObjectName>,
Expand Down Expand Up @@ -1598,6 +1607,7 @@ pub enum Statement {
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
name: ObjectName,
columns: Vec<Ident>,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
query: Box<Query>,
with_options: Vec<SqlOption>,
},
Expand Down Expand Up @@ -1665,6 +1675,7 @@ pub enum Statement {
/// Some(true) = WITH HOLD, specifies that the cursor can continue to be used after the transaction that created it successfully commits
/// Some(false) = WITHOUT HOLD, specifies that the cursor cannot be used outside of the transaction that created it
hold: Option<bool>,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
query: Box<Query>,
},
/// FETCH - retrieve rows from a query using a cursor
Expand Down Expand Up @@ -1969,6 +1980,7 @@ pub enum Statement {
/// Table confs
options: Vec<SqlOption>,
/// Cache table as a Query
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
query: Option<Query>,
},
/// UNCACHE TABLE [ IF EXISTS ] <table_name>
Expand Down Expand Up @@ -4278,7 +4290,7 @@ pub enum CopySource {
/// are copied.
columns: Vec<Ident>,
},
Query(Box<Query>),
Query(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
Expand Down Expand Up @@ -4795,7 +4807,7 @@ impl fmt::Display for MacroArg {
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum MacroDefinition {
Expr(Expr),
Table(Query),
Table(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Query),
}

impl fmt::Display for MacroDefinition {
Expand Down
5 changes: 4 additions & 1 deletion src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::ast::*;
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub with: Option<With>,
Expand Down Expand Up @@ -86,7 +87,7 @@ pub enum SetExpr {
Select(Box<Select>),
/// Parenthesized SELECT subquery, which may include more set operations
/// in its body and an optional ORDER BY / LIMIT.
Query(Box<Query>),
Query(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
/// UNION/EXCEPT/INTERSECT of two queries
SetOperation {
op: SetOperator,
Expand Down Expand Up @@ -377,6 +378,7 @@ impl fmt::Display for With {
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Cte {
pub alias: TableAlias,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
pub query: Box<Query>,
pub from: Option<Ident>,
}
Expand Down Expand Up @@ -687,6 +689,7 @@ pub enum TableFactor {
},
Derived {
lateral: bool,
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
subquery: Box<Query>,
alias: Option<TableAlias>,
},
Expand Down
22 changes: 21 additions & 1 deletion src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.

use crate::ast::{Expr, ObjectName, Statement, TableFactor};
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
use core::ops::ControlFlow;

/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
Expand Down Expand Up @@ -179,6 +179,16 @@ pub trait Visitor {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
Expand Down Expand Up @@ -267,6 +277,16 @@ pub trait VisitorMut {
/// Type returned when the recursion returns early.
type Break;

/// Invoked for any queries that appear in the AST before visiting children
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any queries that appear in the AST after visiting children
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
Expand Down

0 comments on commit 23e42b8

Please sign in to comment.