diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 146ebf349..46e8dda2c 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -67,21 +67,29 @@ pub trait Dialect: Debug + Any { fn is_identifier_start(&self, ch: char) -> bool; /// Determine if a character is a valid unquoted identifier character fn is_identifier_part(&self, ch: char) -> bool; - /// Custom prefix parser + /// Dialect-specific prefix parser override fn parse_prefix(&self, _parser: &mut Parser) -> Option> { + // return None to fall back to the default behavior None } - /// Custom infix parser + /// Dialect-specific infix parser override fn parse_infix( &self, _parser: &mut Parser, _expr: &Expr, _precendence: u8, ) -> Option> { + // return None to fall back to the default behavior None } - /// Custom statement parser + /// Dialect-specific precedence override + fn get_next_precedence(&self, _parser: &Parser) -> Option> { + // return None to fall back to the default behavior + None + } + /// Dialect-specific statement parser override fn parse_statement(&self, _parser: &mut Parser) -> Option> { + // return None to fall back to the default behavior None } } diff --git a/src/parser.rs b/src/parser.rs index a82c70400..5f1d2f145 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1485,6 +1485,11 @@ impl<'a> Parser<'a> { /// Get the precedence of the next token pub fn get_next_precedence(&self) -> Result { + // allow the dialect to override precedence logic + if let Some(precedence) = self.dialect.get_next_precedence(self) { + return precedence; + } + let token = self.peek_token(); debug!("get_next_precedence() {:?}", token); let token_0 = self.peek_nth_token(0); @@ -4905,4 +4910,52 @@ mod tests { assert_eq!(ast.to_string(), sql.to_string()); }); } + + #[test] + fn custom_infix_parser() -> Result<(), ParserError> { + #[derive(Debug)] + struct MyDialect {} + + impl Dialect for MyDialect { + fn is_identifier_start(&self, ch: char) -> bool { + // See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + // We don't yet support identifiers beginning with "letters with + // diacritical marks and non-Latin letters" + ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' + } + + fn is_identifier_part(&self, ch: char) -> bool { + ('a'..='z').contains(&ch) + || ('A'..='Z').contains(&ch) + || ('0'..='9').contains(&ch) + || ch == '$' + || ch == '_' + } + + fn parse_infix( + &self, + parser: &mut Parser, + expr: &Expr, + _precendence: u8, + ) -> Option> { + if parser.peek_token() == Token::Plus { + assert!(parser.consume_token(&Token::Plus)); + Some(Ok(Expr::BinaryOp { + left: Box::new(expr.clone()), + op: BinaryOperator::Multiply, // translate Plus to Multiply + right: Box::new(parser.parse_expr().unwrap()), + })) + } else { + None + } + } + } + + let dialect = MyDialect {}; + let sql = "SELECT 1 + 2"; + let ast = Parser::parse_sql(&dialect, sql)?; + let query = &ast[0]; + assert_eq!("SELECT 1 * 2", &format!("{}", query)); + Ok(()) + } }