Skip to content

Commit 414078f

Browse files
authored
fix: begin atomic split (#538)
1 parent df6abda commit 414078f

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

crates/pgt_statement_splitter/src/lib.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ mod tests {
9292
assert_eq!(
9393
self.result.ranges.len(),
9494
expected.len(),
95-
"Expected {} statements for input {}, got {}: {:?}",
95+
"Expected {} statements for input\n{}\ngot {}:\n{:?}",
9696
expected.len(),
9797
self.input,
9898
self.result.ranges.len(),
@@ -133,6 +133,40 @@ mod tests {
133133
}
134134
}
135135

136+
#[test]
137+
fn begin_commit() {
138+
Tester::from(
139+
"BEGIN;
140+
SELECT 1;
141+
COMMIT;",
142+
)
143+
.expect_statements(vec!["BEGIN;", "SELECT 1;", "COMMIT;"]);
144+
}
145+
146+
#[test]
147+
fn begin_atomic() {
148+
Tester::from(
149+
"CREATE OR REPLACE FUNCTION public.test_fn(some_in TEXT)
150+
RETURNS TEXT
151+
LANGUAGE sql
152+
IMMUTABLE
153+
STRICT
154+
BEGIN ATOMIC
155+
SELECT $1 || 'foo';
156+
END;",
157+
)
158+
.expect_statements(vec![
159+
"CREATE OR REPLACE FUNCTION public.test_fn(some_in TEXT)
160+
RETURNS TEXT
161+
LANGUAGE sql
162+
IMMUTABLE
163+
STRICT
164+
BEGIN ATOMIC
165+
SELECT $1 || 'foo';
166+
END;",
167+
]);
168+
}
169+
136170
#[test]
137171
fn ts_with_timezone() {
138172
Tester::from("alter table foo add column bar timestamp with time zone;").expect_statements(

crates/pgt_statement_splitter/src/splitter/common.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,33 @@ pub(crate) fn statement(p: &mut Splitter) {
5858
p.close_stmt();
5959
}
6060

61+
pub(crate) fn begin_end(p: &mut Splitter) {
62+
p.expect(SyntaxKind::BEGIN_KW);
63+
64+
let mut depth = 1;
65+
66+
loop {
67+
match p.current() {
68+
SyntaxKind::BEGIN_KW => {
69+
p.advance();
70+
depth += 1;
71+
}
72+
SyntaxKind::END_KW | SyntaxKind::EOF => {
73+
if p.current() == SyntaxKind::END_KW {
74+
p.advance();
75+
}
76+
depth -= 1;
77+
if depth == 0 {
78+
break;
79+
}
80+
}
81+
_ => {
82+
p.advance();
83+
}
84+
}
85+
}
86+
}
87+
6188
pub(crate) fn parenthesis(p: &mut Splitter) {
6289
p.expect(SyntaxKind::L_PAREN);
6390

@@ -163,6 +190,14 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) {
163190
SyntaxKind::L_PAREN => {
164191
parenthesis(p);
165192
}
193+
SyntaxKind::BEGIN_KW => {
194+
if p.look_ahead(true) != SyntaxKind::SEMICOLON {
195+
// BEGIN; should be treated as a statement terminator
196+
begin_end(p);
197+
} else {
198+
p.advance();
199+
}
200+
}
166201
t => match at_statement_start(t, exclude) {
167202
Some(SyntaxKind::SELECT_KW) => {
168203
let prev = p.look_back(true);
@@ -188,6 +223,8 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) {
188223
// for revoke
189224
SyntaxKind::REVOKE_KW,
190225
SyntaxKind::COMMA,
226+
// for BEGIN ATOMIC
227+
SyntaxKind::ATOMIC_KW,
191228
]
192229
.iter()
193230
.all(|x| Some(x) != prev.as_ref())
@@ -255,7 +292,6 @@ pub(crate) fn unknown(p: &mut Splitter, exclude: &[SyntaxKind]) {
255292
}
256293
p.advance();
257294
}
258-
259295
Some(SyntaxKind::CREATE_KW) => {
260296
let prev = p.look_back(true);
261297
if [

0 commit comments

Comments
 (0)