diff --git a/crates/pgt_statement_splitter/src/lib.rs b/crates/pgt_statement_splitter/src/lib.rs index 6fb81c092..c67bc0e97 100644 --- a/crates/pgt_statement_splitter/src/lib.rs +++ b/crates/pgt_statement_splitter/src/lib.rs @@ -92,7 +92,7 @@ mod tests { assert_eq!( self.result.ranges.len(), expected.len(), - "Expected {} statements for input {}, got {}: {:?}", + "Expected {} statements for input\n{}\ngot {}:\n{:?}", expected.len(), self.input, self.result.ranges.len(), @@ -133,6 +133,40 @@ mod tests { } } + #[test] + fn begin_commit() { + Tester::from( + "BEGIN; +SELECT 1; +COMMIT;", + ) + .expect_statements(vec!["BEGIN;", "SELECT 1;", "COMMIT;"]); + } + + #[test] + fn begin_atomic() { + Tester::from( + "CREATE OR REPLACE FUNCTION public.test_fn(some_in TEXT) +RETURNS TEXT +LANGUAGE sql +IMMUTABLE +STRICT +BEGIN ATOMIC + SELECT $1 || 'foo'; +END;", + ) + .expect_statements(vec![ + "CREATE OR REPLACE FUNCTION public.test_fn(some_in TEXT) +RETURNS TEXT +LANGUAGE sql +IMMUTABLE +STRICT +BEGIN ATOMIC + SELECT $1 || 'foo'; +END;", + ]); + } + #[test] fn ts_with_timezone() { Tester::from("alter table foo add column bar timestamp with time zone;").expect_statements( diff --git a/crates/pgt_statement_splitter/src/splitter/common.rs b/crates/pgt_statement_splitter/src/splitter/common.rs index 54db04e8b..fcb851dac 100644 --- a/crates/pgt_statement_splitter/src/splitter/common.rs +++ b/crates/pgt_statement_splitter/src/splitter/common.rs @@ -58,6 +58,33 @@ pub(crate) fn statement(p: &mut Splitter) { p.close_stmt(); } +pub(crate) fn begin_end(p: &mut Splitter) { + p.expect(SyntaxKind::BEGIN_KW); + + let mut depth = 1; + + loop { + match p.current() { + SyntaxKind::BEGIN_KW => { + p.advance(); + depth += 1; + } + SyntaxKind::END_KW | SyntaxKind::EOF => { + if p.current() == SyntaxKind::END_KW { + p.advance(); + } + depth -= 1; + if depth == 0 { + break; + } + } + _ => { + p.advance(); + } + } + } +} + pub(crate) fn parenthesis(p: &mut Splitter) { p.expect(SyntaxKind::L_PAREN); @@ -163,6 +190,14 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) { SyntaxKind::L_PAREN => { parenthesis(p); } + SyntaxKind::BEGIN_KW => { + if p.look_ahead(true) != SyntaxKind::SEMICOLON { + // BEGIN; should be treated as a statement terminator + begin_end(p); + } else { + p.advance(); + } + } t => match at_statement_start(t, exclude) { Some(SyntaxKind::SELECT_KW) => { let prev = p.look_back(true); @@ -188,6 +223,8 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) { // for revoke SyntaxKind::REVOKE_KW, SyntaxKind::COMMA, + // for BEGIN ATOMIC + SyntaxKind::ATOMIC_KW, ] .iter() .all(|x| Some(x) != prev.as_ref()) @@ -255,7 +292,6 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) { } p.advance(); } - Some(SyntaxKind::CREATE_KW) => { let prev = p.look_back(true); if [