Skip to content

Commit

Permalink
Add a mutable visitor (#782)
Browse files Browse the repository at this point in the history
* Add a mutable visitor

This adds the ability to mutate parsed sql queries.
Previously, only visitors taking an immutable reference to the visited structures were allowed.

* add utility functions for mutable visits

* bump version numbers
  • Loading branch information
lovasoa committed Jan 2, 2023
1 parent 86d71f2 commit 524b8a7
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 150 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -33,7 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
# of dev-dependencies because of
# https://github.com/rust-lang/cargo/issues/1596
serde_json = { version = "1.0", optional = true }
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
sqlparser_derive = { version = "0.1.1", path = "derive", optional = true }

[dev-dependencies]
simple_logger = "4.0"
Expand Down
2 changes: 1 addition & 1 deletion derive/Cargo.toml
@@ -1,7 +1,7 @@
[package]
name = "sqlparser_derive"
description = "proc macro for sqlparser"
version = "0.1.0"
version = "0.1.1"
authors = ["sqlparser-rs authors"]
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser_derive/"
Expand Down
6 changes: 3 additions & 3 deletions derive/README.md
Expand Up @@ -6,13 +6,13 @@ This crate contains a procedural macro that can automatically derive
implementations of the `Visit` trait in the [sqlparser](https://crates.io/crates/sqlparser) crate

```rust
#[derive(Visit)]
#[derive(Visit, VisitMut)]
struct Foo {
boolean: bool,
bar: Bar,
}

#[derive(Visit)]
#[derive(Visit, VisitMut)]
enum Bar {
A(),
B(String, bool),
Expand Down Expand Up @@ -51,7 +51,7 @@ impl Visit for Bar {
Additionally certain types may wish to call a corresponding method on visitor before recursing

```rust
#[derive(Visit)]
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
Expand Down
65 changes: 49 additions & 16 deletions derive/src/lib.rs
Expand Up @@ -6,25 +6,58 @@ use syn::{
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
};


/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(VisitMut, attributes(visit))]
pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_visit(input, &VisitType {
visit_trait: quote!(VisitMut),
visitor_trait: quote!(VisitorMut),
modifier: Some(quote!(mut)),
})
}

/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_visit(input, &VisitType {
visit_trait: quote!(Visit),
visitor_trait: quote!(Visitor),
modifier: None,
})
}

struct VisitType {
visit_trait: TokenStream,
visitor_trait: TokenStream,
modifier: Option<TokenStream>,
}

fn derive_visit(
input: proc_macro::TokenStream,
visit_type: &VisitType,
) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;

let VisitType { visit_trait, visitor_trait, modifier } = visit_type;

let attributes = Attributes::parse(&input.attrs);
// Add a bound `T: HeapSize` to every type parameter T.
let generics = add_trait_bounds(input.generics);
// Add a bound `T: Visit` to every type parameter T.
let generics = add_trait_bounds(input.generics, visit_type);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let (pre_visit, post_visit) = attributes.visit(quote!(self));
let children = visit_children(&input.data);
let children = visit_children(&input.data, visit_type);

let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::#visitor_trait>(
&#modifier self,
visitor: &mut V
) -> ::std::ops::ControlFlow<V::Break> {
#pre_visit
#children
#post_visit
Expand Down Expand Up @@ -92,25 +125,25 @@ impl Attributes {
}

// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics) -> Generics {
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
type_param.bounds.push(parse_quote!(sqlparser::ast::#visit_trait));
}
}
generics
}

