Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@
import java.util.Set;
import java.util.TreeSet;

class TokenizedSql {
class ParsedSql {

private final String sql;

private final List<TokenizedStatement> statements;
private final List<Statement> statements;

private final int statementCount;

private final int parameterCount;

public TokenizedSql(String sql, List<TokenizedStatement> statements) {
public ParsedSql(String sql, List<Statement> statements) {
this.sql = sql;
this.statements = statements;
this.statementCount = statements.size();
this.parameterCount = getParameterCount(statements);
}

List<TokenizedStatement> getStatements() {
List<Statement> getStatements() {
return this.statements;
}

Expand All @@ -53,16 +53,16 @@ public String getSql() {
return sql;
}

private static int getParameterCount(List<TokenizedStatement> statements) {
private static int getParameterCount(List<Statement> statements) {
int sum = 0;
for (TokenizedStatement statement : statements){
for (Statement statement : statements){
sum += statement.getParameterCount();
}
return sum;
}

public boolean hasDefaultTokenValue(String... tokenValues) {
for (TokenizedStatement statement : this.statements) {
for (Statement statement : this.statements) {
for (Token token : statement.getTokens()) {
if (token.getType() == TokenType.DEFAULT) {
for (String value : tokenValues) {
Expand Down Expand Up @@ -129,24 +129,17 @@ public String toString() {

}

static class TokenizedStatement {

private final String sql;
static class Statement {

private final List<Token> tokens;

private final int parameterCount;

public TokenizedStatement(String sql, List<Token> tokens) {
public Statement(List<Token> tokens) {
this.tokens = tokens;
this.sql = sql;
this.parameterCount = readParameterCount(tokens);
}

public String getSql() {
return this.sql;
}

public List<Token> getTokens() {
return this.tokens;
}
Expand All @@ -164,19 +157,14 @@ public boolean equals(Object o) {
return false;
}

TokenizedStatement that = (TokenizedStatement) o;
Statement that = (Statement) o;

if (!this.sql.equals(that.sql)) {
return false;
}
return this.tokens.equals(that.tokens);
}

@Override
public int hashCode() {
int result = this.sql.hashCode();
result = 31 * result + this.tokens.hashCode();
return result;
return this.tokens.hashCode();
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/io/r2dbc/postgresql/PostgresqlBatch.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ final class PostgresqlBatch implements io.r2dbc.postgresql.api.PostgresqlBatch {
public PostgresqlBatch add(String sql) {
Assert.requireNonNull(sql, "sql must not be null");

if (!(PostgresqlSqlLexer.tokenize(sql).getParameterCount() == 0)) {
if (!(PostgresqlSqlParser.parse(sql).getParameterCount() == 0)) {
throw new IllegalArgumentException(String.format("Statement '%s' is not supported. This is often due to the presence of parameters.", sql));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
*
* @since 0.9
*/
class PostgresqlSqlLexer {
class PostgresqlSqlParser {

private static final char[] SPECIAL_AND_OPERATOR_CHARS = {
'+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?',
Expand All @@ -38,15 +38,12 @@ class PostgresqlSqlLexer {
Arrays.sort(SPECIAL_AND_OPERATOR_CHARS);
}

public static TokenizedSql tokenize(String sql) {
List<TokenizedSql.Token> tokens = new ArrayList<>();
List<TokenizedSql.TokenizedStatement> statements = new ArrayList<>();

int statementStartIndex = 0;
private static List<ParsedSql.Token> tokenize(String sql) {
List<ParsedSql.Token> tokens = new ArrayList<>();
int i = 0;
while (i < sql.length()) {
char c = sql.charAt(i);
TokenizedSql.Token token = null;
ParsedSql.Token token = null;

if (isWhitespace(c)) {
i++;
Expand All @@ -73,55 +70,82 @@ public static TokenizedSql tokenize(String sql) {
token = getParameterOrDollarQuoteToken(sql, i);
break;
case ';':
token = new TokenizedSql.Token(TokenizedSql.TokenType.STATEMENT_END, ";");
token = new ParsedSql.Token(ParsedSql.TokenType.STATEMENT_END, ";");
break;
default:
break;
}
if (token == null) {
if (isSpecialOrOperatorChar(c)) {
token = new TokenizedSql.Token(TokenizedSql.TokenType.SPECIAL_OR_OPERATOR, Character.toString(c));//getSpecialOrOperatorToken(sql, i);
token = new ParsedSql.Token(ParsedSql.TokenType.SPECIAL_OR_OPERATOR, Character.toString(c));//getSpecialOrOperatorToken(sql, i);
} else {
token = getDefaultToken(sql, i);
}
}

i += token.getValue().length();
tokens.add(token);
}
return tokens;
}

if (token.getType() == TokenizedSql.TokenType.STATEMENT_END) {
public static ParsedSql parse(String sql) {
List<ParsedSql.Token> tokens = tokenize(sql);
List<ParsedSql.Statement> statements = new ArrayList<>();
List<Boolean> functionBodyList = new ArrayList<>();

tokens.add(token);
statements.add(new TokenizedSql.TokenizedStatement(sql.substring(statementStartIndex, i), tokens));
List<ParsedSql.Token> currentStatementTokens = new ArrayList<>();
for (int i = 0; i < tokens.size(); i++) {
ParsedSql.Token current = tokens.get(i);
currentStatementTokens.add(current);

tokens = new ArrayList<>();
statementStartIndex = i + 1;
} else {
tokens.add(token);
if (current.getType() == ParsedSql.TokenType.DEFAULT) {
String currentValue = current.getValue();

if (currentValue.equalsIgnoreCase("BEGIN")) {
if (i + 1 < tokens.size() && tokens.get(i + 1).getValue().equalsIgnoreCase("ATOMIC")) {
functionBodyList.add(true);
} else {
functionBodyList.add(false);
}
} else if (currentValue.equalsIgnoreCase("END") && !functionBodyList.isEmpty()) {
functionBodyList.remove(functionBodyList.size() - 1);
}
} else if (current.getType().equals(ParsedSql.TokenType.STATEMENT_END)) {
boolean inFunctionBody = false;

for (boolean b : functionBodyList) {
inFunctionBody |= b;
}
if (!inFunctionBody) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
currentStatementTokens = new ArrayList<>();
}
}
}
// If tokens is not empty, implicit statement end
if (!tokens.isEmpty()) {
statements.add(new TokenizedSql.TokenizedStatement(sql.substring(statementStartIndex), tokens));

if (!currentStatementTokens.isEmpty()) {
statements.add(new ParsedSql.Statement(currentStatementTokens));
}

return new TokenizedSql(sql, statements);
return new ParsedSql(sql, statements);
}

private static TokenizedSql.Token getDefaultToken(String sql, int beginIndex) {
private static ParsedSql.Token getDefaultToken(String sql, int beginIndex) {
for (int i = beginIndex + 1; i < sql.length(); i++) {
char c = sql.charAt(i);
if (Character.isWhitespace(c) || isSpecialOrOperatorChar(c)) {
return new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, sql.substring(beginIndex, i));
return new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, sql.substring(beginIndex, i));
}
}
return new TokenizedSql.Token(TokenizedSql.TokenType.DEFAULT, sql.substring(beginIndex));
return new ParsedSql.Token(ParsedSql.TokenType.DEFAULT, sql.substring(beginIndex));
}

private static boolean isSpecialOrOperatorChar(char c) {
return Arrays.binarySearch(SPECIAL_AND_OPERATOR_CHARS, c) >= 0;
}

private static TokenizedSql.Token getBlockCommentToken(String sql, int beginIndex) {
private static ParsedSql.Token getBlockCommentToken(String sql, int beginIndex) {
int depth = 1;
for (int i = beginIndex + 2; i < (sql.length() - 1); i++) {
char c1 = sql.charAt(i);
Expand All @@ -134,44 +158,44 @@ private static TokenizedSql.Token getBlockCommentToken(String sql, int beginInde
i++;
}
if (depth == 0) {
return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex, i + 1));
return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex, i + 1));
}
}
throw new IllegalArgumentException("Sql cannot be parsed: unclosed block comment (comment opened at index " + beginIndex + ") in statement: " + sql);
}

private static TokenizedSql.Token getCommentToLineEndToken(String sql, int beginIndex) {
private static ParsedSql.Token getCommentToLineEndToken(String sql, int beginIndex) {
int lineEnding = sql.indexOf('\n', beginIndex);
if (lineEnding == -1) {
return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex));
return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex));
} else {
return new TokenizedSql.Token(TokenizedSql.TokenType.COMMENT, sql.substring(beginIndex, lineEnding));
return new ParsedSql.Token(ParsedSql.TokenType.COMMENT, sql.substring(beginIndex, lineEnding));
}
}

private static TokenizedSql.Token getDollarQuoteToken(String sql, String tag, int beginIndex) {
private static ParsedSql.Token getDollarQuoteToken(String sql, String tag, int beginIndex) {
int nextQuote = sql.indexOf(tag, beginIndex + tag.length());
if (nextQuote == -1) {
throw new IllegalArgumentException("Sql cannot be parsed: unclosed quote (quote opened at index " + beginIndex + ") in statement: " + sql);
} else {
return new TokenizedSql.Token(TokenizedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + tag.length()));
return new ParsedSql.Token(ParsedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + tag.length()));
}
}

private static TokenizedSql.Token getParameterToken(String sql, int beginIndex) {
private static ParsedSql.Token getParameterToken(String sql, int beginIndex) {
for (int i = beginIndex + 1; i < sql.length(); i++) {
char c = sql.charAt(i);
if (isWhitespace(c) || isSpecialOrOperatorChar(c)) {
return new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, sql.substring(beginIndex, i));
return new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, sql.substring(beginIndex, i));
}
if (!isAsciiDigit(c)) {
throw new IllegalArgumentException("Sql cannot be parsed: illegal character in parameter or dollar-quote tag: " + c);
}
}
return new TokenizedSql.Token(TokenizedSql.TokenType.PARAMETER, sql.substring(beginIndex));
return new ParsedSql.Token(ParsedSql.TokenType.PARAMETER, sql.substring(beginIndex));
}

private static TokenizedSql.Token getParameterOrDollarQuoteToken(String sql, int beginIndex) {
private static ParsedSql.Token getParameterOrDollarQuoteToken(String sql, int beginIndex) {
char firstChar = sql.charAt(beginIndex + 1);
if (firstChar == '$') {
return getDollarQuoteToken(sql, "$$", beginIndex);
Expand All @@ -191,30 +215,31 @@ private static TokenizedSql.Token getParameterOrDollarQuoteToken(String sql, int
}
}

private static TokenizedSql.Token getStandardQuoteToken(String sql, int beginIndex) {
private static ParsedSql.Token getStandardQuoteToken(String sql, int beginIndex) {
int nextQuote = sql.indexOf('\'', beginIndex + 1);
if (nextQuote == -1) {
throw new IllegalArgumentException("Sql cannot be parsed: unclosed quote (quote opened at index " + beginIndex + ") in statement: " + sql);
} else {
return new TokenizedSql.Token(TokenizedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + 1));
return new ParsedSql.Token(ParsedSql.TokenType.STRING_CONSTANT, sql.substring(beginIndex, nextQuote + 1));
}
}

private static TokenizedSql.Token getQuotedIdentifierToken(String sql, int beginIndex) {
private static ParsedSql.Token getQuotedIdentifierToken(String sql, int beginIndex) {
int nextQuote = sql.indexOf('\"', beginIndex + 1);
if (nextQuote == -1) {
throw new IllegalArgumentException("Sql cannot be parsed: unclosed quoted identifier (identifier opened at index " + beginIndex + ") in statement: " + sql);
} else {
return new TokenizedSql.Token(TokenizedSql.TokenType.QUOTED_IDENTIFIER, sql.substring(beginIndex, nextQuote + 1));
return new ParsedSql.Token(ParsedSql.TokenType.QUOTED_IDENTIFIER, sql.substring(beginIndex, nextQuote + 1));
}
}

private static boolean isAsciiLetter(char c){
private static boolean isAsciiLetter(char c) {
char lower = Character.toLowerCase(c);
return lower >= 'a' && lower <= 'z';
}

private static boolean isAsciiDigit(char c){
private static boolean isAsciiDigit(char c) {
return c >= '0' && c <= '9';
}

}
Loading