Permalink
Browse files

fix: make sure {fn now()} jdbc translation is not performed in dollar…

…-quoted strings

closes #511
  • Loading branch information...
vlsi committed Feb 13, 2016
1 parent a3e2045 commit 9109451c65d43328b8e4344642331d7750d79cf6
@@ -14,6 +14,7 @@
import org.postgresql.core.BaseStatement;
import org.postgresql.core.Field;
import org.postgresql.core.ParameterList;
import org.postgresql.core.Parser;
import org.postgresql.core.Query;
import org.postgresql.core.QueryExecutor;
import org.postgresql.core.ResultCursor;
@@ -145,14 +146,16 @@
protected ResultWrapper generatedKeys = null;
// Static variables for parsing SQL when replaceProcessing is true.
private static final short IN_SQLCODE = 0;
private static final short IN_STRING = 1;
private static final short IN_IDENTIFIER = 6;
private static final short BACKSLASH = 2;
private static final short ESC_TIMEDATE = 3;
private static final short ESC_FUNCTION = 4;
private static final short ESC_OUTERJOIN = 5;
private static final short ESC_ESCAPECHAR = 7;
private enum SqlParseState {
IN_SQLCODE,
IN_STRING,
IN_IDENTIFIER,
BACKSLASH,
ESC_TIMEDATE,
ESC_FUNCTION,
ESC_OUTERJOIN,
ESC_ESCAPECHAR;
}
protected Query lastSimpleQuery;
@@ -653,7 +656,7 @@ static String replaceProcessing(String p_sql, boolean replaceProcessingEnabled,
*/
protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean stopOnComma,
boolean stdStrings) throws SQLException {
short state = IN_SQLCODE;
SqlParseState state = SqlParseState.IN_SQLCODE;
int len = p_sql.length();
int nestedParenthesis = 0;
boolean endOfNested = false;
@@ -664,12 +667,31 @@ protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean
char c = p_sql.charAt(i);
switch (state) {
case IN_SQLCODE:
if (c == '\'') {
if (c == '$' && (i == 0 || !Parser.isIdentifierContChar(p_sql.charAt(i - 1)))) {
// start of a dollar-quoted string
int tagEnd = -1;
if (i + 1 < len) {
tagEnd = p_sql.indexOf('$', i + 1);
}
if (tagEnd != -1) {
String dollarQuoteTag = p_sql.substring(i, tagEnd + 1);
int nextPos = p_sql.indexOf(dollarQuoteTag, i + dollarQuoteTag.length());
if (nextPos > 0) {
tagEnd = nextPos + dollarQuoteTag.length();
}
}
if (tagEnd == -1) {
tagEnd = len;
}
newsql.append(p_sql, i, tagEnd); // tagEnd is excluding
i = tagEnd - 1;
break;
} else if (c == '\'') {
// start of a string?
state = IN_STRING;
state = SqlParseState.IN_STRING;
} else if (c == '"') {
// start of a identifier?
state = IN_IDENTIFIER;
state = SqlParseState.IN_IDENTIFIER;
} else if (c == '(') { // begin nested sql
nestedParenthesis++;
} else if (c == ')') { // end of nested sql
@@ -686,12 +708,12 @@ protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean
char next = p_sql.charAt(i + 1);
char nextnext = (i + 2 < len) ? p_sql.charAt(i + 2) : '\0';
if (next == 'd' || next == 'D') {
state = ESC_TIMEDATE;
state = SqlParseState.ESC_TIMEDATE;
i++;
newsql.append("DATE ");
break;
} else if (next == 't' || next == 'T') {
state = ESC_TIMEDATE;
state = SqlParseState.ESC_TIMEDATE;
if (nextnext == 's' || nextnext == 'S') {
// timestamp constant
i += 2;
@@ -703,16 +725,16 @@ protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean
}
break;
} else if (next == 'f' || next == 'F') {
state = ESC_FUNCTION;
state = SqlParseState.ESC_FUNCTION;
i += (nextnext == 'n' || nextnext == 'N') ? 2 : 1;
break;
} else if (next == 'o' || next == 'O') {
state = ESC_OUTERJOIN;
state = SqlParseState.ESC_OUTERJOIN;
i += (nextnext == 'j' || nextnext == 'J') ? 2 : 1;
break;
} else if (next == 'e' || next == 'E') {
// we assume that escape is the only escape sequence beginning with e
state = ESC_ESCAPECHAR;
state = SqlParseState.ESC_ESCAPECHAR;
break;
}
}
@@ -723,10 +745,10 @@ protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean
case IN_STRING:
if (c == '\'') {
// end of string?
state = IN_SQLCODE;
state = SqlParseState.IN_SQLCODE;
} else if (c == '\\' && !stdStrings) {
// a backslash?
state = BACKSLASH;
state = SqlParseState.BACKSLASH;
}
newsql.append(c);
@@ -735,13 +757,13 @@ protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean
case IN_IDENTIFIER:
if (c == '"') {
// end of identifier
state = IN_SQLCODE;
state = SqlParseState.IN_SQLCODE;
}
newsql.append(c);
break;
case BACKSLASH:
state = IN_STRING;
state = SqlParseState.IN_STRING;
newsql.append(c);
break;
@@ -764,13 +786,13 @@ protected static int parseSql(String p_sql, int i, StringBuilder newsql, boolean
while (i < len && p_sql.charAt(i) != '}') {
newsql.append(p_sql.charAt(i++));
}
state = IN_SQLCODE; // end of escaped function (or query)
state = SqlParseState.IN_SQLCODE; // end of escaped function (or query)
break;
case ESC_TIMEDATE:
case ESC_OUTERJOIN:
case ESC_ESCAPECHAR:
if (c == '}') {
state = IN_SQLCODE; // end of escape code.
state = SqlParseState.IN_SQLCODE; // end of escape code.
} else {
newsql.append(c);
}
@@ -67,6 +67,7 @@ public static TestSuite suite() throws Exception {
suite.addTestSuite(PreparedStatementTest.class);
suite.addTestSuite(PreparedStatementBinaryTest.class);
suite.addTestSuite(StatementTest.class);
suite.addTest(new JUnit4TestAdapter(QuotationTest.class));
// ServerSide Prepared Statements
suite.addTestSuite(ServerPreparedStmtTest.class);
@@ -0,0 +1,127 @@
package org.postgresql.test.jdbc2;
import org.postgresql.test.TestUtil;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
@RunWith(Parameterized.class)
public class QuotationTest extends BaseTest4 {
enum QuoteStyle {
SIMPLE("'"), DOLLAR_NOTAG("$$"), DOLLAR_A("$a$"), DOLLAR_DEF("$DEF$");
private final String quote;
QuoteStyle(String quote) {
this.quote = quote;
}
@Override
public String toString() {
return quote;
}
}
private final String expr;
private final String expected;
public QuotationTest(QuoteStyle quoteStyle, String expected, String expr) {
this.expected = expected;
this.expr = expr;
}
@Parameterized.Parameters(name = "{index}: quotes(style={0}, src={1}, quoted={2})")
public static Iterable<Object[]> data() {
Collection<Object[]> ids = new ArrayList<Object[]>();
Collection<String> garbageValues = new ArrayList<String>();
garbageValues.add("{fn now}");
garbageValues.add("{extract}");
garbageValues.add("{select}");
garbageValues.add("?select");
garbageValues.add("select?");
garbageValues.add("??select");
garbageValues.add("}{");
garbageValues.add("{");
garbageValues.add("}");
garbageValues.add("--");
garbageValues.add("/*");
garbageValues.add("*/");
for (QuoteStyle quoteStyle : QuoteStyle.values()) {
garbageValues.add(quoteStyle.toString());
}
for (char ch = 'a'; ch <= 'z'; ch++) {
garbageValues.add(Character.toString(ch));
}
for (QuoteStyle quoteStyle : QuoteStyle.values()) {
for (String garbage : garbageValues) {
String unquoted = garbage;
for (int i = 0; i < 3; i++) {
String quoted = unquoted;
if (quoteStyle == QuoteStyle.SIMPLE) {
quoted = quoted.replaceAll("'", "''");
}
quoted = quoteStyle.toString() + quoted + quoteStyle.toString();
if (quoted.endsWith("$$$") && quoteStyle == QuoteStyle.DOLLAR_NOTAG) {
// $$$a$$$ is parsed like $$ $a $$ $ -> thus we skip this test
continue;
}
if (quoteStyle != QuoteStyle.SIMPLE && garbage.equals(quoteStyle.toString())) {
// $a$$a$$a$ is not valid
continue;
}
String expected = unquoted;
ids.add(new Object[]{quoteStyle, expected, quoted});
if (unquoted.length() == 1) {
char ch = unquoted.charAt(0);
if (ch >= 'a' && ch <= 'z') {
// Will assume if 'a' works, then 'aa', 'aaa' will also work
break;
}
}
unquoted += garbage;
}
}
}
return ids;
}
@Test
public void quotedString() throws SQLException {
PreparedStatement ps = con.prepareStatement("select " + expr);
try {
ResultSet rs = ps.executeQuery();
rs.next();
String val = rs.getString(1);
Assert.assertEquals(expected, val);
} catch (SQLException e) {
TestUtil.closeQuietly(ps);
}
}
@Test
public void bindInTheMiddle() throws SQLException {
PreparedStatement ps = con.prepareStatement("select " + expr + ", ?, " + expr);
try {
ps.setInt(1, 42);
ResultSet rs = ps.executeQuery();
rs.next();
String val1 = rs.getString(1);
String val3 = rs.getString(3);
Assert.assertEquals(expected, val1);
Assert.assertEquals(expected, val3);
} catch (SQLException e) {
TestUtil.closeQuietly(ps);
}
}
}
@@ -652,4 +652,30 @@ protected void finalize() throws Throwable {
throw new IllegalStateException("Detected failure in cleanup thread", cleanupFailure.get());
}
}
/**
* Test that $JAVASCRIPT$ protects curly braces from JDBC {fn now()} kind of syntax.
* @throws SQLException if something goes wrong
*/
public void testJavascriptFunction() throws SQLException {
String str = " var _modules = {};\n"
+ " var _current_stack = [];\n"
+ "\n"
+ " // modules start\n"
+ " _modules[\"/root/aidbox/fhirbase/src/core\"] = {\n"
+ " init: function(){\n"
+ " var exports = {};\n"
+ " _current_stack.push({file: \"core\", dir: \"/root/aidbox/fhirbase/src\"})\n"
+ " var module = {exports: exports};";
PreparedStatement ps = null;
try {
ps = con.prepareStatement("select $JAVASCRIPT$" + str + "$JAVASCRIPT$");
ResultSet rs = ps.executeQuery();
rs.next();
assertEquals("Javascript code has been protected with $JAVASCRIPT$", str, rs.getString(1));
} finally {
TestUtil.closeQuietly(ps);
}
}
}

0 comments on commit 9109451

Please sign in to comment.