From 4b270945fe48d8b1a8d2e0822d4a8d26a9ccb158 Mon Sep 17 00:00:00 2001 From: Drew Thomas Date: Thu, 9 Nov 2023 15:38:06 +1100 Subject: [PATCH 1/2] Add support for release and rollback to savepoint syntax --- src/ast/mod.rs | 26 ++++++++++++++--- src/parser/mod.rs | 26 +++++++++++++++-- tests/sqlparser_common.rs | 57 +++++++++++++++++++++++++++++++++++-- tests/sqlparser_postgres.rs | 10 ------- 4 files changed, 100 insertions(+), 19 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index ab917dc4c..709c7dc7c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1802,8 +1802,11 @@ pub enum Statement { }, /// `COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]` Commit { chain: bool }, - /// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]` - Rollback { chain: bool }, + /// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]` + Rollback { + chain: bool, + savepoint: Option, + }, /// CREATE SCHEMA CreateSchema { /// ` | AUTHORIZATION | AUTHORIZATION ` @@ -1940,6 +1943,8 @@ pub enum Statement { }, /// SAVEPOINT -- define a new savepoint within the current transaction Savepoint { name: Ident }, + /// RELEASE [ SAVEPOINT ] savepoint_name + ReleaseSavepoint { name: Ident }, // MERGE INTO statement, based on Snowflake. See Merge { // optional INTO keyword @@ -3079,8 +3084,18 @@ impl fmt::Display for Statement { Statement::Commit { chain } => { write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },) } - Statement::Rollback { chain } => { - write!(f, "ROLLBACK{}", if *chain { " AND CHAIN" } else { "" },) + Statement::Rollback { chain, savepoint } => { + write!(f, "ROLLBACK")?; + + if *chain { + write!(f, " AND CHAIN")?; + } + + if let Some(savepoint) = savepoint { + write!(f, " TO SAVEPOINT {savepoint}")?; + } + + Ok(()) } Statement::CreateSchema { schema_name, @@ -3177,6 +3192,9 @@ impl fmt::Display for Statement { write!(f, "SAVEPOINT ")?; write!(f, "{name}") } + Statement::ReleaseSavepoint { name } => { + write!(f, "RELEASE SAVEPOINT {name}") + } Statement::Merge { into, table, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1964437eb..61eddfbc2 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -502,6 +502,7 @@ impl<'a> Parser<'a> { // by at least PostgreSQL and MySQL. Keyword::BEGIN => Ok(self.parse_begin()?), Keyword::SAVEPOINT => Ok(self.parse_savepoint()?), + Keyword::RELEASE => Ok(self.parse_release()?), Keyword::COMMIT => Ok(self.parse_commit()?), Keyword::ROLLBACK => Ok(self.parse_rollback()?), Keyword::ASSERT => Ok(self.parse_assert()?), @@ -747,6 +748,13 @@ impl<'a> Parser<'a> { Ok(Statement::Savepoint { name }) } + pub fn parse_release(&mut self) -> Result { + let _ = self.parse_keyword(Keyword::SAVEPOINT); + let name = self.parse_identifier()?; + + Ok(Statement::ReleaseSavepoint { name }) + } + /// Parse an expression prefix pub fn parse_prefix(&mut self) -> Result { // allow the dialect to override prefix parsing @@ -7664,9 +7672,10 @@ impl<'a> Parser<'a> { } pub fn parse_rollback(&mut self) -> Result { - Ok(Statement::Rollback { - chain: self.parse_commit_rollback_chain()?, - }) + let chain = self.parse_commit_rollback_chain()?; + let savepoint = self.parse_rollback_savepoint()?; + + Ok(Statement::Rollback { chain, savepoint }) } pub fn parse_commit_rollback_chain(&mut self) -> Result { @@ -7680,6 +7689,17 @@ impl<'a> Parser<'a> { } } + pub fn parse_rollback_savepoint(&mut self) -> Result, ParserError> { + if self.parse_keyword(Keyword::TO) { + let _ = self.parse_keyword(Keyword::SAVEPOINT); + let savepoint = self.parse_identifier()?; + + Ok(Some(savepoint)) + } else { + Ok(None) + } + } + pub fn parse_deallocate(&mut self) -> Result { let prepare = self.parse_keyword(Keyword::PREPARE); let name = self.parse_identifier()?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index befdf5129..19da1e2f9 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6230,12 +6230,38 @@ fn parse_commit() { #[test] fn parse_rollback() { match verified_stmt("ROLLBACK") { - Statement::Rollback { chain: false } => (), + Statement::Rollback { + chain: false, + savepoint: None, + } => (), _ => unreachable!(), } match verified_stmt("ROLLBACK AND CHAIN") { - Statement::Rollback { chain: true } => (), + Statement::Rollback { + chain: true, + savepoint: None, + } => (), + _ => unreachable!(), + } + + match verified_stmt("ROLLBACK TO SAVEPOINT test1") { + Statement::Rollback { + chain: false, + savepoint, + } => { + assert_eq!(savepoint, Some(Ident::new("test1"))); + } + _ => unreachable!(), + } + + match verified_stmt("ROLLBACK AND CHAIN TO SAVEPOINT test1") { + Statement::Rollback { + chain: true, + savepoint, + } => { + assert_eq!(savepoint, Some(Ident::new("test1"))); + } _ => unreachable!(), } @@ -6246,6 +6272,11 @@ fn parse_rollback() { one_statement_parses_to("ROLLBACK TRANSACTION AND CHAIN", "ROLLBACK AND CHAIN"); one_statement_parses_to("ROLLBACK WORK", "ROLLBACK"); one_statement_parses_to("ROLLBACK TRANSACTION", "ROLLBACK"); + one_statement_parses_to("ROLLBACK TO test1", "ROLLBACK TO SAVEPOINT test1"); + one_statement_parses_to( + "ROLLBACK AND CHAIN TO test1", + "ROLLBACK AND CHAIN TO SAVEPOINT test1", + ); } #[test] @@ -7859,3 +7890,25 @@ fn parse_binary_operators_without_whitespace() { "SELECT tbl1.field % tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id", ); } + +#[test] +fn test_savepoint() { + match verified_stmt("SAVEPOINT test1") { + Statement::Savepoint { name } => { + assert_eq!(Ident::new("test1"), name); + } + _ => unreachable!(), + } +} + +#[test] +fn test_release_savepoint() { + match verified_stmt("RELEASE SAVEPOINT test1") { + Statement::ReleaseSavepoint { name } => { + assert_eq!(Ident::new("test1"), name); + } + _ => unreachable!(), + } + + one_statement_parses_to("RELEASE test1", "RELEASE SAVEPOINT test1"); +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 606d835f1..0601a8868 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -2085,16 +2085,6 @@ fn test_transaction_statement() { ); } -#[test] -fn test_savepoint() { - match pg().verified_stmt("SAVEPOINT test1") { - Statement::Savepoint { name } => { - assert_eq!(Ident::new("test1"), name); - } - _ => unreachable!(), - } -} - #[test] fn test_json() { let sql = "SELECT params ->> 'name' FROM events"; From 598e92f760ea390a577aed5301f21b019c16d64b Mon Sep 17 00:00:00 2001 From: Drew Thomas Date: Tue, 21 Nov 2023 07:41:59 +1100 Subject: [PATCH 2/2] Escape square brackets in `ReleaseSavepoint` docs The brackets here need to be escaped so that rustdoc doesn't try to treat them as an intra doc link. --- src/ast/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 709c7dc7c..5575014c0 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1943,7 +1943,7 @@ pub enum Statement { }, /// SAVEPOINT -- define a new savepoint within the current transaction Savepoint { name: Ident }, - /// RELEASE [ SAVEPOINT ] savepoint_name + /// RELEASE \[ SAVEPOINT \] savepoint_name ReleaseSavepoint { name: Ident }, // MERGE INTO statement, based on Snowflake. See Merge {