Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Aug 17, 2022
1 parent f518c30 commit 8640d1b
Showing 1 changed file with 83 additions and 11 deletions.
94 changes: 83 additions & 11 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4911,25 +4911,49 @@ mod tests {
});
}

#[test]
fn custom_prefix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Number("1".to_string(), false)) {
Some(Ok(Expr::Value(Value::Null)))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT NULL + 2", &format!("{}", query));
Ok(())
}

#[test]
fn custom_infix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
// See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
// We don't yet support identifiers beginning with "letters with
// diacritical marks and non-Latin letters"
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '$'
|| ch == '_'
is_identifier_part(ch)
}

fn parse_infix(
Expand All @@ -4938,8 +4962,7 @@ mod tests {
expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
if parser.peek_token() == Token::Plus {
assert!(parser.consume_token(&Token::Plus));
if parser.consume_token(&Token::Plus) {
Some(Ok(Expr::BinaryOp {
left: Box::new(expr.clone()),
op: BinaryOperator::Multiply, // translate Plus to Multiply
Expand All @@ -4958,4 +4981,53 @@ mod tests {
assert_eq!("SELECT 1 * 2", &format!("{}", query));
Ok(())
}

#[test]
fn custom_statement_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_statement(
&self,
parser: &mut Parser,
) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::SELECT) {
for _ in 0..3 {
let _ = parser.next_token();
}
Some(Ok(Statement::Commit { chain: false }))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("COMMIT", &format!("{}", query));
Ok(())
}

fn is_identifier_start(ch: char) -> bool {
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
}

fn is_identifier_part(ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '$'
|| ch == '_'
}
}

0 comments on commit 8640d1b

Please sign in to comment.