From 4aef274221df465dc73701ef3a12f9e7ea763143 Mon Sep 17 00:00:00 2001 From: toverdijk Date: Sun, 22 May 2022 19:56:03 +0200 Subject: [PATCH 1/3] - Renamed PostgresqlLexer to PostgresqlParser - Renamed TokenizedSql to ParsedSql - Lexing/parsing is now done in two steps: first only tokenize, then parse into statements - Added support for function bodies ("BEGIN ATOMIC") - Added a test case for newly supported grammar --- .../{TokenizedSql.java => ParsedSql.java} | 34 +- .../io/r2dbc/postgresql/PostgresqlBatch.java | 2 +- ...SqlLexer.java => PostgresqlSqlParser.java} | 105 +++--- .../r2dbc/postgresql/PostgresqlStatement.java | 28 +- .../postgresql/PostgresqlSqlLexerTest.java | 291 ---------------- .../postgresql/PostgresqlSqlParserTest.java | 313 ++++++++++++++++++ 6 files changed, 404 insertions(+), 369 deletions(-) rename src/main/java/io/r2dbc/postgresql/{TokenizedSql.java => ParsedSql.java} (84%) rename src/main/java/io/r2dbc/postgresql/{PostgresqlSqlLexer.java => PostgresqlSqlParser.java} (60%) delete mode 100644 src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java create mode 100644 src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java diff --git a/src/main/java/io/r2dbc/postgresql/TokenizedSql.java b/src/main/java/io/r2dbc/postgresql/ParsedSql.java similarity index 84% rename from src/main/java/io/r2dbc/postgresql/TokenizedSql.java rename to src/main/java/io/r2dbc/postgresql/ParsedSql.java index 1f692aea7..3480d6b97 100644 --- a/src/main/java/io/r2dbc/postgresql/TokenizedSql.java +++ b/src/main/java/io/r2dbc/postgresql/ParsedSql.java @@ -20,24 +20,24 @@ import java.util.Set; import java.util.TreeSet; -class TokenizedSql { +class ParsedSql { private final String sql; - private final List statements; + private final List statements; private final int statementCount; private final int parameterCount; - public TokenizedSql(String sql, List statements) { + public ParsedSql(String sql, List statements) { this.sql = sql; this.statements = statements; this.statementCount = statements.size(); this.parameterCount = getParameterCount(statements); } - List getStatements() { + List getStatements() { return this.statements; } @@ -53,16 +53,16 @@ public String getSql() { return sql; } - private static int getParameterCount(List statements) { + private static int getParameterCount(List statements) { int sum = 0; - for (TokenizedStatement statement : statements){ + for (Statement statement : statements){ sum += statement.getParameterCount(); } return sum; } public boolean hasDefaultTokenValue(String... tokenValues) { - for (TokenizedStatement statement : this.statements) { + for (Statement statement : this.statements) { for (Token token : statement.getTokens()) { if (token.getType() == TokenType.DEFAULT) { for (String value : tokenValues) { @@ -129,24 +129,17 @@ public String toString() { } - static class TokenizedStatement { - - private final String sql; + static class Statement { private final List tokens; private final int parameterCount; - public TokenizedStatement(String sql, List tokens) { + public Statement(List tokens) { this.tokens = tokens; - this.sql = sql; this.parameterCount = readParameterCount(tokens); } - public String getSql() { - return this.sql; - } - public List getTokens() { return this.tokens; } @@ -164,19 +157,14 @@ public boolean equals(Object o) { return false; } - TokenizedStatement that = (TokenizedStatement) o; + Statement that = (Statement) o; - if (!this.sql.equals(that.sql)) { - return false; - } return this.tokens.equals(that.tokens); } @Override public int hashCode() { - int result = this.sql.hashCode(); - result = 31 * result + this.tokens.hashCode(); - return result; + return this.tokens.hashCode(); } @Override diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java index d1f5ccd13..41ed92917 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java @@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch { public PostgresqlBatch add(String sql) { Assert.requireNonNull(sql, "sql must not be null"); - if (!(PostgresqlSqlLexer.tokenize(sql).getParameterCount() == 0)) { + if (!(PostgresqlSqlParser.parse(sql).getParameterCount() == 0)) { throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql)); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlLexer.java b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java similarity index 60% rename from src/main/java/io/r2dbc/postgresql/PostgresqlSqlLexer.java rename to src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java index 4139648cf..b9884a370 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlLexer.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java @@ -27,7 +27,7 @@ * * @since 0.9 */ -class PostgresqlSqlLexer { +class PostgresqlSqlParser { private static final char[] SPECIAL_AND_OPERATOR_CHARS = { '+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?', @@ -38,15 +38,12 @@ class PostgresqlSqlLexer { Arrays.sort(SPECIAL_AND_OPERATOR_CHARS); } - public static TokenizedSql tokenize(String sql) { - List tokens = new ArrayList<>(); - List statements = new ArrayList<>(); - - int statementStartIndex = 0; + private static List tokenize(String sql) { + List tokens = new ArrayList<>(); int i = 0; while (i < sql.length()) { char c = sql.charAt(i); - TokenizedSql.Token token = null; + ParsedSql.Token token = null; if (isWhitespace(c)) { i++; @@ -73,55 +70,82 @@ public static TokenizedSql tokenize(String sql) { token = getParameterOrDollarQuoteToken(sql, i); break; case ';': - token = new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";"); + token = new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"); break; default: break; } if (token == null) { if (isSpecialOrOperatorChar(c)) { - token = new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, Character.toString(c));//getSpecialOrOperatorToken(sql, i); + token = new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, Character.toString(c));//getSpecialOrOperatorToken(sql, i); } else { token = getDefaultToken(sql, i); } } i += token.getValue().length(); + tokens.add(token); + } + return tokens; + } - if (token.getType() == TokenizedSql.TokenType.STATEMENT_END) { + public static ParsedSql parse(String sql) { + List tokens = tokenize(sql); + List statements = new ArrayList<>(); + List functionBodyList = new ArrayList<>(); - tokens.add(token); - statements.add(new TokenizedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens)); + List currentStatementTokens = new ArrayList<>(); + for (int i = 0; i < tokens.size(); i++) { + ParsedSql.Token current = tokens.get(i); + currentStatementTokens.add(current); - tokens = new ArrayList<>(); - statementStartIndex = i + 1; - } else { - tokens.add(token); + if (current.getType() == ParsedSql.TokenType.DEFAULT) { + String currentValue = current.getValue(); + + if (currentValue.equalsIgnoreCase("BEGIN")) { + if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) { + functionBodyList.add(true); + } else { + functionBodyList.add(false); + } + } else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) { + functionBodyList.remove(functionBodyList.size() - 1); + } + } else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) { + boolean inFunctionBody = false; + + for (boolean b : functionBodyList) { + inFunctionBody |= b; + } + if (!inFunctionBody) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); + currentStatementTokens = new ArrayList<>(); + } } } - // If tokens is not empty, implicit statement end - if (!tokens.isEmpty()) { - statements.add(new TokenizedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens)); + + if (!currentStatementTokens.isEmpty()) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); } - return new TokenizedSql(sql, statements); + return new ParsedSql(sql, statements); } - private static TokenizedSql.Token getDefaultToken(String sql, int beginIndex) { + private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) { for (int i = beginIndex + 1; i < sql.length(); i++) { char c = sql.charAt(i); if (Character.isWhitespace(c) || isSpecialOrOperatorChar(c)) { - return new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, sql.substring(beginIndex, i)); + return new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, sql.substring(beginIndex, i)); } } - return new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, sql.substring(beginIndex)); + return new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, sql.substring(beginIndex)); } private static boolean isSpecialOrOperatorChar(char c) { return Arrays.binarySearch(SPECIAL_AND_OPERATOR_CHARS, c) >= 0; } - private static TokenizedSql.Token getBlockCommentToken(String sql, int beginIndex) { + private static ParsedSql.Token getBlockCommentToken(String sql, int beginIndex) { int depth = 1; for (int i = beginIndex + 2; i < (sql.length() - 1); i++) { char c1 = sql.charAt(i); @@ -134,44 +158,44 @@ private static TokenizedSql.Token getBlockCommentToken(String sql, int beginInde i++; } if (depth == 0) { - return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex, i + 1)); + return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex, i + 1)); } } throw new IllegalArgumentException("Sql cannot be parsed: unclosed block comment (comment opened at index " + beginIndex + ") in statement: " + sql); } - private static TokenizedSql.Token getCommentToLineEndToken(String sql, int beginIndex) { + private static ParsedSql.Token getCommentToLineEndToken(String sql, int beginIndex) { int lineEnding = sql.indexOf('\n', beginIndex); if (lineEnding == -1) { - return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex)); + return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex)); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex, lineEnding)); + return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex, lineEnding)); } } - private static TokenizedSql.Token getDollarQuoteToken(String sql, String tag, int beginIndex) { + private static ParsedSql.Token getDollarQuoteToken(String sql, String tag, int beginIndex) { int nextQuote = sql.indexOf(tag, beginIndex + tag.length()); if (nextQuote == -1) { throw new IllegalArgumentException("Sql cannot be parsed: unclosed quote (quote opened at index " + beginIndex + ") in statement: " + sql); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + tag.length())); + return new ParsedSql.Token(ParsedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + tag.length())); } } - private static TokenizedSql.Token getParameterToken(String sql, int beginIndex) { + private static ParsedSql.Token getParameterToken(String sql, int beginIndex) { for (int i = beginIndex + 1; i < sql.length(); i++) { char c = sql.charAt(i); if (isWhitespace(c) || isSpecialOrOperatorChar(c)) { - return new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, sql.substring(beginIndex, i)); + return new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, sql.substring(beginIndex, i)); } if (!isAsciiDigit(c)) { throw new IllegalArgumentException("Sql cannot be parsed: illegal character in parameter or dollar-quote tag: " + c); } } - return new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, sql.substring(beginIndex)); + return new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, sql.substring(beginIndex)); } - private static TokenizedSql.Token getParameterOrDollarQuoteToken(String sql, int beginIndex) { + private static ParsedSql.Token getParameterOrDollarQuoteToken(String sql, int beginIndex) { char firstChar = sql.charAt(beginIndex + 1); if (firstChar == '$') { return getDollarQuoteToken(sql, "$$", beginIndex); @@ -191,30 +215,31 @@ private static TokenizedSql.Token getParameterOrDollarQuoteToken(String sql, int } } - private static TokenizedSql.Token getStandardQuoteToken(String sql, int beginIndex) { + private static ParsedSql.Token getStandardQuoteToken(String sql, int beginIndex) { int nextQuote = sql.indexOf('\'', beginIndex + 1); if (nextQuote == -1) { throw new IllegalArgumentException("Sql cannot be parsed: unclosed quote (quote opened at index " + beginIndex + ") in statement: " + sql); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + 1)); + return new ParsedSql.Token(ParsedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + 1)); } } - private static TokenizedSql.Token getQuotedIdentifierToken(String sql, int beginIndex) { + private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginIndex) { int nextQuote = sql.indexOf('\"', beginIndex + 1); if (nextQuote == -1) { throw new IllegalArgumentException("Sql cannot be parsed: unclosed quoted identifier (identifier opened at index " + beginIndex + ") in statement: " + sql); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, sql.substring(beginIndex, nextQuote + 1)); + return new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, sql.substring(beginIndex, nextQuote + 1)); } } - private static boolean isAsciiLetter(char c){ + private static boolean isAsciiLetter(char c) { char lower = Character.toLowerCase(c); return lower >= 'a' && lower <= 'z'; } - private static boolean isAsciiDigit(char c){ + private static boolean isAsciiDigit(char c) { return c >= '0' && c <= '9'; } + } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java index d09d8b78e..d612472a1 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java @@ -65,7 +65,7 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta private final ConnectionContext connectionContext; - private final TokenizedSql tokenizedSql; + private final ParsedSql parsedSql; private int fetchSize; @@ -73,11 +73,11 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta PostgresqlStatement(ConnectionResources resources, String sql) { this.resources = Assert.requireNonNull(resources, "resources must not be null"); - this.tokenizedSql = PostgresqlSqlLexer.tokenize(Assert.requireNonNull(sql, "sql must not be null")); + this.parsedSql = PostgresqlSqlParser.parse(Assert.requireNonNull(sql, "sql must not be null")); this.connectionContext = resources.getClient().getContext(); - this.bindings = new ArrayDeque<>(this.tokenizedSql.getParameterCount()); + this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount()); - if (this.tokenizedSql.getStatementCount() > 1 && this.tokenizedSql.getParameterCount() > 0) { + if (this.parsedSql.getStatementCount() > 1 && this.parsedSql.getParameterCount() > 0) { throw new IllegalArgumentException(String.format("Statement '%s' cannot be created. This is often due to the presence of both multiple statements and parameters at the same time.", sql)); } @@ -90,7 +90,7 @@ public PostgresqlStatement add() { if (binding != null) { binding.validate(); } - this.bindings.add(new Binding(this.tokenizedSql.getParameterCount())); + this.bindings.add(new Binding(this.parsedSql.getParameterCount())); return this; } @@ -117,8 +117,8 @@ public PostgresqlStatement bindNull(String identifier, Class type) { public PostgresqlStatement bindNull(int index, Class type) { Assert.requireNonNull(type, "type must not be null"); - if (index >= this.tokenizedSql.getParameterCount()) { - throw new UnsupportedOperationException(String.format("Cannot bind parameter %d, statement has %d parameters", index, this.tokenizedSql.getParameterCount())); + if (index >= this.parsedSql.getParameterCount()) { + throw new UnsupportedOperationException(String.format("Cannot bind parameter %d, statement has %d parameters", index, this.parsedSql.getParameterCount())); } BindingLogger.logBindNull(this.connectionContext, index, type); @@ -130,7 +130,7 @@ public PostgresqlStatement bindNull(int index, Class type) { private Binding getCurrentOrFirstBinding() { Binding binding = this.bindings.peekLast(); if (binding == null) { - Binding newBinding = new Binding(this.tokenizedSql.getParameterCount()); + Binding newBinding = new Binding(this.parsedSql.getParameterCount()); this.bindings.add(newBinding); return newBinding; } else { @@ -141,20 +141,20 @@ private Binding getCurrentOrFirstBinding() { @Override public Flux execute() { if (this.generatedColumns == null) { - return execute(this.tokenizedSql.getSql()); + return execute(this.parsedSql.getSql()); } - return execute(GeneratedValuesUtils.augment(this.tokenizedSql.getSql(), this.generatedColumns)); + return execute(GeneratedValuesUtils.augment(this.parsedSql.getSql(), this.generatedColumns)); } @Override public PostgresqlStatement returnGeneratedValues(String... columns) { Assert.requireNonNull(columns, "columns must not be null"); - if (this.tokenizedSql.hasDefaultTokenValue("RETURNING")) { + if (this.parsedSql.hasDefaultTokenValue("RETURNING")) { throw new IllegalStateException("Statement already includes RETURNING clause"); } - if (!this.tokenizedSql.hasDefaultTokenValue("DELETE", "INSERT", "UPDATE")) { + if (!this.parsedSql.hasDefaultTokenValue("DELETE", "INSERT", "UPDATE")) { throw new IllegalStateException("Statement is not a DELETE, INSERT, or UPDATE command"); } @@ -174,7 +174,7 @@ public String toString() { return "PostgresqlStatement{" + "bindings=" + this.bindings + ", context=" + this.resources + - ", sql='" + this.tokenizedSql.getSql() + '\'' + + ", sql='" + this.parsedSql.getSql() + '\'' + ", generatedColumns=" + Arrays.toString(this.generatedColumns) + '}'; } @@ -199,7 +199,7 @@ private int getIdentifierIndex(String identifier) { private Flux execute(String sql) { ExceptionFactory factory = ExceptionFactory.withSql(sql); - if (this.tokenizedSql.getParameterCount() != 0) { + if (this.parsedSql.getParameterCount() != 0) { // Extended query protocol if (this.bindings.size() == 0) { throw new IllegalStateException("No parameters have been bound"); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java deleted file mode 100644 index 679ec121a..000000000 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Copyright 2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.r2dbc.postgresql; - -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class PostgresqlSqlLexerTest { - - @Nested - class SingleStatementTests { - - @Nested - class SingleTokenTests { - - @Test - void singleQuotedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("'Test'", TokenizedSql.TokenType.STRING_CONSTANT); - } - - @Test - void dollarQuotedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("$$test$$", TokenizedSql.TokenType.STRING_CONSTANT); - } - - @Test - void dollarQuotedTaggedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("$a$test$a$", TokenizedSql.TokenType.STRING_CONSTANT); - } - - @Test - void quotedIdentifierIsTokenized() { - assertSingleStatementEqualsCompleteToken("\"test\"", TokenizedSql.TokenType.QUOTED_IDENTIFIER); - } - - @Test - void lineCommentIsTokenized() { - assertSingleStatementEqualsCompleteToken("--test", TokenizedSql.TokenType.COMMENT); - } - - @Test - void cStyleCommentIsTokenized() { - assertSingleStatementEqualsCompleteToken("/*Test*/", TokenizedSql.TokenType.COMMENT); - assertSingleStatementEqualsCompleteToken("/**/", TokenizedSql.TokenType.COMMENT); - assertSingleStatementEqualsCompleteToken("/*T*/", TokenizedSql.TokenType.COMMENT); - } - - @Test - void nestedCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*/*Test*/*/", TokenizedSql.TokenType.COMMENT); - } - - @Test - void windowsMultiLineCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*Test\r\n Test*/", TokenizedSql.TokenType.COMMENT); - } - - @Test - void unixMultiLineCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*Test\n Test*/", TokenizedSql.TokenType.COMMENT); - } - - @Test - void digitIsTokenizedAsDefaultToken() { - assertSingleStatementEqualsCompleteToken("1", TokenizedSql.TokenType.DEFAULT); - } - - @Test - void alphaIsTokenizedAsDefaultToken() { - assertSingleStatementEqualsCompleteToken("a", TokenizedSql.TokenType.DEFAULT); - } - - @Test - void multipleDefaultTokensAreTokenizedAsSingleDefaultToken() { - assertSingleStatementEqualsCompleteToken("atest123", TokenizedSql.TokenType.DEFAULT); - } - - @Test - void parameterIsTokenized() { - assertSingleStatementEqualsCompleteToken("$1", TokenizedSql.TokenType.PARAMETER); - } - - @Test - void statementEndIsTokenized() { - assertSingleStatementEqualsCompleteToken(";", TokenizedSql.TokenType.STATEMENT_END); - } - - void assertSingleStatementEqualsCompleteToken(String sql, TokenizedSql.TokenType token) { - assertSingleStatementEquals(sql, new TokenizedSql.Token(token, sql)); - } - - } - - @Nested - class SingleTokenExceptionTests { - - @Test - void unclosedSingleQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("'test")); - } - - @Test - void unclosedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$$test")); - } - - @Test - void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$abc$test")); - } - - @Test - void unclosedQuotedIdentifierThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("\"test")); - } - - @Test - void unclosedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("/*test")); - } - - @Test - void unclosedNestedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("/*/*test*/")); - } - - @Test - void invalidParameterCharacterThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$1test")); - } - - @Test - void invalidTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$a b$test$a b$")); - } - - @Test - void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$abc")); - } - } - - @Nested - class MultipleTokenTests { - - @Test - void defaultTokenIsEndedBySpecialCharacter() { - assertSingleStatementEquals("abc[", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "[")); - } - - @Test - void defaultTokenIsEndedByOperatorCharacter() { - assertSingleStatementEquals("abc-", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "-")); - } - - @Test - void defaultTokenIsEndedByStatementEndCharacter() { - assertSingleStatementEquals("abc;", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";")); - } - - @Test - void defaultTokenIsEndedByQuoteCharacter() { - assertSingleStatementEquals("abc\"def\"", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, "\"def\"")); - } - - @Test - void parameterTokenIsEndedByQuoteCharacter() { - assertSingleStatementEquals("$1+", - new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, "$1"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "+")); - } - - @Test - void parameterIsRecognizedBetweenSpecialCharacters() { - assertSingleStatementEquals("($1)", - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "("), - new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, "$1"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, ")") - ); - } - - @Test - void lineCommentIsEndedAtNewline() { - assertSingleStatementEquals("--abc\ndef", - new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, "--abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "def")); - } - - @Test - void multipleOperatorsAreSeparatelyTokenized() { - assertSingleStatementEquals("**", - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*") - ); - } - - } - - @Nested - class AssortedRealStatementTests { - - @Test - void simpleSelectStatementIsTokenized() { - assertSingleStatementEquals("SELECT * FROM /* A Comment */ table WHERE \"SELECT\" = $1", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "SELECT"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "FROM"), - new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, "/* A Comment */"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "table"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "WHERE"), - new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, "\"SELECT\""), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "="), - new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, "$1") - ); - } - - } - - void assertSingleStatementEquals(String sql, TokenizedSql.Token... tokens) { - TokenizedSql tokenizedSql = PostgresqlSqlLexer.tokenize(sql); - assertEquals(1, tokenizedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); - TokenizedSql.TokenizedStatement statement = tokenizedSql.getStatements().get(0); - assertEquals(new TokenizedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement); - } - - } - - @Nested - class MultipleStatementTests { - - @Test - void simpleMultipleStatementIsTokenized() { - TokenizedSql tokenizedSql = PostgresqlSqlLexer.tokenize("DELETE * FROM X; SELECT 1;"); - List statements = tokenizedSql.getStatements(); - assertEquals(2, statements.size()); - TokenizedSql.TokenizedStatement statementA = statements.get(0); - TokenizedSql.TokenizedStatement statementB = statements.get(1); - - assertEquals(new TokenizedSql.TokenizedStatement("DELETE * FROM X;", - Arrays.asList( - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "DELETE"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "FROM"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "X"), - new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";") - )), - statementA - ); - - assertEquals(new TokenizedSql.TokenizedStatement("SELECT 1;", - Arrays.asList( - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "SELECT"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "1"), - new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";") - )), - statementB - ); - - } - - } - -} diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java new file mode 100644 index 000000000..ad1008b98 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java @@ -0,0 +1,313 @@ +/* + * Copyright 2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class PostgresqlSqlParserTest { + + @Nested + class SingleStatementTests { + + @Nested + class SingleTokenTests { + + @Test + void singleQuotedStringIsTokenized() { + assertSingleStatementEqualsCompleteToken("'Test'", ParsedSql.TokenType.STRING_CONSTANT); + } + + @Test + void dollarQuotedStringIsTokenized() { + assertSingleStatementEqualsCompleteToken("$$test$$", ParsedSql.TokenType.STRING_CONSTANT); + } + + @Test + void dollarQuotedTaggedStringIsTokenized() { + assertSingleStatementEqualsCompleteToken("$a$test$a$", ParsedSql.TokenType.STRING_CONSTANT); + } + + @Test + void quotedIdentifierIsTokenized() { + assertSingleStatementEqualsCompleteToken("\"test\"", ParsedSql.TokenType.QUOTED_IDENTIFIER); + } + + @Test + void lineCommentIsTokenized() { + assertSingleStatementEqualsCompleteToken("--test", ParsedSql.TokenType.COMMENT); + } + + @Test + void cStyleCommentIsTokenized() { + assertSingleStatementEqualsCompleteToken("/*Test*/", ParsedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/**/", ParsedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/*T*/", ParsedSql.TokenType.COMMENT); + } + + @Test + void nestedCStyleCommentIsTokenizedAsSingleToken() { + assertSingleStatementEqualsCompleteToken("/*/*Test*/*/", ParsedSql.TokenType.COMMENT); + } + + @Test + void windowsMultiLineCStyleCommentIsTokenizedAsSingleToken() { + assertSingleStatementEqualsCompleteToken("/*Test\r\n Test*/", ParsedSql.TokenType.COMMENT); + } + + @Test + void unixMultiLineCStyleCommentIsTokenizedAsSingleToken() { + assertSingleStatementEqualsCompleteToken("/*Test\n Test*/", ParsedSql.TokenType.COMMENT); + } + + @Test + void digitIsTokenizedAsDefaultToken() { + assertSingleStatementEqualsCompleteToken("1", ParsedSql.TokenType.DEFAULT); + } + + @Test + void alphaIsTokenizedAsDefaultToken() { + assertSingleStatementEqualsCompleteToken("a", ParsedSql.TokenType.DEFAULT); + } + + @Test + void multipleDefaultTokensAreTokenizedAsSingleDefaultToken() { + assertSingleStatementEqualsCompleteToken("atest123", ParsedSql.TokenType.DEFAULT); + } + + @Test + void parameterIsTokenized() { + assertSingleStatementEqualsCompleteToken("$1", ParsedSql.TokenType.PARAMETER); + } + + @Test + void statementEndIsTokenized() { + assertSingleStatementEqualsCompleteToken(";", ParsedSql.TokenType.STATEMENT_END); + } + + void assertSingleStatementEqualsCompleteToken(String sql, ParsedSql.TokenType token) { + assertSingleStatementEquals(sql, new ParsedSql.Token(token, sql)); + } + + } + + @Nested + class SingleTokenExceptionTests { + + @Test + void unclosedSingleQuotedStringThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("'test")); + } + + @Test + void unclosedDollarQuotedStringThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$$test")); + } + + @Test + void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc$test")); + } + + @Test + void unclosedQuotedIdentifierThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("\"test")); + } + + @Test + void unclosedBlockCommentThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*test")); + } + + @Test + void unclosedNestedBlockCommentThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*/*test*/")); + } + + @Test + void invalidParameterCharacterThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$1test")); + } + + @Test + void invalidTaggedDollarQuoteThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$a b$test$a b$")); + } + + @Test + void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc")); + } + + } + + @Nested + class MultipleTokenTests { + + @Test + void defaultTokenIsEndedBySpecialCharacter() { + assertSingleStatementEquals("abc[", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "[")); + } + + @Test + void defaultTokenIsEndedByOperatorCharacter() { + assertSingleStatementEquals("abc-", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "-")); + } + + @Test + void defaultTokenIsEndedByStatementEndCharacter() { + assertSingleStatementEquals("abc;", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")); + } + + @Test + void defaultTokenIsEndedByQuoteCharacter() { + assertSingleStatementEquals("abc\"def\"", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, "\"def\"")); + } + + @Test + void parameterTokenIsEndedByQuoteCharacter() { + assertSingleStatementEquals("$1+", + new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, "$1"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "+")); + } + + @Test + void parameterIsRecognizedBetweenSpecialCharacters() { + assertSingleStatementEquals("($1)", + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("), + new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, "$1"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")") + ); + } + + @Test + void lineCommentIsEndedAtNewline() { + assertSingleStatementEquals("--abc\ndef", + new ParsedSql.Token(ParsedSql.TokenType.COMMENT, "--abc"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "def")); + } + + @Test + void multipleOperatorsAreSeparatelyTokenized() { + assertSingleStatementEquals("**", + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*") + ); + } + + } + + @Nested + class AssortedRealStatementTests { + + @Test + void simpleSelectStatementIsTokenized() { + assertSingleStatementEquals("SELECT * FROM /* A Comment */ table WHERE \"SELECT\" = $1", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), + new ParsedSql.Token(ParsedSql.TokenType.COMMENT, "/* A Comment */"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "table"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "WHERE"), + new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, "\"SELECT\""), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "="), + new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, "$1") + ); + } + + @Test + void simpleSelectStatementWithFunctionBodyIsTokenized() { + assertSingleStatementEquals("CREATE FUNCTION test() BEGIN ATOMIC SELECT 1; SELECT 2; END", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "CREATE"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FUNCTION"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "test"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "BEGIN"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "ATOMIC"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "2"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "END") + ); + } + + } + + void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) { + ParsedSql parsedSql = PostgresqlSqlParser.parse(sql); + assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); + ParsedSql.Statement statement = parsedSql.getStatements().get(0); + assertIterableEquals(Arrays.asList(tokens), statement.getTokens()); + } + + } + + @Nested + class MultipleStatementTests { + + @Test + void simpleMultipleStatementIsTokenized() { + ParsedSql parsedSql = PostgresqlSqlParser.parse("DELETE * FROM X; SELECT 1;"); + List statements = parsedSql.getStatements(); + assertEquals(2, statements.size()); + ParsedSql.Statement statementA = statements.get(0); + ParsedSql.Statement statementB = statements.get(1); + + assertIterableEquals( + Arrays.asList( + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") + ), + statementA.getTokens() + ); + + assertIterableEquals( + Arrays.asList( + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") + ), + statementB.getTokens() + ); + + } + + } + +} From e3aa98da8b84857b97e1774d464548adc7bd8abb Mon Sep 17 00:00:00 2001 From: toverdijk Date: Tue, 24 May 2022 17:11:07 +0200 Subject: [PATCH 2/3] - Renamed PostgresqlLexer to PostgresqlParser - Renamed TokenizedSql to ParsedSql --- .../{TokenizedSql.java => ParsedSql.java} | 4 +- .../io/r2dbc/postgresql/PostgresqlBatch.java | 2 +- ...SqlLexer.java => PostgresqlSqlParser.java} | 58 +++---- .../r2dbc/postgresql/PostgresqlStatement.java | 28 ++-- ...Test.java => PostgresqlSqlParserTest.java} | 146 +++++++++--------- 5 files changed, 119 insertions(+), 119 deletions(-) rename src/main/java/io/r2dbc/postgresql/{TokenizedSql.java => ParsedSql.java} (98%) rename src/main/java/io/r2dbc/postgresql/{PostgresqlSqlLexer.java => PostgresqlSqlParser.java} (71%) rename src/test/java/io/r2dbc/postgresql/{PostgresqlSqlLexerTest.java => PostgresqlSqlParserTest.java} (55%) diff --git a/src/main/java/io/r2dbc/postgresql/TokenizedSql.java b/src/main/java/io/r2dbc/postgresql/ParsedSql.java similarity index 98% rename from src/main/java/io/r2dbc/postgresql/TokenizedSql.java rename to src/main/java/io/r2dbc/postgresql/ParsedSql.java index 1f692aea7..e1ee68046 100644 --- a/src/main/java/io/r2dbc/postgresql/TokenizedSql.java +++ b/src/main/java/io/r2dbc/postgresql/ParsedSql.java @@ -20,7 +20,7 @@ import java.util.Set; import java.util.TreeSet; -class TokenizedSql { +class ParsedSql { private final String sql; @@ -30,7 +30,7 @@ class TokenizedSql { private final int parameterCount; - public TokenizedSql(String sql, List statements) { + public ParsedSql(String sql, List statements) { this.sql = sql; this.statements = statements; this.statementCount = statements.size(); diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java index d1f5ccd13..0a68118a4 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java @@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch { public PostgresqlBatch add(String sql) { Assert.requireNonNull(sql, "sql must not be null"); - if (!(PostgresqlSqlLexer.tokenize(sql).getParameterCount() == 0)) { + if (!(PostgresqlSqlParser.tokenize(sql).getParameterCount() == 0)) { throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql)); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlLexer.java b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java similarity index 71% rename from src/main/java/io/r2dbc/postgresql/PostgresqlSqlLexer.java rename to src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java index 4139648cf..9932ac17a 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlLexer.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java @@ -27,7 +27,7 @@ * * @since 0.9 */ -class PostgresqlSqlLexer { +class PostgresqlSqlParser { private static final char[] SPECIAL_AND_OPERATOR_CHARS = { '+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?', @@ -38,15 +38,15 @@ class PostgresqlSqlLexer { Arrays.sort(SPECIAL_AND_OPERATOR_CHARS); } - public static TokenizedSql tokenize(String sql) { - List tokens = new ArrayList<>(); - List statements = new ArrayList<>(); + public static ParsedSql tokenize(String sql) { + List tokens = new ArrayList<>(); + List statements = new ArrayList<>(); int statementStartIndex = 0; int i = 0; while (i < sql.length()) { char c = sql.charAt(i); - TokenizedSql.Token token = null; + ParsedSql.Token token = null; if (isWhitespace(c)) { i++; @@ -73,14 +73,14 @@ public static TokenizedSql tokenize(String sql) { token = getParameterOrDollarQuoteToken(sql, i); break; case ';': - token = new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";"); + token = new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"); break; default: break; } if (token == null) { if (isSpecialOrOperatorChar(c)) { - token = new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, Character.toString(c));//getSpecialOrOperatorToken(sql, i); + token = new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, Character.toString(c));//getSpecialOrOperatorToken(sql, i); } else { token = getDefaultToken(sql, i); } @@ -88,10 +88,10 @@ public static TokenizedSql tokenize(String sql) { i += token.getValue().length(); - if (token.getType() == TokenizedSql.TokenType.STATEMENT_END) { + if (token.getType() == ParsedSql.TokenType.STATEMENT_END) { tokens.add(token); - statements.add(new TokenizedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens)); + statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens)); tokens = new ArrayList<>(); statementStartIndex = i + 1; @@ -101,27 +101,27 @@ public static TokenizedSql tokenize(String sql) { } // If tokens is not empty, implicit statement end if (!tokens.isEmpty()) { - statements.add(new TokenizedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens)); + statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens)); } - return new TokenizedSql(sql, statements); + return new ParsedSql(sql, statements); } - private static TokenizedSql.Token getDefaultToken(String sql, int beginIndex) { + private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) { for (int i = beginIndex + 1; i < sql.length(); i++) { char c = sql.charAt(i); if (Character.isWhitespace(c) || isSpecialOrOperatorChar(c)) { - return new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, sql.substring(beginIndex, i)); + return new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, sql.substring(beginIndex, i)); } } - return new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, sql.substring(beginIndex)); + return new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, sql.substring(beginIndex)); } private static boolean isSpecialOrOperatorChar(char c) { return Arrays.binarySearch(SPECIAL_AND_OPERATOR_CHARS, c) >= 0; } - private static TokenizedSql.Token getBlockCommentToken(String sql, int beginIndex) { + private static ParsedSql.Token getBlockCommentToken(String sql, int beginIndex) { int depth = 1; for (int i = beginIndex + 2; i < (sql.length() - 1); i++) { char c1 = sql.charAt(i); @@ -134,44 +134,44 @@ private static TokenizedSql.Token getBlockCommentToken(String sql, int beginInde i++; } if (depth == 0) { - return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex, i + 1)); + return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex, i + 1)); } } throw new IllegalArgumentException("Sql cannot be parsed: unclosed block comment (comment opened at index " + beginIndex + ") in statement: " + sql); } - private static TokenizedSql.Token getCommentToLineEndToken(String sql, int beginIndex) { + private static ParsedSql.Token getCommentToLineEndToken(String sql, int beginIndex) { int lineEnding = sql.indexOf('\n', beginIndex); if (lineEnding == -1) { - return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex)); + return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex)); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex, lineEnding)); + return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex, lineEnding)); } } - private static TokenizedSql.Token getDollarQuoteToken(String sql, String tag, int beginIndex) { + private static ParsedSql.Token getDollarQuoteToken(String sql, String tag, int beginIndex) { int nextQuote = sql.indexOf(tag, beginIndex + tag.length()); if (nextQuote == -1) { throw new IllegalArgumentException("Sql cannot be parsed: unclosed quote (quote opened at index " + beginIndex + ") in statement: " + sql); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + tag.length())); + return new ParsedSql.Token(ParsedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + tag.length())); } } - private static TokenizedSql.Token getParameterToken(String sql, int beginIndex) { + private static ParsedSql.Token getParameterToken(String sql, int beginIndex) { for (int i = beginIndex + 1; i < sql.length(); i++) { char c = sql.charAt(i); if (isWhitespace(c) || isSpecialOrOperatorChar(c)) { - return new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, sql.substring(beginIndex, i)); + return new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, sql.substring(beginIndex, i)); } if (!isAsciiDigit(c)) { throw new IllegalArgumentException("Sql cannot be parsed: illegal character in parameter or dollar-quote tag: " + c); } } - return new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, sql.substring(beginIndex)); + return new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, sql.substring(beginIndex)); } - private static TokenizedSql.Token getParameterOrDollarQuoteToken(String sql, int beginIndex) { + private static ParsedSql.Token getParameterOrDollarQuoteToken(String sql, int beginIndex) { char firstChar = sql.charAt(beginIndex + 1); if (firstChar == '$') { return getDollarQuoteToken(sql, "$$", beginIndex); @@ -191,21 +191,21 @@ private static TokenizedSql.Token getParameterOrDollarQuoteToken(String sql, int } } - private static TokenizedSql.Token getStandardQuoteToken(String sql, int beginIndex) { + private static ParsedSql.Token getStandardQuoteToken(String sql, int beginIndex) { int nextQuote = sql.indexOf('\'', beginIndex + 1); if (nextQuote == -1) { throw new IllegalArgumentException("Sql cannot be parsed: unclosed quote (quote opened at index " + beginIndex + ") in statement: " + sql); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + 1)); + return new ParsedSql.Token(ParsedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + 1)); } } - private static TokenizedSql.Token getQuotedIdentifierToken(String sql, int beginIndex) { + private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginIndex) { int nextQuote = sql.indexOf('\"', beginIndex + 1); if (nextQuote == -1) { throw new IllegalArgumentException("Sql cannot be parsed: unclosed quoted identifier (identifier opened at index " + beginIndex + ") in statement: " + sql); } else { - return new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, sql.substring(beginIndex, nextQuote + 1)); + return new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, sql.substring(beginIndex, nextQuote + 1)); } } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java index d09d8b78e..62b49e5ec 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java @@ -65,7 +65,7 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta private final ConnectionContext connectionContext; - private final TokenizedSql tokenizedSql; + private final ParsedSql parsedSql; private int fetchSize; @@ -73,11 +73,11 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta PostgresqlStatement(ConnectionResources resources, String sql) { this.resources = Assert.requireNonNull(resources, "resources must not be null"); - this.tokenizedSql = PostgresqlSqlLexer.tokenize(Assert.requireNonNull(sql, "sql must not be null")); + this.parsedSql = PostgresqlSqlParser.tokenize(Assert.requireNonNull(sql, "sql must not be null")); this.connectionContext = resources.getClient().getContext(); - this.bindings = new ArrayDeque<>(this.tokenizedSql.getParameterCount()); + this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount()); - if (this.tokenizedSql.getStatementCount() > 1 && this.tokenizedSql.getParameterCount() > 0) { + if (this.parsedSql.getStatementCount() > 1 && this.parsedSql.getParameterCount() > 0) { throw new IllegalArgumentException(String.format("Statement '%s' cannot be created. This is often due to the presence of both multiple statements and parameters at the same time.", sql)); } @@ -90,7 +90,7 @@ public PostgresqlStatement add() { if (binding != null) { binding.validate(); } - this.bindings.add(new Binding(this.tokenizedSql.getParameterCount())); + this.bindings.add(new Binding(this.parsedSql.getParameterCount())); return this; } @@ -117,8 +117,8 @@ public PostgresqlStatement bindNull(String identifier, Class type) { public PostgresqlStatement bindNull(int index, Class type) { Assert.requireNonNull(type, "type must not be null"); - if (index >= this.tokenizedSql.getParameterCount()) { - throw new UnsupportedOperationException(String.format("Cannot bind parameter %d, statement has %d parameters", index, this.tokenizedSql.getParameterCount())); + if (index >= this.parsedSql.getParameterCount()) { + throw new UnsupportedOperationException(String.format("Cannot bind parameter %d, statement has %d parameters", index, this.parsedSql.getParameterCount())); } BindingLogger.logBindNull(this.connectionContext, index, type); @@ -130,7 +130,7 @@ public PostgresqlStatement bindNull(int index, Class type) { private Binding getCurrentOrFirstBinding() { Binding binding = this.bindings.peekLast(); if (binding == null) { - Binding newBinding = new Binding(this.tokenizedSql.getParameterCount()); + Binding newBinding = new Binding(this.parsedSql.getParameterCount()); this.bindings.add(newBinding); return newBinding; } else { @@ -141,20 +141,20 @@ private Binding getCurrentOrFirstBinding() { @Override public Flux execute() { if (this.generatedColumns == null) { - return execute(this.tokenizedSql.getSql()); + return execute(this.parsedSql.getSql()); } - return execute(GeneratedValuesUtils.augment(this.tokenizedSql.getSql(), this.generatedColumns)); + return execute(GeneratedValuesUtils.augment(this.parsedSql.getSql(), this.generatedColumns)); } @Override public PostgresqlStatement returnGeneratedValues(String... columns) { Assert.requireNonNull(columns, "columns must not be null"); - if (this.tokenizedSql.hasDefaultTokenValue("RETURNING")) { + if (this.parsedSql.hasDefaultTokenValue("RETURNING")) { throw new IllegalStateException("Statement already includes RETURNING clause"); } - if (!this.tokenizedSql.hasDefaultTokenValue("DELETE", "INSERT", "UPDATE")) { + if (!this.parsedSql.hasDefaultTokenValue("DELETE", "INSERT", "UPDATE")) { throw new IllegalStateException("Statement is not a DELETE, INSERT, or UPDATE command"); } @@ -174,7 +174,7 @@ public String toString() { return "PostgresqlStatement{" + "bindings=" + this.bindings + ", context=" + this.resources + - ", sql='" + this.tokenizedSql.getSql() + '\'' + + ", sql='" + this.parsedSql.getSql() + '\'' + ", generatedColumns=" + Arrays.toString(this.generatedColumns) + '}'; } @@ -199,7 +199,7 @@ private int getIdentifierIndex(String identifier) { private Flux execute(String sql) { ExceptionFactory factory = ExceptionFactory.withSql(sql); - if (this.tokenizedSql.getParameterCount() != 0) { + if (this.parsedSql.getParameterCount() != 0) { // Extended query protocol if (this.bindings.size() == 0) { throw new IllegalStateException("No parameters have been bound"); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java similarity index 55% rename from src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java rename to src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java index 679ec121a..08ba9428d 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlLexerTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java @@ -25,7 +25,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -class PostgresqlSqlLexerTest { +class PostgresqlSqlParserTest { @Nested class SingleStatementTests { @@ -35,78 +35,78 @@ class SingleTokenTests { @Test void singleQuotedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("'Test'", TokenizedSql.TokenType.STRING_CONSTANT); + assertSingleStatementEqualsCompleteToken("'Test'", ParsedSql.TokenType.STRING_CONSTANT); } @Test void dollarQuotedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("$$test$$", TokenizedSql.TokenType.STRING_CONSTANT); + assertSingleStatementEqualsCompleteToken("$$test$$", ParsedSql.TokenType.STRING_CONSTANT); } @Test void dollarQuotedTaggedStringIsTokenized() { - assertSingleStatementEqualsCompleteToken("$a$test$a$", TokenizedSql.TokenType.STRING_CONSTANT); + assertSingleStatementEqualsCompleteToken("$a$test$a$", ParsedSql.TokenType.STRING_CONSTANT); } @Test void quotedIdentifierIsTokenized() { - assertSingleStatementEqualsCompleteToken("\"test\"", TokenizedSql.TokenType.QUOTED_IDENTIFIER); + assertSingleStatementEqualsCompleteToken("\"test\"", ParsedSql.TokenType.QUOTED_IDENTIFIER); } @Test void lineCommentIsTokenized() { - assertSingleStatementEqualsCompleteToken("--test", TokenizedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("--test", ParsedSql.TokenType.COMMENT); } @Test void cStyleCommentIsTokenized() { - assertSingleStatementEqualsCompleteToken("/*Test*/", TokenizedSql.TokenType.COMMENT); - assertSingleStatementEqualsCompleteToken("/**/", TokenizedSql.TokenType.COMMENT); - assertSingleStatementEqualsCompleteToken("/*T*/", TokenizedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/*Test*/", ParsedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/**/", ParsedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/*T*/", ParsedSql.TokenType.COMMENT); } @Test void nestedCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*/*Test*/*/", TokenizedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/*/*Test*/*/", ParsedSql.TokenType.COMMENT); } @Test void windowsMultiLineCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*Test\r\n Test*/", TokenizedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/*Test\r\n Test*/", ParsedSql.TokenType.COMMENT); } @Test void unixMultiLineCStyleCommentIsTokenizedAsSingleToken() { - assertSingleStatementEqualsCompleteToken("/*Test\n Test*/", TokenizedSql.TokenType.COMMENT); + assertSingleStatementEqualsCompleteToken("/*Test\n Test*/", ParsedSql.TokenType.COMMENT); } @Test void digitIsTokenizedAsDefaultToken() { - assertSingleStatementEqualsCompleteToken("1", TokenizedSql.TokenType.DEFAULT); + assertSingleStatementEqualsCompleteToken("1", ParsedSql.TokenType.DEFAULT); } @Test void alphaIsTokenizedAsDefaultToken() { - assertSingleStatementEqualsCompleteToken("a", TokenizedSql.TokenType.DEFAULT); + assertSingleStatementEqualsCompleteToken("a", ParsedSql.TokenType.DEFAULT); } @Test void multipleDefaultTokensAreTokenizedAsSingleDefaultToken() { - assertSingleStatementEqualsCompleteToken("atest123", TokenizedSql.TokenType.DEFAULT); + assertSingleStatementEqualsCompleteToken("atest123", ParsedSql.TokenType.DEFAULT); } @Test void parameterIsTokenized() { - assertSingleStatementEqualsCompleteToken("$1", TokenizedSql.TokenType.PARAMETER); + assertSingleStatementEqualsCompleteToken("$1", ParsedSql.TokenType.PARAMETER); } @Test void statementEndIsTokenized() { - assertSingleStatementEqualsCompleteToken(";", TokenizedSql.TokenType.STATEMENT_END); + assertSingleStatementEqualsCompleteToken(";", ParsedSql.TokenType.STATEMENT_END); } - void assertSingleStatementEqualsCompleteToken(String sql, TokenizedSql.TokenType token) { - assertSingleStatementEquals(sql, new TokenizedSql.Token(token, sql)); + void assertSingleStatementEqualsCompleteToken(String sql, ParsedSql.TokenType token) { + assertSingleStatementEquals(sql, new ParsedSql.Token(token, sql)); } } @@ -116,47 +116,47 @@ class SingleTokenExceptionTests { @Test void unclosedSingleQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("'test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("'test")); } @Test void unclosedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$$test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$$test")); } @Test void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$abc$test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc$test")); } @Test void unclosedQuotedIdentifierThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("\"test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("\"test")); } @Test void unclosedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("/*test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*test")); } @Test void unclosedNestedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("/*/*test*/")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*/*test*/")); } @Test void invalidParameterCharacterThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$1test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$1test")); } @Test void invalidTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$a b$test$a b$")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$a b$test$a b$")); } @Test void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlLexer.tokenize("$abc")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc")); } } @@ -166,59 +166,59 @@ class MultipleTokenTests { @Test void defaultTokenIsEndedBySpecialCharacter() { assertSingleStatementEquals("abc[", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "[")); + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "[")); } @Test void defaultTokenIsEndedByOperatorCharacter() { assertSingleStatementEquals("abc-", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "-")); + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "-")); } @Test void defaultTokenIsEndedByStatementEndCharacter() { assertSingleStatementEquals("abc;", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";")); + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";")); } @Test void defaultTokenIsEndedByQuoteCharacter() { assertSingleStatementEquals("abc\"def\"", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, "\"def\"")); + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "abc"), + new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, "\"def\"")); } @Test void parameterTokenIsEndedByQuoteCharacter() { assertSingleStatementEquals("$1+", - new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, "$1"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "+")); + new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, "$1"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "+")); } @Test void parameterIsRecognizedBetweenSpecialCharacters() { assertSingleStatementEquals("($1)", - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "("), - new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, "$1"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, ")") + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("), + new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, "$1"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")") ); } @Test void lineCommentIsEndedAtNewline() { assertSingleStatementEquals("--abc\ndef", - new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, "--abc"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "def")); + new ParsedSql.Token(ParsedSql.TokenType.COMMENT, "--abc"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "def")); } @Test void multipleOperatorsAreSeparatelyTokenized() { assertSingleStatementEquals("**", - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*") + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*") ); } @@ -230,25 +230,25 @@ class AssortedRealStatementTests { @Test void simpleSelectStatementIsTokenized() { assertSingleStatementEquals("SELECT * FROM /* A Comment */ table WHERE \"SELECT\" = $1", - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "SELECT"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "FROM"), - new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, "/* A Comment */"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "table"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "WHERE"), - new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, "\"SELECT\""), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "="), - new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, "$1") + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), + new ParsedSql.Token(ParsedSql.TokenType.COMMENT, "/* A Comment */"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "table"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "WHERE"), + new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, "\"SELECT\""), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "="), + new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, "$1") ); } } - void assertSingleStatementEquals(String sql, TokenizedSql.Token... tokens) { - TokenizedSql tokenizedSql = PostgresqlSqlLexer.tokenize(sql); - assertEquals(1, tokenizedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); - TokenizedSql.TokenizedStatement statement = tokenizedSql.getStatements().get(0); - assertEquals(new TokenizedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement); + void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) { + ParsedSql parsedSql = PostgresqlSqlParser.tokenize(sql); + assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); + ParsedSql.TokenizedStatement statement = parsedSql.getStatements().get(0); + assertEquals(new ParsedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement); } } @@ -258,28 +258,28 @@ class MultipleStatementTests { @Test void simpleMultipleStatementIsTokenized() { - TokenizedSql tokenizedSql = PostgresqlSqlLexer.tokenize("DELETE * FROM X; SELECT 1;"); - List statements = tokenizedSql.getStatements(); + ParsedSql parsedSql = PostgresqlSqlParser.tokenize("DELETE * FROM X; SELECT 1;"); + List statements = parsedSql.getStatements(); assertEquals(2, statements.size()); - TokenizedSql.TokenizedStatement statementA = statements.get(0); - TokenizedSql.TokenizedStatement statementB = statements.get(1); + ParsedSql.TokenizedStatement statementA = statements.get(0); + ParsedSql.TokenizedStatement statementB = statements.get(1); - assertEquals(new TokenizedSql.TokenizedStatement("DELETE * FROM X;", + assertEquals(new ParsedSql.TokenizedStatement("DELETE * FROM X;", Arrays.asList( - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "DELETE"), - new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "FROM"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "X"), - new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";") + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") )), statementA ); - assertEquals(new TokenizedSql.TokenizedStatement("SELECT 1;", + assertEquals(new ParsedSql.TokenizedStatement("SELECT 1;", Arrays.asList( - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "SELECT"), - new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, "1"), - new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";") + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") )), statementB ); From e55aab17424b1599214ad6247430ea08b77a33db Mon Sep 17 00:00:00 2001 From: toverdijk Date: Tue, 24 May 2022 17:13:13 +0200 Subject: [PATCH 3/3] - Lexing/parsing is now done in two steps: first only tokenize, then parse into statements - Added support for function bodies ("BEGIN ATOMIC") - Added a test case for newly supported grammar --- .../java/io/r2dbc/postgresql/ParsedSql.java | 32 +++---- .../io/r2dbc/postgresql/PostgresqlBatch.java | 2 +- .../r2dbc/postgresql/PostgresqlSqlParser.java | 57 ++++++++---- .../r2dbc/postgresql/PostgresqlStatement.java | 2 +- .../postgresql/PostgresqlSqlParserTest.java | 88 ++++++++++++------- 5 files changed, 108 insertions(+), 73 deletions(-) diff --git a/src/main/java/io/r2dbc/postgresql/ParsedSql.java b/src/main/java/io/r2dbc/postgresql/ParsedSql.java index e1ee68046..3480d6b97 100644 --- a/src/main/java/io/r2dbc/postgresql/ParsedSql.java +++ b/src/main/java/io/r2dbc/postgresql/ParsedSql.java @@ -24,20 +24,20 @@ class ParsedSql { private final String sql; - private final List statements; + private final List statements; private final int statementCount; private final int parameterCount; - public ParsedSql(String sql, List statements) { + public ParsedSql(String sql, List statements) { this.sql = sql; this.statements = statements; this.statementCount = statements.size(); this.parameterCount = getParameterCount(statements); } - List getStatements() { + List getStatements() { return this.statements; } @@ -53,16 +53,16 @@ public String getSql() { return sql; } - private static int getParameterCount(List statements) { + private static int getParameterCount(List statements) { int sum = 0; - for (TokenizedStatement statement : statements){ + for (Statement statement : statements){ sum += statement.getParameterCount(); } return sum; } public boolean hasDefaultTokenValue(String... tokenValues) { - for (TokenizedStatement statement : this.statements) { + for (Statement statement : this.statements) { for (Token token : statement.getTokens()) { if (token.getType() == TokenType.DEFAULT) { for (String value : tokenValues) { @@ -129,24 +129,17 @@ public String toString() { } - static class TokenizedStatement { - - private final String sql; + static class Statement { private final List tokens; private final int parameterCount; - public TokenizedStatement(String sql, List tokens) { + public Statement(List tokens) { this.tokens = tokens; - this.sql = sql; this.parameterCount = readParameterCount(tokens); } - public String getSql() { - return this.sql; - } - public List getTokens() { return this.tokens; } @@ -164,19 +157,14 @@ public boolean equals(Object o) { return false; } - TokenizedStatement that = (TokenizedStatement) o; + Statement that = (Statement) o; - if (!this.sql.equals(that.sql)) { - return false; - } return this.tokens.equals(that.tokens); } @Override public int hashCode() { - int result = this.sql.hashCode(); - result = 31 * result + this.tokens.hashCode(); - return result; + return this.tokens.hashCode(); } @Override diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java index 0a68118a4..41ed92917 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java @@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch { public PostgresqlBatch add(String sql) { Assert.requireNonNull(sql, "sql must not be null"); - if (!(PostgresqlSqlParser.tokenize(sql).getParameterCount() == 0)) { + if (!(PostgresqlSqlParser.parse(sql).getParameterCount() == 0)) { throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql)); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java index 9932ac17a..b9884a370 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlSqlParser.java @@ -38,11 +38,8 @@ class PostgresqlSqlParser { Arrays.sort(SPECIAL_AND_OPERATOR_CHARS); } - public static ParsedSql tokenize(String sql) { + private static List tokenize(String sql) { List tokens = new ArrayList<>(); - List statements = new ArrayList<>(); - - int statementStartIndex = 0; int i = 0; while (i < sql.length()) { char c = sql.charAt(i); @@ -87,21 +84,48 @@ public static ParsedSql tokenize(String sql) { } i += token.getValue().length(); + tokens.add(token); + } + return tokens; + } - if (token.getType() == ParsedSql.TokenType.STATEMENT_END) { + public static ParsedSql parse(String sql) { + List tokens = tokenize(sql); + List statements = new ArrayList<>(); + List functionBodyList = new ArrayList<>(); - tokens.add(token); - statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens)); + List currentStatementTokens = new ArrayList<>(); + for (int i = 0; i < tokens.size(); i++) { + ParsedSql.Token current = tokens.get(i); + currentStatementTokens.add(current); - tokens = new ArrayList<>(); - statementStartIndex = i + 1; - } else { - tokens.add(token); + if (current.getType() == ParsedSql.TokenType.DEFAULT) { + String currentValue = current.getValue(); + + if (currentValue.equalsIgnoreCase("BEGIN")) { + if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) { + functionBodyList.add(true); + } else { + functionBodyList.add(false); + } + } else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) { + functionBodyList.remove(functionBodyList.size() - 1); + } + } else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) { + boolean inFunctionBody = false; + + for (boolean b : functionBodyList) { + inFunctionBody |= b; + } + if (!inFunctionBody) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); + currentStatementTokens = new ArrayList<>(); + } } } - // If tokens is not empty, implicit statement end - if (!tokens.isEmpty()) { - statements.add(new ParsedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens)); + + if (!currentStatementTokens.isEmpty()) { + statements.add(new ParsedSql.Statement(currentStatementTokens)); } return new ParsedSql(sql, statements); @@ -209,12 +233,13 @@ private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginInd } } - private static boolean isAsciiLetter(char c){ + private static boolean isAsciiLetter(char c) { char lower = Character.toLowerCase(c); return lower >= 'a' && lower <= 'z'; } - private static boolean isAsciiDigit(char c){ + private static boolean isAsciiDigit(char c) { return c >= '0' && c <= '9'; } + } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java index 62b49e5ec..d612472a1 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java @@ -73,7 +73,7 @@ final class PostgresqlStatement implements io.r2dbc.postgresql.api.PostgresqlSta PostgresqlStatement(ConnectionResources resources, String sql) { this.resources = Assert.requireNonNull(resources, "resources must not be null"); - this.parsedSql = PostgresqlSqlParser.tokenize(Assert.requireNonNull(sql, "sql must not be null")); + this.parsedSql = PostgresqlSqlParser.parse(Assert.requireNonNull(sql, "sql must not be null")); this.connectionContext = resources.getClient().getContext(); this.bindings = new ArrayDeque<>(this.parsedSql.getParameterCount()); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java index 08ba9428d..ad1008b98 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlSqlParserTest.java @@ -23,6 +23,7 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertThrows; class PostgresqlSqlParserTest { @@ -116,48 +117,49 @@ class SingleTokenExceptionTests { @Test void unclosedSingleQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("'test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("'test")); } @Test void unclosedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$$test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$$test")); } @Test void unclosedTaggedDollarQuotedStringThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc$test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc$test")); } @Test void unclosedQuotedIdentifierThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("\"test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("\"test")); } @Test void unclosedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*test")); } @Test void unclosedNestedBlockCommentThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("/*/*test*/")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("/*/*test*/")); } @Test void invalidParameterCharacterThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$1test")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$1test")); } @Test void invalidTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$a b$test$a b$")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$a b$test$a b$")); } @Test void unclosedTaggedDollarQuoteThrowsIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.tokenize("$abc")); + assertThrows(IllegalArgumentException.class, () -> PostgresqlSqlParser.parse("$abc")); } + } @Nested @@ -242,13 +244,33 @@ void simpleSelectStatementIsTokenized() { ); } + @Test + void simpleSelectStatementWithFunctionBodyIsTokenized() { + assertSingleStatementEquals("CREATE FUNCTION test() BEGIN ATOMIC SELECT 1; SELECT 2; END", + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "CREATE"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FUNCTION"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "test"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "("), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, ")"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "BEGIN"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "ATOMIC"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "2"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "END") + ); + } + } void assertSingleStatementEquals(String sql, ParsedSql.Token... tokens) { - ParsedSql parsedSql = PostgresqlSqlParser.tokenize(sql); + ParsedSql parsedSql = PostgresqlSqlParser.parse(sql); assertEquals(1, parsedSql.getStatements().size(), "Parse returned zero or more than 2 statements"); - ParsedSql.TokenizedStatement statement = parsedSql.getStatements().get(0); - assertEquals(new ParsedSql.TokenizedStatement(sql, Arrays.asList(tokens)), statement); + ParsedSql.Statement statement = parsedSql.getStatements().get(0); + assertIterableEquals(Arrays.asList(tokens), statement.getTokens()); } } @@ -258,30 +280,30 @@ class MultipleStatementTests { @Test void simpleMultipleStatementIsTokenized() { - ParsedSql parsedSql = PostgresqlSqlParser.tokenize("DELETE * FROM X; SELECT 1;"); - List statements = parsedSql.getStatements(); + ParsedSql parsedSql = PostgresqlSqlParser.parse("DELETE * FROM X; SELECT 1;"); + List statements = parsedSql.getStatements(); assertEquals(2, statements.size()); - ParsedSql.TokenizedStatement statementA = statements.get(0); - ParsedSql.TokenizedStatement statementB = statements.get(1); - - assertEquals(new ParsedSql.TokenizedStatement("DELETE * FROM X;", - Arrays.asList( - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), - new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), - new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") - )), - statementA + ParsedSql.Statement statementA = statements.get(0); + ParsedSql.Statement statementB = statements.get(1); + + assertIterableEquals( + Arrays.asList( + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "DELETE"), + new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, "*"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "FROM"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "X"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") + ), + statementA.getTokens() ); - assertEquals(new ParsedSql.TokenizedStatement("SELECT 1;", - Arrays.asList( - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), - new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), - new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") - )), - statementB + assertIterableEquals( + Arrays.asList( + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "SELECT"), + new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, "1"), + new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";") + ), + statementB.getTokens() ); }