diff --git a/Cargo.toml b/Cargo.toml index 00b4d1a1c..05f3eba23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ path = "src/lib.rs" [features] default = ["std"] -std = [] +std = ["scopeguard"] # Enable JSON output in the `cli` example: json_example = ["serde_json", "serde"] @@ -32,6 +32,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } # of dev-dependencies because of # https://github.com/rust-lang/cargo/issues/1596 serde_json = { version = "1.0", optional = true } +scopeguard = { version = "1.1.0", optional = true } [dev-dependencies] simple_logger = "2.1" diff --git a/src/parser.rs b/src/parser.rs index 59198fa66..9a933f6d4 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -21,8 +21,14 @@ use alloc::{ vec::Vec, }; use core::fmt; +#[cfg(feature = "std")] +use std::cell::Cell; +#[cfg(feature = "std")] +use std::rc::Rc; use log::debug; +#[cfg(feature = "std")] +use scopeguard::defer; use crate::ast::*; use crate::dialect::*; @@ -33,6 +39,7 @@ use crate::tokenizer::*; pub enum ParserError { TokenizerError(String), ParserError(String), + RecursionLimitExceeded, } // Use `Parser::expected` instead, if possible @@ -51,6 +58,25 @@ macro_rules! return_ok_if_some { }}; } +#[cfg(feature = "std")] +macro_rules! check_recursion_depth { + ($this:ident) => { + let remaining_depth = $this.remaining_depth.clone(); + remaining_depth.set(remaining_depth.get().saturating_sub(1)); + if remaining_depth.get() == 0 { + return Err(ParserError::RecursionLimitExceeded); + } + defer! { + remaining_depth.set(remaining_depth.get() + 1); + } + }; +} + +#[cfg(not(feature = "std"))] +macro_rules! check_recursion_depth { + ($this:ident) => {}; +} + #[derive(PartialEq)] pub enum IsOptional { Optional, @@ -96,6 +122,7 @@ impl fmt::Display for ParserError { match self { ParserError::TokenizerError(s) => s, ParserError::ParserError(s) => s, + ParserError::RecursionLimitExceeded => "recursion limit exceeded", } ) } @@ -109,6 +136,8 @@ pub struct Parser<'a> { /// The index of the first unprocessed token in `self.tokens` index: usize, dialect: &'a dyn Dialect, + #[cfg(feature = "std")] + remaining_depth: Rc>, } impl<'a> Parser<'a> { @@ -118,6 +147,8 @@ impl<'a> Parser<'a> { tokens, index: 0, dialect, + #[cfg(feature = "std")] + remaining_depth: Rc::new(Cell::new(96)), } } @@ -152,6 +183,8 @@ impl<'a> Parser<'a> { /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// stopping before the statement separator, if any. pub fn parse_statement(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::KILL => Ok(self.parse_kill()?), @@ -211,6 +244,8 @@ impl<'a> Parser<'a> { } pub fn parse_msck(&mut self) -> Result { + check_recursion_depth!(self); + let repair = self.parse_keyword(Keyword::REPAIR); self.expect_keyword(Keyword::TABLE)?; let table_name = self.parse_object_name()?; @@ -238,6 +273,8 @@ impl<'a> Parser<'a> { } pub fn parse_truncate(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let table_name = self.parse_object_name()?; let mut partitions = None; @@ -253,6 +290,8 @@ impl<'a> Parser<'a> { } pub fn parse_analyze(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let table_name = self.parse_object_name()?; let mut for_columns = false; @@ -310,6 +349,8 @@ impl<'a> Parser<'a> { /// Parse a new expression including wildcard & qualified wildcard pub fn parse_wildcard_expr(&mut self) -> Result { + check_recursion_depth!(self); + let index = self.index; match self.next_token() { @@ -340,11 +381,15 @@ impl<'a> Parser<'a> { /// Parse a new expression pub fn parse_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.parse_subexpr(0) } /// Parse tokens until the precedence changes pub fn parse_subexpr(&mut self, precedence: u8) -> Result { + check_recursion_depth!(self); + debug!("parsing expr"); let mut expr = self.parse_prefix()?; debug!("prefix: {:?}", expr); @@ -362,6 +407,8 @@ impl<'a> Parser<'a> { } pub fn parse_assert(&mut self) -> Result { + check_recursion_depth!(self); + let condition = self.parse_expr()?; let message = if self.parse_keyword(Keyword::AS) { Some(self.parse_expr()?) @@ -373,12 +420,16 @@ impl<'a> Parser<'a> { } pub fn parse_savepoint(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; Ok(Statement::Savepoint { name }) } /// Parse an expression prefix pub fn parse_prefix(&mut self) -> Result { + check_recursion_depth!(self); + // PostgreSQL allows any string literal to be preceded by a type name, indicating that the // string literal represents a literal of that type. Some examples: // @@ -558,6 +609,7 @@ impl<'a> Parser<'a> { } pub fn parse_function(&mut self, name: ObjectName) -> Result { + check_recursion_depth!(self); self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; let args = self.parse_optional_args()?; @@ -601,6 +653,8 @@ impl<'a> Parser<'a> { } pub fn parse_time_functions(&mut self, name: ObjectName) -> Result { + check_recursion_depth!(self); + let args = if self.consume_token(&Token::LParen) { self.parse_optional_args()? } else { @@ -615,6 +669,8 @@ impl<'a> Parser<'a> { } pub fn parse_window_frame_units(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::ROWS => Ok(WindowFrameUnits::Rows), @@ -627,6 +683,8 @@ impl<'a> Parser<'a> { } pub fn parse_window_frame(&mut self) -> Result { + check_recursion_depth!(self); + let units = self.parse_window_frame_units()?; let (start_bound, end_bound) = if self.parse_keyword(Keyword::BETWEEN) { let start_bound = self.parse_window_frame_bound()?; @@ -645,6 +703,8 @@ impl<'a> Parser<'a> { /// Parse `CURRENT ROW` or `{ | UNBOUNDED } { PRECEDING | FOLLOWING }` pub fn parse_window_frame_bound(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keywords(&[Keyword::CURRENT, Keyword::ROW]) { Ok(WindowFrameBound::CurrentRow) } else { @@ -666,6 +726,8 @@ impl<'a> Parser<'a> { /// parse a group by expr. a group by expr can be one of group sets, roll up, cube, or simple /// expr. fn parse_group_by_expr(&mut self) -> Result { + check_recursion_depth!(self); + if dialect_of!(self is PostgreSqlDialect) { if self.parse_keywords(&[Keyword::GROUPING, Keyword::SETS]) { self.expect_token(&Token::LParen)?; @@ -699,6 +761,8 @@ impl<'a> Parser<'a> { lift_singleton: bool, allow_empty: bool, ) -> Result, ParserError> { + check_recursion_depth!(self); + if lift_singleton { if self.consume_token(&Token::LParen) { let result = if allow_empty && self.consume_token(&Token::RParen) { @@ -726,6 +790,8 @@ impl<'a> Parser<'a> { } pub fn parse_case_expr(&mut self) -> Result { + check_recursion_depth!(self); + let mut operand = None; if !self.parse_keyword(Keyword::WHEN) { operand = Some(Box::new(self.parse_expr()?)); @@ -757,6 +823,8 @@ impl<'a> Parser<'a> { /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; self.expect_keyword(Keyword::AS)?; @@ -770,6 +838,8 @@ impl<'a> Parser<'a> { /// Parse a SQL TRY_CAST function e.g. `TRY_CAST(expr AS FLOAT)` pub fn parse_try_cast_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; self.expect_keyword(Keyword::AS)?; @@ -783,6 +853,8 @@ impl<'a> Parser<'a> { /// Parse a SQL EXISTS expression e.g. `WHERE EXISTS(SELECT ...)`. pub fn parse_exists_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let exists_node = Expr::Exists(Box::new(self.parse_query()?)); self.expect_token(&Token::RParen)?; @@ -790,6 +862,8 @@ impl<'a> Parser<'a> { } pub fn parse_extract_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let field = self.parse_date_time_field()?; self.expect_keyword(Keyword::FROM)?; @@ -820,6 +894,8 @@ impl<'a> Parser<'a> { } pub fn parse_substring_expr(&mut self) -> Result { + check_recursion_depth!(self); + // PARSE SUBSTRING (EXPR [FROM 1] [FOR 3]) self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; @@ -844,6 +920,8 @@ impl<'a> Parser<'a> { /// TRIM (WHERE 'text' FROM 'text')\ /// TRIM ('text') pub fn parse_trim_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let mut where_expr = None; if let Token::Word(word) = self.peek_token() { @@ -867,6 +945,8 @@ impl<'a> Parser<'a> { } pub fn parse_trim_where(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::BOTH => Ok(TrimWhereField::Both), @@ -881,6 +961,8 @@ impl<'a> Parser<'a> { /// Parses an array expression `[ex1, ex2, ..]` /// if `named` is `true`, came from an expression like `ARRAY[ex1, ex2]` pub fn parse_array_expr(&mut self, named: bool) -> Result { + check_recursion_depth!(self); + let exprs = self.parse_comma_separated(Parser::parse_expr)?; self.expect_token(&Token::RBracket)?; Ok(Expr::Array(Array { elem: exprs, named })) @@ -888,6 +970,8 @@ impl<'a> Parser<'a> { /// Parse a SQL LISTAGG expression, e.g. `LISTAGG(...) WITHIN GROUP (ORDER BY ...)`. pub fn parse_listagg_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; let expr = Box::new(self.parse_expr()?); @@ -953,6 +1037,8 @@ impl<'a> Parser<'a> { // date/time fields than interval qualifiers, so this function may need to // be split in two. pub fn parse_date_time_field(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::YEAR => Ok(DateTimeField::Year), @@ -996,6 +1082,8 @@ impl<'a> Parser<'a> { /// /// Note that we do not currently attempt to parse the quoted value. pub fn parse_literal_interval(&mut self) -> Result { + check_recursion_depth!(self); + // The SQL standard allows an optional sign before the value string, but // it is not clear if any implementations support that syntax, so we // don't currently try to parse it. (The sign can instead be included @@ -1080,6 +1168,8 @@ impl<'a> Parser<'a> { /// Parse an operator following an expression pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result { + check_recursion_depth!(self); + let tok = self.next_token(); let regular_binary_operator = match &tok { @@ -1234,6 +1324,8 @@ impl<'a> Parser<'a> { } pub fn parse_array_index(&mut self, expr: Expr) -> Result { + check_recursion_depth!(self); + let index = self.parse_expr()?; self.expect_token(&Token::RBracket)?; let mut indexes: Vec = vec![index]; @@ -1249,6 +1341,8 @@ impl<'a> Parser<'a> { } pub fn parse_map_access(&mut self, expr: Expr) -> Result { + check_recursion_depth!(self); + let key = self.parse_map_key()?; let tok = self.consume_token(&Token::RBracket); debug!("Tok: {}", tok); @@ -1270,6 +1364,8 @@ impl<'a> Parser<'a> { /// Parses the parens following the `[ NOT ] IN` operator pub fn parse_in(&mut self, expr: Expr, negated: bool) -> Result { + check_recursion_depth!(self); + // BigQuery allows `IN UNNEST(array_expression)` // https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#in_operators if self.parse_keyword(Keyword::UNNEST) { @@ -1303,6 +1399,8 @@ impl<'a> Parser<'a> { /// Parses `BETWEEN AND `, assuming the `BETWEEN` keyword was already consumed pub fn parse_between(&mut self, expr: Expr, negated: bool) -> Result { + check_recursion_depth!(self); + // Stop parsing subexpressions for and on tokens with // precedence lower than that of `BETWEEN`, such as `AND`, `IS`, etc. let low = self.parse_subexpr(Self::BETWEEN_PREC)?; @@ -1318,6 +1416,8 @@ impl<'a> Parser<'a> { /// Parse a postgresql casting style which is in the form of `expr::datatype` pub fn parse_pg_cast(&mut self, expr: Expr) -> Result { + check_recursion_depth!(self); + Ok(Expr::Cast { expr: Box::new(expr), data_type: self.parse_data_type()?, @@ -1330,6 +1430,8 @@ impl<'a> Parser<'a> { /// Get the precedence of the next token pub fn get_next_precedence(&self) -> Result { + check_recursion_depth!(self); + let token = self.peek_token(); debug!("get_next_precedence() {:?}", token); match token { @@ -1439,6 +1541,8 @@ impl<'a> Parser<'a> { /// Report unexpected token fn expected(&self, expected: &str, found: Token) -> Result { + check_recursion_depth!(self); + parser_err!(format!("Expected {}, found: {}", expected, found)) } @@ -1488,6 +1592,8 @@ impl<'a> Parser<'a> { /// Bail out if the current token is not one of the expected keywords, or consume it if it is pub fn expect_one_of_keywords(&mut self, keywords: &[Keyword]) -> Result { + check_recursion_depth!(self); + if let Some(keyword) = self.parse_one_of_keywords(keywords) { Ok(keyword) } else { @@ -1501,6 +1607,8 @@ impl<'a> Parser<'a> { /// Bail out if the current token is not an expected keyword, or consume it if it is pub fn expect_keyword(&mut self, expected: Keyword) -> Result<(), ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(expected) { Ok(()) } else { @@ -1511,6 +1619,8 @@ impl<'a> Parser<'a> { /// Bail out if the following tokens are not the expected sequence of /// keywords, or consume them if they are. pub fn expect_keywords(&mut self, expected: &[Keyword]) -> Result<(), ParserError> { + check_recursion_depth!(self); + for &kw in expected { self.expect_keyword(kw)?; } @@ -1530,6 +1640,8 @@ impl<'a> Parser<'a> { /// Bail out if the current token is not an expected keyword, or consume it if it is pub fn expect_token(&mut self, expected: &Token) -> Result<(), ParserError> { + check_recursion_depth!(self); + if self.consume_token(expected) { Ok(()) } else { @@ -1542,6 +1654,8 @@ impl<'a> Parser<'a> { where F: FnMut(&mut Parser<'a>) -> Result, { + check_recursion_depth!(self); + let mut values = vec![]; loop { values.push(f(self)?); @@ -1571,6 +1685,8 @@ impl<'a> Parser<'a> { /// Parse either `ALL` or `DISTINCT`. Returns `true` if `DISTINCT` is parsed and results in a /// `ParserError` if both `ALL` and `DISTINCT` are fround. pub fn parse_all_or_distinct(&mut self) -> Result { + check_recursion_depth!(self); + let all = self.parse_keyword(Keyword::ALL); let distinct = self.parse_keyword(Keyword::DISTINCT); if all && distinct { @@ -1582,6 +1698,8 @@ impl<'a> Parser<'a> { /// Parse a SQL CREATE statement pub fn parse_create(&mut self) -> Result { + check_recursion_depth!(self); + let or_replace = self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]); let local = self.parse_one_of_keywords(&[Keyword::LOCAL]).is_some(); let global = self.parse_one_of_keywords(&[Keyword::GLOBAL]).is_some(); @@ -1626,6 +1744,8 @@ impl<'a> Parser<'a> { /// SQLite-specific `CREATE VIRTUAL TABLE` pub fn parse_create_virtual_table(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; @@ -1645,6 +1765,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_schema(&mut self) -> Result { + check_recursion_depth!(self); + let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let schema_name = self.parse_object_name()?; Ok(Statement::CreateSchema { @@ -1654,6 +1776,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_database(&mut self) -> Result { + check_recursion_depth!(self); + let ine = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let db_name = self.parse_object_name()?; let mut location = None; @@ -1678,6 +1802,8 @@ impl<'a> Parser<'a> { pub fn parse_optional_create_function_using( &mut self, ) -> Result, ParserError> { + check_recursion_depth!(self); + if !self.parse_keyword(Keyword::USING) { return Ok(None); }; @@ -1698,6 +1824,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_function(&mut self, temporary: bool) -> Result { + check_recursion_depth!(self); + let name = self.parse_object_name()?; self.expect_keyword(Keyword::AS)?; let class_name = self.parse_literal_string()?; @@ -1715,6 +1843,8 @@ impl<'a> Parser<'a> { &mut self, or_replace: bool, ) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; @@ -1759,6 +1889,8 @@ impl<'a> Parser<'a> { } pub fn parse_file_format(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::AVRO => Ok(FileFormat::AVRO), @@ -1775,6 +1907,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_view(&mut self, or_replace: bool) -> Result { + check_recursion_depth!(self); + let materialized = self.parse_keyword(Keyword::MATERIALIZED); self.expect_keyword(Keyword::VIEW)?; // Many dialects support `OR ALTER` right after `CREATE`, but we don't (yet). @@ -1796,6 +1930,8 @@ impl<'a> Parser<'a> { } pub fn parse_drop(&mut self) -> Result { + check_recursion_depth!(self); + let object_type = if self.parse_keyword(Keyword::TABLE) { ObjectType::Table } else if self.parse_keyword(Keyword::VIEW) { @@ -1829,6 +1965,8 @@ impl<'a> Parser<'a> { /// DECLARE name [ BINARY ] [ ASENSITIVE | INSENSITIVE ] [ [ NO ] SCROLL ] // CURSOR [ { WITH | WITHOUT } HOLD ] FOR query pub fn parse_declare(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let binary = self.parse_keyword(Keyword::BINARY); @@ -1878,6 +2016,8 @@ impl<'a> Parser<'a> { // FETCH [ direction { FROM | IN } ] cursor INTO target; pub fn parse_fetch_statement(&mut self) -> Result { + check_recursion_depth!(self); + let direction = if self.parse_keyword(Keyword::NEXT) { FetchDirection::Next } else if self.parse_keyword(Keyword::PRIOR) { @@ -1938,6 +2078,8 @@ impl<'a> Parser<'a> { } pub fn parse_discard(&mut self) -> Result { + check_recursion_depth!(self); + let object_type = if self.parse_keyword(Keyword::ALL) { DiscardObject::ALL } else if self.parse_keyword(Keyword::PLANS) { @@ -1956,6 +2098,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_index(&mut self, unique: bool) -> Result { + check_recursion_depth!(self); + let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let index_name = self.parse_object_name()?; self.expect_keyword(Keyword::ON)?; @@ -1974,6 +2118,8 @@ impl<'a> Parser<'a> { //TODO: Implement parsing for Skewed and Clustered pub fn parse_hive_distribution(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keywords(&[Keyword::PARTITIONED, Keyword::BY]) { self.expect_token(&Token::LParen)?; let columns = self.parse_comma_separated(Parser::parse_column_def)?; @@ -1985,6 +2131,8 @@ impl<'a> Parser<'a> { } pub fn parse_hive_formats(&mut self) -> Result { + check_recursion_depth!(self); + let mut hive_format = HiveFormat::default(); loop { match self.parse_one_of_keywords(&[Keyword::ROW, Keyword::STORED, Keyword::LOCATION]) { @@ -2018,6 +2166,8 @@ impl<'a> Parser<'a> { } pub fn parse_row_format(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::FORMAT)?; match self.parse_one_of_keywords(&[Keyword::SERDE, Keyword::DELIMITED]) { Some(Keyword::SERDE) => { @@ -2034,6 +2184,8 @@ impl<'a> Parser<'a> { temporary: bool, global: Option, ) -> Result { + check_recursion_depth!(self); + let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; let like = if self.parse_keyword(Keyword::LIKE) || self.parse_keyword(Keyword::ILIKE) { @@ -2132,6 +2284,8 @@ impl<'a> Parser<'a> { } pub fn parse_columns(&mut self) -> Result<(Vec, Vec), ParserError> { + check_recursion_depth!(self); + let mut columns = vec![]; let mut constraints = vec![]; if !self.consume_token(&Token::LParen) || self.consume_token(&Token::RParen) { @@ -2159,6 +2313,8 @@ impl<'a> Parser<'a> { } pub fn parse_column_def(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let data_type = self.parse_data_type()?; let collation = if self.parse_keyword(Keyword::COLLATE) { @@ -2193,6 +2349,8 @@ impl<'a> Parser<'a> { } pub fn parse_optional_column_option(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keywords(&[Keyword::CHARACTER, Keyword::SET]) { Ok(Some(ColumnOption::CharacterSet(self.parse_object_name()?))) } else if self.parse_keywords(&[Keyword::NOT, Keyword::NULL]) { @@ -2259,6 +2417,8 @@ impl<'a> Parser<'a> { } pub fn parse_referential_action(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::RESTRICT) { Ok(ReferentialAction::Restrict) } else if self.parse_keyword(Keyword::CASCADE) { @@ -2280,6 +2440,8 @@ impl<'a> Parser<'a> { pub fn parse_optional_table_constraint( &mut self, ) -> Result, ParserError> { + check_recursion_depth!(self); + let name = if self.parse_keyword(Keyword::CONSTRAINT) { Some(self.parse_identifier()?) } else { @@ -2344,6 +2506,8 @@ impl<'a> Parser<'a> { } pub fn parse_options(&mut self, keyword: Keyword) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(keyword) { self.expect_token(&Token::LParen)?; let options = self.parse_comma_separated(Parser::parse_sql_option)?; @@ -2355,6 +2519,8 @@ impl<'a> Parser<'a> { } pub fn parse_sql_option(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; self.expect_token(&Token::Eq)?; let value = self.parse_value()?; @@ -2362,6 +2528,8 @@ impl<'a> Parser<'a> { } pub fn parse_alter(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let _ = self.parse_keyword(Keyword::ONLY); let table_name = self.parse_object_name()?; @@ -2516,6 +2684,8 @@ impl<'a> Parser<'a> { /// Parse a copy statement pub fn parse_copy(&mut self) -> Result { + check_recursion_depth!(self); + let table_name = self.parse_object_name()?; let columns = self.parse_parenthesized_column_list(Optional)?; let to = match self.parse_one_of_keywords(&[Keyword::FROM, Keyword::TO]) { @@ -2713,6 +2883,8 @@ impl<'a> Parser<'a> { /// Parse a literal value (numbers, strings, date/time, booleans) pub fn parse_value(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::TRUE => Ok(Value::Boolean(true)), @@ -2742,6 +2914,8 @@ impl<'a> Parser<'a> { } pub fn parse_number_value(&mut self) -> Result { + check_recursion_depth!(self); + match self.parse_value()? { v @ Value::Number(_, _) => Ok(v), v @ Value::Placeholder(_) => Ok(v), @@ -2754,6 +2928,8 @@ impl<'a> Parser<'a> { /// Parse an unsigned literal integer/long pub fn parse_literal_uint(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Number(s, _) => s.parse::().map_err(|e| { ParserError::ParserError(format!("Could not parse '{}' as u64: {}", s, e)) @@ -2764,6 +2940,8 @@ impl<'a> Parser<'a> { /// Parse a literal string pub fn parse_literal_string(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(Word { value, keyword, .. }) if keyword == Keyword::NoKeyword => Ok(value), Token::SingleQuotedString(s) => Ok(s), @@ -2776,6 +2954,8 @@ impl<'a> Parser<'a> { /// Parse a map key string pub fn parse_map_key(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(Word { value, keyword, .. }) if keyword == Keyword::NoKeyword => { if self.peek_token() == Token::LParen { @@ -2794,6 +2974,8 @@ impl<'a> Parser<'a> { /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) pub fn parse_data_type(&mut self) -> Result { + check_recursion_depth!(self); + let mut data = match self.next_token() { Token::Word(w) => match w.keyword { Keyword::BOOLEAN => Ok(DataType::Boolean), @@ -2903,6 +3085,8 @@ impl<'a> Parser<'a> { } pub fn parse_string_values(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let mut values = Vec::new(); loop { @@ -2926,6 +3110,8 @@ impl<'a> Parser<'a> { &mut self, reserved_kwds: &[Keyword], ) -> Result, ParserError> { + check_recursion_depth!(self); + let after_as = self.parse_keyword(Keyword::AS); match self.next_token() { // Accept any identifier after `AS` (though many dialects have restrictions on @@ -2967,6 +3153,8 @@ impl<'a> Parser<'a> { &mut self, reserved_kwds: &[Keyword], ) -> Result, ParserError> { + check_recursion_depth!(self); + match self.parse_optional_alias(reserved_kwds)? { Some(name) => { let columns = self.parse_parenthesized_column_list(Optional)?; @@ -2979,6 +3167,8 @@ impl<'a> Parser<'a> { /// Parse a possibly qualified, possibly quoted identifier, e.g. /// `foo` or `myschema."table" pub fn parse_object_name(&mut self) -> Result { + check_recursion_depth!(self); + let mut idents = vec![]; loop { idents.push(self.parse_identifier()?); @@ -2991,6 +3181,8 @@ impl<'a> Parser<'a> { /// Parse identifiers strictly i.e. don't parse keywords pub fn parse_identifiers_non_keywords(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut idents = vec![]; loop { match self.peek_token() { @@ -3013,6 +3205,8 @@ impl<'a> Parser<'a> { /// Parse identifiers pub fn parse_identifiers(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut idents = vec![]; loop { match self.next_token() { @@ -3029,6 +3223,8 @@ impl<'a> Parser<'a> { /// Parse a simple one-word identifier (possibly quoted, possibly a keyword) pub fn parse_identifier(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => Ok(w.to_ident()), Token::SingleQuotedString(s) => Ok(Ident::with_quote('\'', s)), @@ -3041,6 +3237,8 @@ impl<'a> Parser<'a> { &mut self, optional: IsOptional, ) -> Result, ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::LParen) { let cols = self.parse_comma_separated(Parser::parse_identifier)?; self.expect_token(&Token::RParen)?; @@ -3053,6 +3251,8 @@ impl<'a> Parser<'a> { } pub fn parse_optional_precision(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::LParen) { let n = self.parse_literal_uint()?; self.expect_token(&Token::RParen)?; @@ -3065,6 +3265,8 @@ impl<'a> Parser<'a> { pub fn parse_optional_precision_scale( &mut self, ) -> Result<(Option, Option), ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::LParen) { let n = self.parse_literal_uint()?; let scale = if self.consume_token(&Token::Comma) { @@ -3080,6 +3282,8 @@ impl<'a> Parser<'a> { } pub fn parse_delete(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::FROM)?; let table_name = self.parse_object_name()?; let selection = if self.parse_keyword(Keyword::WHERE) { @@ -3121,6 +3325,8 @@ impl<'a> Parser<'a> { } pub fn parse_explain(&mut self, describe_alias: bool) -> Result { + check_recursion_depth!(self); + let analyze = self.parse_keyword(Keyword::ANALYZE); let verbose = self.parse_keyword(Keyword::VERBOSE); @@ -3146,6 +3352,8 @@ impl<'a> Parser<'a> { /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed pub fn parse_query(&mut self) -> Result { + check_recursion_depth!(self); + let with = if self.parse_keyword(Keyword::WITH) { Some(With { recursive: self.parse_keyword(Keyword::RECURSIVE), @@ -3220,6 +3428,8 @@ impl<'a> Parser<'a> { /// Parse a CTE (`alias [( col1, col2, ... )] AS (subquery)`) pub fn parse_cte(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let mut cte = if self.parse_keyword(Keyword::AS) { @@ -3263,6 +3473,8 @@ impl<'a> Parser<'a> { /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body /// ``` pub fn parse_query_body(&mut self, precedence: u8) -> Result { + check_recursion_depth!(self); + // We parse the expression using a Pratt parser, as in `parse_expr()`. // Start by parsing a restricted SELECT or a `(subquery)`: let mut expr = if self.parse_keyword(Keyword::SELECT) { @@ -3319,6 +3531,8 @@ impl<'a> Parser<'a> { /// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`), /// assuming the initial `SELECT` was already consumed pub fn parse_select(&mut self) -> Result { + check_recursion_depth!(self); + let distinct = self.parse_all_or_distinct()?; let top = if self.parse_keyword(Keyword::TOP) { @@ -3448,6 +3662,8 @@ impl<'a> Parser<'a> { } pub fn parse_set(&mut self) -> Result { + check_recursion_depth!(self); + let modifier = self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]); if let Some(Keyword::HIVEVAR) = modifier { @@ -3523,6 +3739,8 @@ impl<'a> Parser<'a> { } pub fn parse_show(&mut self) -> Result { + check_recursion_depth!(self); + if self .parse_one_of_keywords(&[ Keyword::EXTENDED, @@ -3544,6 +3762,8 @@ impl<'a> Parser<'a> { } pub fn parse_show_create(&mut self) -> Result { + check_recursion_depth!(self); + let obj_type = match self.expect_one_of_keywords(&[ Keyword::TABLE, Keyword::TRIGGER, @@ -3568,6 +3788,8 @@ impl<'a> Parser<'a> { } pub fn parse_show_columns(&mut self) -> Result { + check_recursion_depth!(self); + let extended = self.parse_keyword(Keyword::EXTENDED); let full = self.parse_keyword(Keyword::FULL); self.expect_one_of_keywords(&[Keyword::COLUMNS, Keyword::FIELDS])?; @@ -3588,6 +3810,8 @@ impl<'a> Parser<'a> { pub fn parse_show_statement_filter( &mut self, ) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::LIKE) { Ok(Some(ShowStatementFilter::Like( self.parse_literal_string()?, @@ -3604,6 +3828,8 @@ impl<'a> Parser<'a> { } pub fn parse_table_and_joins(&mut self) -> Result { + check_recursion_depth!(self); + let relation = self.parse_table_factor()?; // Note that for keywords to be properly handled here, they need to be // added to `RESERVED_FOR_TABLE_ALIAS`, otherwise they may be parsed as @@ -3677,6 +3903,8 @@ impl<'a> Parser<'a> { /// A table name or a parenthesized subquery, followed by optional `[AS] alias` pub fn parse_table_factor(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::LATERAL) { // LATERAL must always be followed by a subquery. if !self.consume_token(&Token::LParen) { @@ -3829,6 +4057,8 @@ impl<'a> Parser<'a> { &mut self, lateral: IsLateral, ) -> Result { + check_recursion_depth!(self); + let subquery = Box::new(self.parse_query()?); self.expect_token(&Token::RParen)?; let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; @@ -3843,6 +4073,8 @@ impl<'a> Parser<'a> { } pub fn parse_join_constraint(&mut self, natural: bool) -> Result { + check_recursion_depth!(self); + if natural { Ok(JoinConstraint::Natural) } else if self.parse_keyword(Keyword::ON) { @@ -3859,6 +4091,8 @@ impl<'a> Parser<'a> { /// Parse a GRANT statement. pub fn parse_grant(&mut self) -> Result { + check_recursion_depth!(self); + let (privileges, objects) = self.parse_grant_revoke_privileges_objects()?; self.expect_keyword(Keyword::TO)?; @@ -3883,6 +4117,8 @@ impl<'a> Parser<'a> { pub fn parse_grant_revoke_privileges_objects( &mut self, ) -> Result<(Privileges, GrantObjects), ParserError> { + check_recursion_depth!(self); + let privileges = if self.parse_keyword(Keyword::ALL) { Privileges::All { with_privileges_keyword: self.parse_keyword(Keyword::PRIVILEGES), @@ -3958,6 +4194,8 @@ impl<'a> Parser<'a> { } pub fn parse_grant_permission(&mut self) -> Result<(Keyword, Option>), ParserError> { + check_recursion_depth!(self); + if let Some(kw) = self.parse_one_of_keywords(&[ Keyword::CONNECT, Keyword::CREATE, @@ -3991,6 +4229,8 @@ impl<'a> Parser<'a> { /// Parse a REVOKE statement pub fn parse_revoke(&mut self) -> Result { + check_recursion_depth!(self); + let (privileges, objects) = self.parse_grant_revoke_privileges_objects()?; self.expect_keyword(Keyword::FROM)?; @@ -4017,6 +4257,8 @@ impl<'a> Parser<'a> { /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { + check_recursion_depth!(self); + let or = if !dialect_of!(self is SQLiteDialect) { None } else if self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]) { @@ -4102,6 +4344,8 @@ impl<'a> Parser<'a> { } pub fn parse_update(&mut self) -> Result { + check_recursion_depth!(self); + let table = self.parse_table_and_joins()?; self.expect_keyword(Keyword::SET)?; let assignments = self.parse_comma_separated(Parser::parse_assignment)?; @@ -4125,6 +4369,8 @@ impl<'a> Parser<'a> { /// Parse a `var = expr` assignment, used in an UPDATE statement pub fn parse_assignment(&mut self) -> Result { + check_recursion_depth!(self); + let id = self.parse_identifiers_non_keywords()?; self.expect_token(&Token::Eq)?; let value = self.parse_expr()?; @@ -4132,6 +4378,8 @@ impl<'a> Parser<'a> { } pub fn parse_function_args(&mut self) -> Result { + check_recursion_depth!(self); + if self.peek_nth_token(1) == Token::RArrow { let name = self.parse_identifier()?; @@ -4145,6 +4393,8 @@ impl<'a> Parser<'a> { } pub fn parse_optional_args(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::RParen) { Ok(vec![]) } else { @@ -4156,6 +4406,8 @@ impl<'a> Parser<'a> { /// Parse a comma-delimited list of projections after SELECT pub fn parse_select_item(&mut self) -> Result { + check_recursion_depth!(self); + match self.parse_wildcard_expr()? { WildcardExpr::Expr(expr) => self .parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS) @@ -4170,6 +4422,8 @@ impl<'a> Parser<'a> { /// Parse an expression, optionally followed by ASC or DESC (used in ORDER BY) pub fn parse_order_by_expr(&mut self) -> Result { + check_recursion_depth!(self); + let expr = self.parse_expr()?; let asc = if self.parse_keyword(Keyword::ASC) { @@ -4198,6 +4452,8 @@ impl<'a> Parser<'a> { /// Parse a TOP clause, MSSQL equivalent of LIMIT, /// that follows after SELECT [DISTINCT]. pub fn parse_top(&mut self) -> Result { + check_recursion_depth!(self); + let quantity = if self.consume_token(&Token::LParen) { let quantity = self.parse_expr()?; self.expect_token(&Token::RParen)?; @@ -4219,6 +4475,8 @@ impl<'a> Parser<'a> { /// Parse a LIMIT clause pub fn parse_limit(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::ALL) { Ok(None) } else { @@ -4228,6 +4486,8 @@ impl<'a> Parser<'a> { /// Parse an OFFSET clause pub fn parse_offset(&mut self) -> Result { + check_recursion_depth!(self); + let value = Expr::Value(self.parse_number_value()?); let rows = if self.parse_keyword(Keyword::ROW) { OffsetRows::Row @@ -4241,6 +4501,8 @@ impl<'a> Parser<'a> { /// Parse a FETCH clause pub fn parse_fetch(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_one_of_keywords(&[Keyword::FIRST, Keyword::NEXT])?; let (quantity, percent) = if self .parse_one_of_keywords(&[Keyword::ROW, Keyword::ROWS]) @@ -4269,6 +4531,8 @@ impl<'a> Parser<'a> { /// Parse a FOR UPDATE/FOR SHARE clause pub fn parse_lock(&mut self) -> Result { + check_recursion_depth!(self); + match self.expect_one_of_keywords(&[Keyword::UPDATE, Keyword::SHARE])? { Keyword::UPDATE => Ok(LockType::Update), Keyword::SHARE => Ok(LockType::Share), @@ -4277,6 +4541,8 @@ impl<'a> Parser<'a> { } pub fn parse_values(&mut self) -> Result { + check_recursion_depth!(self); + let values = self.parse_comma_separated(|parser| { parser.expect_token(&Token::LParen)?; let exprs = parser.parse_comma_separated(Parser::parse_expr)?; @@ -4287,6 +4553,8 @@ impl<'a> Parser<'a> { } pub fn parse_start_transaction(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TRANSACTION)?; Ok(Statement::StartTransaction { modes: self.parse_transaction_modes()?, @@ -4294,6 +4562,8 @@ impl<'a> Parser<'a> { } pub fn parse_begin(&mut self) -> Result { + check_recursion_depth!(self); + let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); Ok(Statement::StartTransaction { modes: self.parse_transaction_modes()?, @@ -4301,6 +4571,8 @@ impl<'a> Parser<'a> { } pub fn parse_transaction_modes(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut modes = vec![]; let mut required = false; loop { @@ -4337,18 +4609,24 @@ impl<'a> Parser<'a> { } pub fn parse_commit(&mut self) -> Result { + check_recursion_depth!(self); + Ok(Statement::Commit { chain: self.parse_commit_rollback_chain()?, }) } pub fn parse_rollback(&mut self) -> Result { + check_recursion_depth!(self); + Ok(Statement::Rollback { chain: self.parse_commit_rollback_chain()?, }) } pub fn parse_commit_rollback_chain(&mut self) -> Result { + check_recursion_depth!(self); + let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); if self.parse_keyword(Keyword::AND) { let chain = !self.parse_keyword(Keyword::NO); @@ -4360,12 +4638,16 @@ impl<'a> Parser<'a> { } pub fn parse_deallocate(&mut self) -> Result { + check_recursion_depth!(self); + let prepare = self.parse_keyword(Keyword::PREPARE); let name = self.parse_identifier()?; Ok(Statement::Deallocate { name, prepare }) } pub fn parse_execute(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let mut parameters = vec![]; @@ -4378,6 +4660,8 @@ impl<'a> Parser<'a> { } pub fn parse_prepare(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let mut data_types = vec![]; @@ -4396,6 +4680,8 @@ impl<'a> Parser<'a> { } pub fn parse_comment(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::ON)?; let token = self.next_token(); @@ -4425,6 +4711,8 @@ impl<'a> Parser<'a> { } pub fn parse_merge_clauses(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut clauses: Vec = vec![]; loop { if self.peek_token() == Token::EOF || self.peek_token() == Token::SemiColon { @@ -4502,6 +4790,8 @@ impl<'a> Parser<'a> { } pub fn parse_merge(&mut self) -> Result { + check_recursion_depth!(self); + let into = self.parse_keyword(Keyword::INTO); let table = self.parse_table_factor()?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index bc715a096..125506598 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -4944,3 +4944,10 @@ fn parse_discard() { _ => unreachable!(), } } + +#[test] +fn parse_with_lots_of_stack_recursion() { + let sql = "(".repeat(1000); + let res = parse_sql_statements(&sql); + assert_eq!(ParserError::RecursionLimitExceeded, res.unwrap_err()); +}