// Generate the body of the visit implementation for the given type
fn visit_children(data: &Data) -> TokenStream {
fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType) -> TokenStream {
match data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
Expand All @@ -121,7 +154,7 @@ fn visit_children(data: &Data) -> TokenStream {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
Expand All @@ -140,8 +173,8 @@ fn visit_children(data: &Data) -> TokenStream {
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});

quote!(
Expand All @@ -155,8 +188,8 @@ fn visit_children(data: &Data) -> TokenStream {
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});

quote! {
Expand Down
12 changes: 6 additions & 6 deletions src/ast/data_type.rs
Expand Up @@ -18,7 +18,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};

use crate::ast::ObjectName;

Expand All @@ -27,7 +27,7 @@ use super::value::escape_single_quote_string;
/// SQL data types
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum DataType {
/// Fixed-length character type e.g. CHARACTER(10)
Character(Option<CharacterLength>),
Expand Down Expand Up @@ -341,7 +341,7 @@ fn format_datetime_precision_and_tz(
/// guarantee compatibility with the input query we must maintain its exact information.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TimezoneInfo {
/// No information about time zone. E.g., TIMESTAMP
None,
Expand Down Expand Up @@ -389,7 +389,7 @@ impl fmt::Display for TimezoneInfo {
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ExactNumberInfo {
/// No additional information e.g. `DECIMAL`
None,
Expand Down Expand Up @@ -420,7 +420,7 @@ impl fmt::Display for ExactNumberInfo {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CharacterLength {
/// Default (if VARYING) or maximum (if not VARYING) length
pub length: u64,
Expand All @@ -443,7 +443,7 @@ impl fmt::Display for CharacterLength {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CharLengthUnits {
/// CHARACTERS unit
Characters,
Expand Down
22 changes: 11 additions & 11 deletions src/ast/ddl.rs
Expand Up @@ -21,7 +21,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};

use crate::ast::value::escape_single_quote_string;
use crate::ast::{display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName};
Expand All @@ -30,7 +30,7 @@ use crate::tokenizer::Token;
/// An `ALTER TABLE` (`Statement::AlterTable`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterTableOperation {
/// `ADD <table_constraint>`
AddConstraint(TableConstraint),
Expand Down Expand Up @@ -100,7 +100,7 @@ pub enum AlterTableOperation {

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterIndexOperation {
RenameIndex { index_name: ObjectName },
}
Expand Down Expand Up @@ -224,7 +224,7 @@ impl fmt::Display for AlterIndexOperation {
/// An `ALTER COLUMN` (`Statement::AlterTable`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterColumnOperation {
/// `SET NOT NULL`
SetNotNull,
Expand Down Expand Up @@ -268,7 +268,7 @@ impl fmt::Display for AlterColumnOperation {
/// `ALTER TABLE ADD <constraint>` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TableConstraint {
/// `[ CONSTRAINT <name> ] { PRIMARY KEY | UNIQUE } (<columns>)`
Unique {
Expand Down Expand Up @@ -433,7 +433,7 @@ impl fmt::Display for TableConstraint {
/// [1]: https://dev.mysql.com/doc/refman/8.0/en/create-table.html
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum KeyOrIndexDisplay {
/// Nothing to display
None,
Expand Down Expand Up @@ -469,7 +469,7 @@ impl fmt::Display for KeyOrIndexDisplay {
/// [3]: https://www.postgresql.org/docs/14/sql-createindex.html
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum IndexType {
BTree,
Hash,
Expand All @@ -488,7 +488,7 @@ impl fmt::Display for IndexType {
/// SQL column definition
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ColumnDef {
pub name: Ident,
pub data_type: DataType,
Expand Down Expand Up @@ -524,7 +524,7 @@ impl fmt::Display for ColumnDef {
/// "column options," and we allow any column option to be named.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ColumnOptionDef {
pub name: Option<Ident>,
pub option: ColumnOption,
Expand All @@ -540,7 +540,7 @@ impl fmt::Display for ColumnOptionDef {
/// TABLE` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ColumnOption {
/// `NULL`
Null,
Expand Down Expand Up @@ -630,7 +630,7 @@ fn display_constraint_name(name: &'_ Option<Ident>) -> impl fmt::Display + '_ {
/// Used in foreign key constraints in `ON UPDATE` and `ON DELETE` options.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ReferentialAction {
Restrict,
Cascade,
Expand Down
4 changes: 2 additions & 2 deletions src/ast/helpers/stmt_create_table.rs
Expand Up @@ -5,7 +5,7 @@ use alloc::{boxed::Box, format, string::String, vec, vec::Vec};
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};

use crate::ast::{
ColumnDef, FileFormat, HiveDistributionStyle, HiveFormat, ObjectName, OnCommit, Query,
Expand Down Expand Up @@ -43,7 +43,7 @@ use crate::parser::ParserError;
/// [1]: crate::ast::Statement::CreateTable
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateTableBuilder {
pub or_replace: bool,
pub temporary: bool,
Expand Down

0 comments on commit 524b8a7

Please sign in to comment.