Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for release and rollback to savepoint syntax #1045

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1839,8 +1839,11 @@ pub enum Statement {
},
/// `COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]`
Commit { chain: bool },
/// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]`
Rollback { chain: bool },
/// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]`
Rollback {
chain: bool,
savepoint: Option<Ident>,
},
/// CREATE SCHEMA
CreateSchema {
/// `<schema name> | AUTHORIZATION <schema authorization identifier> | <schema name> AUTHORIZATION <schema authorization identifier>`
Expand Down Expand Up @@ -1977,6 +1980,8 @@ pub enum Statement {
},
/// SAVEPOINT -- define a new savepoint within the current transaction
Savepoint { name: Ident },
/// RELEASE \[ SAVEPOINT \] savepoint_name
ReleaseSavepoint { name: Ident },
// MERGE INTO statement, based on Snowflake. See <https://docs.snowflake.com/en/sql-reference/sql/merge.html>
Merge {
// optional INTO keyword
Expand Down Expand Up @@ -3127,8 +3132,18 @@ impl fmt::Display for Statement {
Statement::Commit { chain } => {
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },)
}
Statement::Rollback { chain } => {
write!(f, "ROLLBACK{}", if *chain { " AND CHAIN" } else { "" },)
Statement::Rollback { chain, savepoint } => {
write!(f, "ROLLBACK")?;

if *chain {
write!(f, " AND CHAIN")?;
}

if let Some(savepoint) = savepoint {
write!(f, " TO SAVEPOINT {savepoint}")?;
}

Ok(())
}
Statement::CreateSchema {
schema_name,
Expand Down Expand Up @@ -3225,6 +3240,9 @@ impl fmt::Display for Statement {
write!(f, "SAVEPOINT ")?;
write!(f, "{name}")
}
Statement::ReleaseSavepoint { name } => {
write!(f, "RELEASE SAVEPOINT {name}")
}
Statement::Merge {
into,
table,
Expand Down
26 changes: 23 additions & 3 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ impl<'a> Parser<'a> {
// by at least PostgreSQL and MySQL.
Keyword::BEGIN => Ok(self.parse_begin()?),
Keyword::SAVEPOINT => Ok(self.parse_savepoint()?),
Keyword::RELEASE => Ok(self.parse_release()?),
Keyword::COMMIT => Ok(self.parse_commit()?),
Keyword::ROLLBACK => Ok(self.parse_rollback()?),
Keyword::ASSERT => Ok(self.parse_assert()?),
Expand Down Expand Up @@ -747,6 +748,13 @@ impl<'a> Parser<'a> {
Ok(Statement::Savepoint { name })
}

pub fn parse_release(&mut self) -> Result<Statement, ParserError> {
let _ = self.parse_keyword(Keyword::SAVEPOINT);
let name = self.parse_identifier()?;

Ok(Statement::ReleaseSavepoint { name })
}

/// Parse an expression prefix
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
// allow the dialect to override prefix parsing
Expand Down Expand Up @@ -7843,9 +7851,10 @@ impl<'a> Parser<'a> {
}

pub fn parse_rollback(&mut self) -> Result<Statement, ParserError> {
Ok(Statement::Rollback {
chain: self.parse_commit_rollback_chain()?,
})
let chain = self.parse_commit_rollback_chain()?;
let savepoint = self.parse_rollback_savepoint()?;

Ok(Statement::Rollback { chain, savepoint })
}

pub fn parse_commit_rollback_chain(&mut self) -> Result<bool, ParserError> {
Expand All @@ -7859,6 +7868,17 @@ impl<'a> Parser<'a> {
}
}

pub fn parse_rollback_savepoint(&mut self) -> Result<Option<Ident>, ParserError> {
if self.parse_keyword(Keyword::TO) {
let _ = self.parse_keyword(Keyword::SAVEPOINT);
let savepoint = self.parse_identifier()?;

Ok(Some(savepoint))
} else {
Ok(None)
}
}

pub fn parse_deallocate(&mut self) -> Result<Statement, ParserError> {
let prepare = self.parse_keyword(Keyword::PREPARE);
let name = self.parse_identifier()?;
Expand Down
57 changes: 55 additions & 2 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6234,12 +6234,38 @@ fn parse_commit() {
#[test]
fn parse_rollback() {
match verified_stmt("ROLLBACK") {
Statement::Rollback { chain: false } => (),
Statement::Rollback {
chain: false,
savepoint: None,
} => (),
_ => unreachable!(),
}

match verified_stmt("ROLLBACK AND CHAIN") {
Statement::Rollback { chain: true } => (),
Statement::Rollback {
chain: true,
savepoint: None,
} => (),
_ => unreachable!(),
}

match verified_stmt("ROLLBACK TO SAVEPOINT test1") {
Statement::Rollback {
chain: false,
savepoint,
} => {
assert_eq!(savepoint, Some(Ident::new("test1")));
}
_ => unreachable!(),
}

match verified_stmt("ROLLBACK AND CHAIN TO SAVEPOINT test1") {
Statement::Rollback {
chain: true,
savepoint,
} => {
assert_eq!(savepoint, Some(Ident::new("test1")));
}
_ => unreachable!(),
}

Expand All @@ -6250,6 +6276,11 @@ fn parse_rollback() {
one_statement_parses_to("ROLLBACK TRANSACTION AND CHAIN", "ROLLBACK AND CHAIN");
one_statement_parses_to("ROLLBACK WORK", "ROLLBACK");
one_statement_parses_to("ROLLBACK TRANSACTION", "ROLLBACK");
one_statement_parses_to("ROLLBACK TO test1", "ROLLBACK TO SAVEPOINT test1");
one_statement_parses_to(
"ROLLBACK AND CHAIN TO test1",
"ROLLBACK AND CHAIN TO SAVEPOINT test1",
);
}

#[test]
Expand Down Expand Up @@ -7864,3 +7895,25 @@ fn parse_binary_operators_without_whitespace() {
"SELECT tbl1.field % tbl2.field FROM tbl1 JOIN tbl2 ON tbl1.id = tbl2.entity_id",
);
}

#[test]
fn test_savepoint() {
match verified_stmt("SAVEPOINT test1") {
Statement::Savepoint { name } => {
assert_eq!(Ident::new("test1"), name);
}
_ => unreachable!(),
}
}

#[test]
fn test_release_savepoint() {
match verified_stmt("RELEASE SAVEPOINT test1") {
Statement::ReleaseSavepoint { name } => {
assert_eq!(Ident::new("test1"), name);
}
_ => unreachable!(),
}

one_statement_parses_to("RELEASE test1", "RELEASE SAVEPOINT test1");
}
10 changes: 0 additions & 10 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2093,16 +2093,6 @@ fn test_transaction_statement() {
);
}

#[test]
fn test_savepoint() {
match pg().verified_stmt("SAVEPOINT test1") {
Statement::Savepoint { name } => {
assert_eq!(Ident::new("test1"), name);
}
_ => unreachable!(),
}
}

#[test]
fn test_json() {
let sql = "SELECT params ->> 'name' FROM events";
Expand Down