From b7cd371296ea6355376ce23912e435aae61956a8 Mon Sep 17 00:00:00 2001 From: Miki Mokrysz Date: Wed, 18 May 2022 14:09:31 +0100 Subject: [PATCH] Add a recursion limit to prevent stack overflows Until now, it's been able to trigger a stack overflow crash by providing a string with excessive recursion. For instance a string of 1000 left brackets causes the parser to recurse down 1000 times, and overflow the stack. This commit adds protection against excessive recursion. It adds a field to `Parser` for tracking the current recursion depth. Every function that returns a `Result` gains a recursion depth check. This isn't quite every method on the `Parser`, but it's the vast majority. An alternative implemention would be to only protect against AST recursions, rather than recursive function calls in `Parser`. That isn't as easy to implement because the parser is so large. --- Cargo.toml | 3 +- examples/parse_select.rs | 5 +- src/parser.rs | 280 ++++++++++++++++++++++++++++++++++++++ tests/sqlparser_common.rs | 10 ++ 4 files changed, 293 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 00b4d1a1c..05f3eba23 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ path = "src/lib.rs" [features] default = ["std"] -std = [] +std = ["scopeguard"] # Enable JSON output in the `cli` example: json_example = ["serde_json", "serde"] @@ -32,6 +32,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } # of dev-dependencies because of # https://github.com/rust-lang/cargo/issues/1596 serde_json = { version = "1.0", optional = true } +scopeguard = { version = "1.1.0", optional = true } [dev-dependencies] simple_logger = "2.1" diff --git a/examples/parse_select.rs b/examples/parse_select.rs index e7aa16307..bde4e426c 100644 --- a/examples/parse_select.rs +++ b/examples/parse_select.rs @@ -16,10 +16,7 @@ use sqlparser::dialect::GenericDialect; use sqlparser::parser::*; fn main() { - let sql = "SELECT a, b, 123, myfunc(b) \ - FROM table_1 \ - WHERE a > b AND b < 100 \ - ORDER BY a DESC, b"; + let sql = "(((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((("; let dialect = GenericDialect {}; diff --git a/src/parser.rs b/src/parser.rs index 5ee3d5cb6..3ddc6d1a5 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -21,8 +21,14 @@ use alloc::{ vec::Vec, }; use core::fmt; +#[cfg(feature = "std")] +use std::rc::Rc; +#[cfg(feature = "std")] +use std::cell::Cell; use log::debug; +#[cfg(feature = "std")] +use scopeguard::defer; use crate::ast::*; use crate::dialect::*; @@ -33,6 +39,7 @@ use crate::tokenizer::*; pub enum ParserError { TokenizerError(String), ParserError(String), + RecursionLimitExceeded, } // Use `Parser::expected` instead, if possible @@ -51,6 +58,25 @@ macro_rules! return_ok_if_some { }}; } +#[cfg(feature = "std")] +macro_rules! check_recursion_depth { + ($this:ident) => { + let remaining_depth = $this.remaining_depth.clone(); + remaining_depth.set(remaining_depth.get().saturating_sub(1)); + if remaining_depth.get() == 0 { + return Err(ParserError::RecursionLimitExceeded); + } + defer! { + remaining_depth.set(remaining_depth.get() + 1); + } + }; +} + +#[cfg(not(feature = "std"))] +macro_rules! check_recursion_depth { + ($this:ident) => {}; +} + #[derive(PartialEq)] pub enum IsOptional { Optional, @@ -96,6 +122,7 @@ impl fmt::Display for ParserError { match self { ParserError::TokenizerError(s) => s, ParserError::ParserError(s) => s, + ParserError::RecursionLimitExceeded => "recursion limit exceeded", } ) } @@ -109,6 +136,8 @@ pub struct Parser<'a> { /// The index of the first unprocessed token in `self.tokens` index: usize, dialect: &'a dyn Dialect, + #[cfg(feature = "std")] + remaining_depth: Rc>, } impl<'a> Parser<'a> { @@ -118,6 +147,8 @@ impl<'a> Parser<'a> { tokens, index: 0, dialect, + #[cfg(feature = "std")] + remaining_depth: Rc::new(Cell::new(96)), } } @@ -152,6 +183,8 @@ impl<'a> Parser<'a> { /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// stopping before the statement separator, if any. pub fn parse_statement(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::KILL => Ok(self.parse_kill()?), @@ -208,6 +241,8 @@ impl<'a> Parser<'a> { } pub fn parse_msck(&mut self) -> Result { + check_recursion_depth!(self); + let repair = self.parse_keyword(Keyword::REPAIR); self.expect_keyword(Keyword::TABLE)?; let table_name = self.parse_object_name()?; @@ -235,6 +270,8 @@ impl<'a> Parser<'a> { } pub fn parse_truncate(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let table_name = self.parse_object_name()?; let mut partitions = None; @@ -250,6 +287,8 @@ impl<'a> Parser<'a> { } pub fn parse_analyze(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let table_name = self.parse_object_name()?; let mut for_columns = false; @@ -307,6 +346,8 @@ impl<'a> Parser<'a> { /// Parse a new expression including wildcard & qualified wildcard pub fn parse_wildcard_expr(&mut self) -> Result { + check_recursion_depth!(self); + let index = self.index; match self.next_token() { @@ -337,11 +378,15 @@ impl<'a> Parser<'a> { /// Parse a new expression pub fn parse_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.parse_subexpr(0) } /// Parse tokens until the precedence changes pub fn parse_subexpr(&mut self, precedence: u8) -> Result { + check_recursion_depth!(self); + debug!("parsing expr"); let mut expr = self.parse_prefix()?; debug!("prefix: {:?}", expr); @@ -359,6 +404,8 @@ impl<'a> Parser<'a> { } pub fn parse_assert(&mut self) -> Result { + check_recursion_depth!(self); + let condition = self.parse_expr()?; let message = if self.parse_keyword(Keyword::AS) { Some(self.parse_expr()?) @@ -370,12 +417,16 @@ impl<'a> Parser<'a> { } pub fn parse_savepoint(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; Ok(Statement::Savepoint { name }) } /// Parse an expression prefix pub fn parse_prefix(&mut self) -> Result { + check_recursion_depth!(self); + // PostgreSQL allows any string literal to be preceded by a type name, indicating that the // string literal represents a literal of that type. Some examples: // @@ -549,6 +600,7 @@ impl<'a> Parser<'a> { } pub fn parse_function(&mut self, name: ObjectName) -> Result { + check_recursion_depth!(self); self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; let args = self.parse_optional_args()?; @@ -592,6 +644,8 @@ impl<'a> Parser<'a> { } pub fn parse_time_functions(&mut self, name: ObjectName) -> Result { + check_recursion_depth!(self); + let args = if self.consume_token(&Token::LParen) { self.parse_optional_args()? } else { @@ -606,6 +660,8 @@ impl<'a> Parser<'a> { } pub fn parse_window_frame_units(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::ROWS => Ok(WindowFrameUnits::Rows), @@ -618,6 +674,8 @@ impl<'a> Parser<'a> { } pub fn parse_window_frame(&mut self) -> Result { + check_recursion_depth!(self); + let units = self.parse_window_frame_units()?; let (start_bound, end_bound) = if self.parse_keyword(Keyword::BETWEEN) { let start_bound = self.parse_window_frame_bound()?; @@ -636,6 +694,8 @@ impl<'a> Parser<'a> { /// Parse `CURRENT ROW` or `{ | UNBOUNDED } { PRECEDING | FOLLOWING }` pub fn parse_window_frame_bound(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keywords(&[Keyword::CURRENT, Keyword::ROW]) { Ok(WindowFrameBound::CurrentRow) } else { @@ -657,6 +717,8 @@ impl<'a> Parser<'a> { /// parse a group by expr. a group by expr can be one of group sets, roll up, cube, or simple /// expr. fn parse_group_by_expr(&mut self) -> Result { + check_recursion_depth!(self); + if dialect_of!(self is PostgreSqlDialect) { if self.parse_keywords(&[Keyword::GROUPING, Keyword::SETS]) { self.expect_token(&Token::LParen)?; @@ -690,6 +752,8 @@ impl<'a> Parser<'a> { lift_singleton: bool, allow_empty: bool, ) -> Result, ParserError> { + check_recursion_depth!(self); + if lift_singleton { if self.consume_token(&Token::LParen) { let result = if allow_empty && self.consume_token(&Token::RParen) { @@ -717,6 +781,8 @@ impl<'a> Parser<'a> { } pub fn parse_case_expr(&mut self) -> Result { + check_recursion_depth!(self); + let mut operand = None; if !self.parse_keyword(Keyword::WHEN) { operand = Some(Box::new(self.parse_expr()?)); @@ -748,6 +814,8 @@ impl<'a> Parser<'a> { /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; self.expect_keyword(Keyword::AS)?; @@ -761,6 +829,8 @@ impl<'a> Parser<'a> { /// Parse a SQL TRY_CAST function e.g. `TRY_CAST(expr AS FLOAT)` pub fn parse_try_cast_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; self.expect_keyword(Keyword::AS)?; @@ -774,6 +844,8 @@ impl<'a> Parser<'a> { /// Parse a SQL EXISTS expression e.g. `WHERE EXISTS(SELECT ...)`. pub fn parse_exists_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let exists_node = Expr::Exists(Box::new(self.parse_query()?)); self.expect_token(&Token::RParen)?; @@ -781,6 +853,8 @@ impl<'a> Parser<'a> { } pub fn parse_extract_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let field = self.parse_date_time_field()?; self.expect_keyword(Keyword::FROM)?; @@ -811,6 +885,8 @@ impl<'a> Parser<'a> { } pub fn parse_substring_expr(&mut self) -> Result { + check_recursion_depth!(self); + // PARSE SUBSTRING (EXPR [FROM 1] [FOR 3]) self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; @@ -835,6 +911,8 @@ impl<'a> Parser<'a> { /// TRIM (WHERE 'text' FROM 'text')\ /// TRIM ('text') pub fn parse_trim_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let mut where_expr = None; if let Token::Word(word) = self.peek_token() { @@ -858,6 +936,8 @@ impl<'a> Parser<'a> { } pub fn parse_trim_where(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::BOTH => Ok(TrimWhereField::Both), @@ -872,6 +952,8 @@ impl<'a> Parser<'a> { /// Parses an array expression `[ex1, ex2, ..]` /// if `named` is `true`, came from an expression like `ARRAY[ex1, ex2]` pub fn parse_array_expr(&mut self, named: bool) -> Result { + check_recursion_depth!(self); + let exprs = self.parse_comma_separated(Parser::parse_expr)?; self.expect_token(&Token::RBracket)?; Ok(Expr::Array(Array { elem: exprs, named })) @@ -879,6 +961,8 @@ impl<'a> Parser<'a> { /// Parse a SQL LISTAGG expression, e.g. `LISTAGG(...) WITHIN GROUP (ORDER BY ...)`. pub fn parse_listagg_expr(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; let expr = Box::new(self.parse_expr()?); @@ -943,6 +1027,8 @@ impl<'a> Parser<'a> { // date/time fields than interval qualifiers, so this function may need to // be split in two. pub fn parse_date_time_field(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::YEAR => Ok(DateTimeField::Year), @@ -986,6 +1072,8 @@ impl<'a> Parser<'a> { /// /// Note that we do not currently attempt to parse the quoted value. pub fn parse_literal_interval(&mut self) -> Result { + check_recursion_depth!(self); + // The SQL standard allows an optional sign before the value string, but // it is not clear if any implementations support that syntax, so we // don't currently try to parse it. (The sign can instead be included @@ -1070,6 +1158,8 @@ impl<'a> Parser<'a> { /// Parse an operator following an expression pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result { + check_recursion_depth!(self); + let tok = self.next_token(); let regular_binary_operator = match &tok { @@ -1232,6 +1322,8 @@ impl<'a> Parser<'a> { } pub fn parse_array_index(&mut self, expr: Expr) -> Result { + check_recursion_depth!(self); + let index = self.parse_expr()?; self.expect_token(&Token::RBracket)?; let mut indexs: Vec = vec![index]; @@ -1247,6 +1339,8 @@ impl<'a> Parser<'a> { } pub fn parse_map_access(&mut self, expr: Expr) -> Result { + check_recursion_depth!(self); + let key = self.parse_map_key()?; let tok = self.consume_token(&Token::RBracket); debug!("Tok: {}", tok); @@ -1268,6 +1362,8 @@ impl<'a> Parser<'a> { /// Parses the parens following the `[ NOT ] IN` operator pub fn parse_in(&mut self, expr: Expr, negated: bool) -> Result { + check_recursion_depth!(self); + // BigQuery allows `IN UNNEST(array_expression)` // https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#in_operators if self.parse_keyword(Keyword::UNNEST) { @@ -1301,6 +1397,8 @@ impl<'a> Parser<'a> { /// Parses `BETWEEN AND `, assuming the `BETWEEN` keyword was already consumed pub fn parse_between(&mut self, expr: Expr, negated: bool) -> Result { + check_recursion_depth!(self); + // Stop parsing subexpressions for and on tokens with // precedence lower than that of `BETWEEN`, such as `AND`, `IS`, etc. let low = self.parse_subexpr(Self::BETWEEN_PREC)?; @@ -1316,6 +1414,8 @@ impl<'a> Parser<'a> { /// Parse a postgresql casting style which is in the form of `expr::datatype` pub fn parse_pg_cast(&mut self, expr: Expr) -> Result { + check_recursion_depth!(self); + Ok(Expr::Cast { expr: Box::new(expr), data_type: self.parse_data_type()?, @@ -1328,6 +1428,8 @@ impl<'a> Parser<'a> { /// Get the precedence of the next token pub fn get_next_precedence(&self) -> Result { + check_recursion_depth!(self); + let token = self.peek_token(); debug!("get_next_precedence() {:?}", token); match token { @@ -1437,6 +1539,8 @@ impl<'a> Parser<'a> { /// Report unexpected token fn expected(&self, expected: &str, found: Token) -> Result { + check_recursion_depth!(self); + parser_err!(format!("Expected {}, found: {}", expected, found)) } @@ -1486,6 +1590,8 @@ impl<'a> Parser<'a> { /// Bail out if the current token is not one of the expected keywords, or consume it if it is pub fn expect_one_of_keywords(&mut self, keywords: &[Keyword]) -> Result { + check_recursion_depth!(self); + if let Some(keyword) = self.parse_one_of_keywords(keywords) { Ok(keyword) } else { @@ -1499,6 +1605,8 @@ impl<'a> Parser<'a> { /// Bail out if the current token is not an expected keyword, or consume it if it is pub fn expect_keyword(&mut self, expected: Keyword) -> Result<(), ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(expected) { Ok(()) } else { @@ -1509,6 +1617,8 @@ impl<'a> Parser<'a> { /// Bail out if the following tokens are not the expected sequence of /// keywords, or consume them if they are. pub fn expect_keywords(&mut self, expected: &[Keyword]) -> Result<(), ParserError> { + check_recursion_depth!(self); + for &kw in expected { self.expect_keyword(kw)?; } @@ -1528,6 +1638,8 @@ impl<'a> Parser<'a> { /// Bail out if the current token is not an expected keyword, or consume it if it is pub fn expect_token(&mut self, expected: &Token) -> Result<(), ParserError> { + check_recursion_depth!(self); + if self.consume_token(expected) { Ok(()) } else { @@ -1540,6 +1652,8 @@ impl<'a> Parser<'a> { where F: FnMut(&mut Parser<'a>) -> Result, { + check_recursion_depth!(self); + let mut values = vec![]; loop { values.push(f(self)?); @@ -1569,6 +1683,8 @@ impl<'a> Parser<'a> { /// Parse either `ALL` or `DISTINCT`. Returns `true` if `DISTINCT` is parsed and results in a /// `ParserError` if both `ALL` and `DISTINCT` are fround. pub fn parse_all_or_distinct(&mut self) -> Result { + check_recursion_depth!(self); + let all = self.parse_keyword(Keyword::ALL); let distinct = self.parse_keyword(Keyword::DISTINCT); if all && distinct { @@ -1580,6 +1696,8 @@ impl<'a> Parser<'a> { /// Parse a SQL CREATE statement pub fn parse_create(&mut self) -> Result { + check_recursion_depth!(self); + let or_replace = self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]); let local = self.parse_one_of_keywords(&[Keyword::LOCAL]).is_some(); let global = self.parse_one_of_keywords(&[Keyword::GLOBAL]).is_some(); @@ -1622,6 +1740,8 @@ impl<'a> Parser<'a> { /// SQLite-specific `CREATE VIRTUAL TABLE` pub fn parse_create_virtual_table(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; @@ -1641,6 +1761,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_schema(&mut self) -> Result { + check_recursion_depth!(self); + let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let schema_name = self.parse_object_name()?; Ok(Statement::CreateSchema { @@ -1650,6 +1772,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_database(&mut self) -> Result { + check_recursion_depth!(self); + let ine = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let db_name = self.parse_object_name()?; let mut location = None; @@ -1675,6 +1799,8 @@ impl<'a> Parser<'a> { &mut self, or_replace: bool, ) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; @@ -1719,6 +1845,8 @@ impl<'a> Parser<'a> { } pub fn parse_file_format(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::AVRO => Ok(FileFormat::AVRO), @@ -1735,6 +1863,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_view(&mut self, or_replace: bool) -> Result { + check_recursion_depth!(self); + let materialized = self.parse_keyword(Keyword::MATERIALIZED); self.expect_keyword(Keyword::VIEW)?; // Many dialects support `OR ALTER` right after `CREATE`, but we don't (yet). @@ -1756,6 +1886,8 @@ impl<'a> Parser<'a> { } pub fn parse_drop(&mut self) -> Result { + check_recursion_depth!(self); + let object_type = if self.parse_keyword(Keyword::TABLE) { ObjectType::Table } else if self.parse_keyword(Keyword::VIEW) { @@ -1787,6 +1919,8 @@ impl<'a> Parser<'a> { } pub fn parse_create_index(&mut self, unique: bool) -> Result { + check_recursion_depth!(self); + let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let index_name = self.parse_object_name()?; self.expect_keyword(Keyword::ON)?; @@ -1805,6 +1939,8 @@ impl<'a> Parser<'a> { //TODO: Implement parsing for Skewed and Clustered pub fn parse_hive_distribution(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keywords(&[Keyword::PARTITIONED, Keyword::BY]) { self.expect_token(&Token::LParen)?; let columns = self.parse_comma_separated(Parser::parse_column_def)?; @@ -1816,6 +1952,8 @@ impl<'a> Parser<'a> { } pub fn parse_hive_formats(&mut self) -> Result { + check_recursion_depth!(self); + let mut hive_format = HiveFormat::default(); loop { match self.parse_one_of_keywords(&[Keyword::ROW, Keyword::STORED, Keyword::LOCATION]) { @@ -1849,6 +1987,8 @@ impl<'a> Parser<'a> { } pub fn parse_row_format(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::FORMAT)?; match self.parse_one_of_keywords(&[Keyword::SERDE, Keyword::DELIMITED]) { Some(Keyword::SERDE) => { @@ -1865,6 +2005,8 @@ impl<'a> Parser<'a> { temporary: bool, global: Option, ) -> Result { + check_recursion_depth!(self); + let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; let like = if self.parse_keyword(Keyword::LIKE) || self.parse_keyword(Keyword::ILIKE) { @@ -1963,6 +2105,8 @@ impl<'a> Parser<'a> { } pub fn parse_columns(&mut self) -> Result<(Vec, Vec), ParserError> { + check_recursion_depth!(self); + let mut columns = vec![]; let mut constraints = vec![]; if !self.consume_token(&Token::LParen) || self.consume_token(&Token::RParen) { @@ -1990,6 +2134,8 @@ impl<'a> Parser<'a> { } pub fn parse_column_def(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let data_type = self.parse_data_type()?; let collation = if self.parse_keyword(Keyword::COLLATE) { @@ -2024,6 +2170,8 @@ impl<'a> Parser<'a> { } pub fn parse_optional_column_option(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keywords(&[Keyword::CHARACTER, Keyword::SET]) { Ok(Some(ColumnOption::CharacterSet(self.parse_object_name()?))) } else if self.parse_keywords(&[Keyword::NOT, Keyword::NULL]) { @@ -2090,6 +2238,8 @@ impl<'a> Parser<'a> { } pub fn parse_referential_action(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::RESTRICT) { Ok(ReferentialAction::Restrict) } else if self.parse_keyword(Keyword::CASCADE) { @@ -2111,6 +2261,8 @@ impl<'a> Parser<'a> { pub fn parse_optional_table_constraint( &mut self, ) -> Result, ParserError> { + check_recursion_depth!(self); + let name = if self.parse_keyword(Keyword::CONSTRAINT) { Some(self.parse_identifier()?) } else { @@ -2175,6 +2327,8 @@ impl<'a> Parser<'a> { } pub fn parse_options(&mut self, keyword: Keyword) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(keyword) { self.expect_token(&Token::LParen)?; let options = self.parse_comma_separated(Parser::parse_sql_option)?; @@ -2186,6 +2340,8 @@ impl<'a> Parser<'a> { } pub fn parse_sql_option(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; self.expect_token(&Token::Eq)?; let value = self.parse_value()?; @@ -2193,6 +2349,8 @@ impl<'a> Parser<'a> { } pub fn parse_alter(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TABLE)?; let _ = self.parse_keyword(Keyword::ONLY); let table_name = self.parse_object_name()?; @@ -2347,6 +2505,8 @@ impl<'a> Parser<'a> { /// Parse a copy statement pub fn parse_copy(&mut self) -> Result { + check_recursion_depth!(self); + let table_name = self.parse_object_name()?; let columns = self.parse_parenthesized_column_list(Optional)?; let to = match self.parse_one_of_keywords(&[Keyword::FROM, Keyword::TO]) { @@ -2544,6 +2704,8 @@ impl<'a> Parser<'a> { /// Parse a literal value (numbers, strings, date/time, booleans) pub fn parse_value(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::TRUE => Ok(Value::Boolean(true)), @@ -2572,6 +2734,8 @@ impl<'a> Parser<'a> { } pub fn parse_number_value(&mut self) -> Result { + check_recursion_depth!(self); + match self.parse_value()? { v @ Value::Number(_, _) => Ok(v), _ => { @@ -2583,6 +2747,8 @@ impl<'a> Parser<'a> { /// Parse an unsigned literal integer/long pub fn parse_literal_uint(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Number(s, _) => s.parse::().map_err(|e| { ParserError::ParserError(format!("Could not parse '{}' as u64: {}", s, e)) @@ -2593,6 +2759,8 @@ impl<'a> Parser<'a> { /// Parse a literal string pub fn parse_literal_string(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(Word { value, keyword, .. }) if keyword == Keyword::NoKeyword => Ok(value), Token::SingleQuotedString(s) => Ok(s), @@ -2602,6 +2770,8 @@ impl<'a> Parser<'a> { /// Parse a map key string pub fn parse_map_key(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(Word { value, keyword, .. }) if keyword == Keyword::NoKeyword => { if self.peek_token() == Token::LParen { @@ -2620,6 +2790,8 @@ impl<'a> Parser<'a> { /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) pub fn parse_data_type(&mut self) -> Result { + check_recursion_depth!(self); + let mut data = match self.next_token() { Token::Word(w) => match w.keyword { Keyword::BOOLEAN => Ok(DataType::Boolean), @@ -2718,6 +2890,8 @@ impl<'a> Parser<'a> { } pub fn parse_string_values(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + self.expect_token(&Token::LParen)?; let mut values = Vec::new(); loop { @@ -2741,6 +2915,8 @@ impl<'a> Parser<'a> { &mut self, reserved_kwds: &[Keyword], ) -> Result, ParserError> { + check_recursion_depth!(self); + let after_as = self.parse_keyword(Keyword::AS); match self.next_token() { // Accept any identifier after `AS` (though many dialects have restrictions on @@ -2782,6 +2958,8 @@ impl<'a> Parser<'a> { &mut self, reserved_kwds: &[Keyword], ) -> Result, ParserError> { + check_recursion_depth!(self); + match self.parse_optional_alias(reserved_kwds)? { Some(name) => { let columns = self.parse_parenthesized_column_list(Optional)?; @@ -2794,6 +2972,8 @@ impl<'a> Parser<'a> { /// Parse a possibly qualified, possibly quoted identifier, e.g. /// `foo` or `myschema."table" pub fn parse_object_name(&mut self) -> Result { + check_recursion_depth!(self); + let mut idents = vec![]; loop { idents.push(self.parse_identifier()?); @@ -2806,6 +2986,8 @@ impl<'a> Parser<'a> { /// Parse identifiers strictly i.e. don't parse keywords pub fn parse_identifiers_non_keywords(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut idents = vec![]; loop { match self.peek_token() { @@ -2828,6 +3010,8 @@ impl<'a> Parser<'a> { /// Parse identifiers pub fn parse_identifiers(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut idents = vec![]; loop { match self.next_token() { @@ -2844,6 +3028,8 @@ impl<'a> Parser<'a> { /// Parse a simple one-word identifier (possibly quoted, possibly a keyword) pub fn parse_identifier(&mut self) -> Result { + check_recursion_depth!(self); + match self.next_token() { Token::Word(w) => Ok(w.to_ident()), Token::SingleQuotedString(s) => Ok(Ident::with_quote('\'', s)), @@ -2856,6 +3042,8 @@ impl<'a> Parser<'a> { &mut self, optional: IsOptional, ) -> Result, ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::LParen) { let cols = self.parse_comma_separated(Parser::parse_identifier)?; self.expect_token(&Token::RParen)?; @@ -2868,6 +3056,8 @@ impl<'a> Parser<'a> { } pub fn parse_optional_precision(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::LParen) { let n = self.parse_literal_uint()?; self.expect_token(&Token::RParen)?; @@ -2880,6 +3070,8 @@ impl<'a> Parser<'a> { pub fn parse_optional_precision_scale( &mut self, ) -> Result<(Option, Option), ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::LParen) { let n = self.parse_literal_uint()?; let scale = if self.consume_token(&Token::Comma) { @@ -2895,6 +3087,8 @@ impl<'a> Parser<'a> { } pub fn parse_delete(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::FROM)?; let table_name = self.parse_object_name()?; let selection = if self.parse_keyword(Keyword::WHERE) { @@ -2936,6 +3130,8 @@ impl<'a> Parser<'a> { } pub fn parse_explain(&mut self, describe_alias: bool) -> Result { + check_recursion_depth!(self); + let analyze = self.parse_keyword(Keyword::ANALYZE); let verbose = self.parse_keyword(Keyword::VERBOSE); @@ -2961,6 +3157,8 @@ impl<'a> Parser<'a> { /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed pub fn parse_query(&mut self) -> Result { + check_recursion_depth!(self); + let with = if self.parse_keyword(Keyword::WITH) { Some(With { recursive: self.parse_keyword(Keyword::RECURSIVE), @@ -3035,6 +3233,8 @@ impl<'a> Parser<'a> { /// Parse a CTE (`alias [( col1, col2, ... )] AS (subquery)`) pub fn parse_cte(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let mut cte = if self.parse_keyword(Keyword::AS) { @@ -3078,6 +3278,8 @@ impl<'a> Parser<'a> { /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body /// ``` pub fn parse_query_body(&mut self, precedence: u8) -> Result { + check_recursion_depth!(self); + // We parse the expression using a Pratt parser, as in `parse_expr()`. // Start by parsing a restricted SELECT or a `(subquery)`: let mut expr = if self.parse_keyword(Keyword::SELECT) { @@ -3134,6 +3336,8 @@ impl<'a> Parser<'a> { /// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`), /// assuming the initial `SELECT` was already consumed pub fn parse_select(&mut self) -> Result { + check_recursion_depth!(self); + let distinct = self.parse_all_or_distinct()?; let top = if self.parse_keyword(Keyword::TOP) { @@ -3262,6 +3466,8 @@ impl<'a> Parser<'a> { } pub fn parse_set(&mut self) -> Result { + check_recursion_depth!(self); + let modifier = self.parse_one_of_keywords(&[Keyword::SESSION, Keyword::LOCAL, Keyword::HIVEVAR]); if let Some(Keyword::HIVEVAR) = modifier { @@ -3327,6 +3533,8 @@ impl<'a> Parser<'a> { } pub fn parse_show(&mut self) -> Result { + check_recursion_depth!(self); + if self .parse_one_of_keywords(&[ Keyword::EXTENDED, @@ -3348,6 +3556,8 @@ impl<'a> Parser<'a> { } pub fn parse_show_create(&mut self) -> Result { + check_recursion_depth!(self); + let obj_type = match self.expect_one_of_keywords(&[ Keyword::TABLE, Keyword::TRIGGER, @@ -3372,6 +3582,8 @@ impl<'a> Parser<'a> { } pub fn parse_show_columns(&mut self) -> Result { + check_recursion_depth!(self); + let extended = self.parse_keyword(Keyword::EXTENDED); let full = self.parse_keyword(Keyword::FULL); self.expect_one_of_keywords(&[Keyword::COLUMNS, Keyword::FIELDS])?; @@ -3392,6 +3604,8 @@ impl<'a> Parser<'a> { pub fn parse_show_statement_filter( &mut self, ) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::LIKE) { Ok(Some(ShowStatementFilter::Like( self.parse_literal_string()?, @@ -3408,6 +3622,8 @@ impl<'a> Parser<'a> { } pub fn parse_table_and_joins(&mut self) -> Result { + check_recursion_depth!(self); + let relation = self.parse_table_factor()?; // Note that for keywords to be properly handled here, they need to be @@ -3482,6 +3698,8 @@ impl<'a> Parser<'a> { /// A table name or a parenthesized subquery, followed by optional `[AS] alias` pub fn parse_table_factor(&mut self) -> Result { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::LATERAL) { // LATERAL must always be followed by a subquery. if !self.consume_token(&Token::LParen) { @@ -3610,6 +3828,8 @@ impl<'a> Parser<'a> { &mut self, lateral: IsLateral, ) -> Result { + check_recursion_depth!(self); + let subquery = Box::new(self.parse_query()?); self.expect_token(&Token::RParen)?; let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; @@ -3624,6 +3844,8 @@ impl<'a> Parser<'a> { } pub fn parse_join_constraint(&mut self, natural: bool) -> Result { + check_recursion_depth!(self); + if natural { Ok(JoinConstraint::Natural) } else if self.parse_keyword(Keyword::ON) { @@ -3640,6 +3862,8 @@ impl<'a> Parser<'a> { /// Parse a GRANT statement. pub fn parse_grant(&mut self) -> Result { + check_recursion_depth!(self); + let (privileges, objects) = self.parse_grant_revoke_privileges_objects()?; self.expect_keyword(Keyword::TO)?; @@ -3664,6 +3888,8 @@ impl<'a> Parser<'a> { pub fn parse_grant_revoke_privileges_objects( &mut self, ) -> Result<(Privileges, GrantObjects), ParserError> { + check_recursion_depth!(self); + let privileges = if self.parse_keyword(Keyword::ALL) { Privileges::All { with_privileges_keyword: self.parse_keyword(Keyword::PRIVILEGES), @@ -3739,6 +3965,8 @@ impl<'a> Parser<'a> { } pub fn parse_grant_permission(&mut self) -> Result<(Keyword, Option>), ParserError> { + check_recursion_depth!(self); + if let Some(kw) = self.parse_one_of_keywords(&[ Keyword::CONNECT, Keyword::CREATE, @@ -3772,6 +4000,8 @@ impl<'a> Parser<'a> { /// Parse a REVOKE statement pub fn parse_revoke(&mut self) -> Result { + check_recursion_depth!(self); + let (privileges, objects) = self.parse_grant_revoke_privileges_objects()?; self.expect_keyword(Keyword::FROM)?; @@ -3798,6 +4028,8 @@ impl<'a> Parser<'a> { /// Parse an INSERT statement pub fn parse_insert(&mut self) -> Result { + check_recursion_depth!(self); + let or = if !dialect_of!(self is SQLiteDialect) { None } else if self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]) { @@ -3883,6 +4115,8 @@ impl<'a> Parser<'a> { } pub fn parse_update(&mut self) -> Result { + check_recursion_depth!(self); + let table = self.parse_table_and_joins()?; self.expect_keyword(Keyword::SET)?; let assignments = self.parse_comma_separated(Parser::parse_assignment)?; @@ -3906,6 +4140,8 @@ impl<'a> Parser<'a> { /// Parse a `var = expr` assignment, used in an UPDATE statement pub fn parse_assignment(&mut self) -> Result { + check_recursion_depth!(self); + let id = self.parse_identifiers_non_keywords()?; self.expect_token(&Token::Eq)?; let value = self.parse_expr()?; @@ -3913,6 +4149,8 @@ impl<'a> Parser<'a> { } pub fn parse_function_args(&mut self) -> Result { + check_recursion_depth!(self); + if self.peek_nth_token(1) == Token::RArrow { let name = self.parse_identifier()?; @@ -3926,6 +4164,8 @@ impl<'a> Parser<'a> { } pub fn parse_optional_args(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.consume_token(&Token::RParen) { Ok(vec![]) } else { @@ -3937,6 +4177,8 @@ impl<'a> Parser<'a> { /// Parse a comma-delimited list of projections after SELECT pub fn parse_select_item(&mut self) -> Result { + check_recursion_depth!(self); + match self.parse_wildcard_expr()? { WildcardExpr::Expr(expr) => self .parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS) @@ -3951,6 +4193,8 @@ impl<'a> Parser<'a> { /// Parse an expression, optionally followed by ASC or DESC (used in ORDER BY) pub fn parse_order_by_expr(&mut self) -> Result { + check_recursion_depth!(self); + let expr = self.parse_expr()?; let asc = if self.parse_keyword(Keyword::ASC) { @@ -3979,6 +4223,8 @@ impl<'a> Parser<'a> { /// Parse a TOP clause, MSSQL equivalent of LIMIT, /// that follows after SELECT [DISTINCT]. pub fn parse_top(&mut self) -> Result { + check_recursion_depth!(self); + let quantity = if self.consume_token(&Token::LParen) { let quantity = self.parse_expr()?; self.expect_token(&Token::RParen)?; @@ -4000,6 +4246,8 @@ impl<'a> Parser<'a> { /// Parse a LIMIT clause pub fn parse_limit(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + if self.parse_keyword(Keyword::ALL) { Ok(None) } else { @@ -4009,6 +4257,8 @@ impl<'a> Parser<'a> { /// Parse an OFFSET clause pub fn parse_offset(&mut self) -> Result { + check_recursion_depth!(self); + let value = Expr::Value(self.parse_number_value()?); let rows = if self.parse_keyword(Keyword::ROW) { OffsetRows::Row @@ -4022,6 +4272,8 @@ impl<'a> Parser<'a> { /// Parse a FETCH clause pub fn parse_fetch(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_one_of_keywords(&[Keyword::FIRST, Keyword::NEXT])?; let (quantity, percent) = if self .parse_one_of_keywords(&[Keyword::ROW, Keyword::ROWS]) @@ -4050,6 +4302,8 @@ impl<'a> Parser<'a> { /// Parse a FOR UPDATE/FOR SHARE clause pub fn parse_lock(&mut self) -> Result { + check_recursion_depth!(self); + match self.expect_one_of_keywords(&[Keyword::UPDATE, Keyword::SHARE])? { Keyword::UPDATE => Ok(LockType::Update), Keyword::SHARE => Ok(LockType::Share), @@ -4058,6 +4312,8 @@ impl<'a> Parser<'a> { } pub fn parse_values(&mut self) -> Result { + check_recursion_depth!(self); + let values = self.parse_comma_separated(|parser| { parser.expect_token(&Token::LParen)?; let exprs = parser.parse_comma_separated(Parser::parse_expr)?; @@ -4068,6 +4324,8 @@ impl<'a> Parser<'a> { } pub fn parse_start_transaction(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::TRANSACTION)?; Ok(Statement::StartTransaction { modes: self.parse_transaction_modes()?, @@ -4075,6 +4333,8 @@ impl<'a> Parser<'a> { } pub fn parse_begin(&mut self) -> Result { + check_recursion_depth!(self); + let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); Ok(Statement::StartTransaction { modes: self.parse_transaction_modes()?, @@ -4082,6 +4342,8 @@ impl<'a> Parser<'a> { } pub fn parse_transaction_modes(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut modes = vec![]; let mut required = false; loop { @@ -4118,18 +4380,24 @@ impl<'a> Parser<'a> { } pub fn parse_commit(&mut self) -> Result { + check_recursion_depth!(self); + Ok(Statement::Commit { chain: self.parse_commit_rollback_chain()?, }) } pub fn parse_rollback(&mut self) -> Result { + check_recursion_depth!(self); + Ok(Statement::Rollback { chain: self.parse_commit_rollback_chain()?, }) } pub fn parse_commit_rollback_chain(&mut self) -> Result { + check_recursion_depth!(self); + let _ = self.parse_one_of_keywords(&[Keyword::TRANSACTION, Keyword::WORK]); if self.parse_keyword(Keyword::AND) { let chain = !self.parse_keyword(Keyword::NO); @@ -4141,12 +4409,16 @@ impl<'a> Parser<'a> { } pub fn parse_deallocate(&mut self) -> Result { + check_recursion_depth!(self); + let prepare = self.parse_keyword(Keyword::PREPARE); let name = self.parse_identifier()?; Ok(Statement::Deallocate { name, prepare }) } pub fn parse_execute(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let mut parameters = vec![]; @@ -4159,6 +4431,8 @@ impl<'a> Parser<'a> { } pub fn parse_prepare(&mut self) -> Result { + check_recursion_depth!(self); + let name = self.parse_identifier()?; let mut data_types = vec![]; @@ -4177,6 +4451,8 @@ impl<'a> Parser<'a> { } pub fn parse_comment(&mut self) -> Result { + check_recursion_depth!(self); + self.expect_keyword(Keyword::ON)?; let token = self.next_token(); @@ -4206,6 +4482,8 @@ impl<'a> Parser<'a> { } pub fn parse_merge_clauses(&mut self) -> Result, ParserError> { + check_recursion_depth!(self); + let mut clauses: Vec = vec![]; loop { if self.peek_token() == Token::EOF { @@ -4283,6 +4561,8 @@ impl<'a> Parser<'a> { } pub fn parse_merge(&mut self) -> Result { + check_recursion_depth!(self); + let into = self.parse_keyword(Keyword::INTO); let table = self.parse_table_factor()?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 0986e407e..ab9e60c23 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -4740,3 +4740,13 @@ fn parse_is_boolean() { res.unwrap_err() ); } + +#[test] +fn parse_with_lots_of_stack_recursion() { + let sql = "(".repeat(1000); + let res = parse_sql_statements(&sql); + assert_eq!( + ParserError::RecursionLimitExceeded, + res.unwrap_err() + ); +}