Skip to content

Commit

Permalink
Merge pull request from GHSA-24rp-q3w6-vc56
Browse files Browse the repository at this point in the history
* SQL Injection via line comment generation for 42_3_x

* fix: Add parentheses around NULL parameter values in simple query mode

* simplify code, handle binary and add tests

---------

Co-authored-by: Sehrope Sarkuni <sehrope@jackdb.com>
  • Loading branch information
davecramer and sehrope committed Feb 20, 2024
1 parent 16a4fb7 commit d93c741
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public void setNull(@Positive int index, int oid) throws SQLException {
* {}
* </pre>
**/
private static String quoteAndCast(String text, String type, boolean standardConformingStrings) {
private static String quoteAndCast(String text, @Nullable String type, boolean standardConformingStrings) {
StringBuilder sb = new StringBuilder((text.length() + 10) / 10 * 11); // Add 10% for escaping.
sb.append("('");
try {
Expand Down Expand Up @@ -233,80 +233,103 @@ public String toString(@Positive int index, boolean standardConformingStrings) {
return "?";
} else if (paramValue == NULL_OBJECT) {
return "(NULL)";
} else if ((flags[index] & BINARY) == BINARY) {
}
String textValue;
String type;
if ((flags[index] & BINARY) == BINARY) {
// handle some of the numeric types

switch (paramTypes[index]) {
case Oid.INT2:
short s = ByteConverter.int2((byte[]) paramValue, 0);
return quoteAndCast(Short.toString(s), "int2", standardConformingStrings);
textValue = Short.toString(s);
type = "int2";
break;

case Oid.INT4:
int i = ByteConverter.int4((byte[]) paramValue, 0);
return quoteAndCast(Integer.toString(i), "int4", standardConformingStrings);
textValue = Integer.toString(i);
type = "int4";
break;

case Oid.INT8:
long l = ByteConverter.int8((byte[]) paramValue, 0);
return quoteAndCast(Long.toString(l), "int8", standardConformingStrings);
textValue = Long.toString(l);
type = "int8";
break;

case Oid.FLOAT4:
float f = ByteConverter.float4((byte[]) paramValue, 0);
if (Float.isNaN(f)) {
return "('NaN'::real)";
}
return quoteAndCast(Float.toString(f), "float", standardConformingStrings);
textValue = Float.toString(f);
type = "real";
break;

case Oid.FLOAT8:
double d = ByteConverter.float8((byte[]) paramValue, 0);
if (Double.isNaN(d)) {
return "('NaN'::double precision)";
}
return quoteAndCast(Double.toString(d), "double precision", standardConformingStrings);
textValue = Double.toString(d);
type = "double precision";
break;

case Oid.NUMERIC:
Number n = ByteConverter.numeric((byte[]) paramValue);
if (n instanceof Double) {
assert ((Double) n).isNaN();
return "('NaN'::numeric)";
}
return n.toString();
textValue = n.toString();
type = "numeric";
break;

case Oid.UUID:
String uuid =
textValue =
new UUIDArrayAssistant().buildElement((byte[]) paramValue, 0, 16).toString();
return quoteAndCast(uuid, "uuid", standardConformingStrings);
type = "uuid";
break;

case Oid.POINT:
PGpoint pgPoint = new PGpoint();
pgPoint.setByteValue((byte[]) paramValue, 0);
return quoteAndCast(pgPoint.toString(), "point", standardConformingStrings);
textValue = pgPoint.toString();
type = "point";
break;

case Oid.BOX:
PGbox pgBox = new PGbox();
pgBox.setByteValue((byte[]) paramValue, 0);
return quoteAndCast(pgBox.toString(), "box", standardConformingStrings);
textValue = pgBox.toString();
type = "box";
break;

default:
return "?";
}
return "?";
} else {
String param = paramValue.toString();
textValue = paramValue.toString();
int paramType = paramTypes[index];
if (paramType == Oid.TIMESTAMP) {
return quoteAndCast(param, "timestamp", standardConformingStrings);
type = "timestamp";
} else if (paramType == Oid.TIMESTAMPTZ) {
return quoteAndCast(param, "timestamp with time zone", standardConformingStrings);
type = "timestamp with time zone";
} else if (paramType == Oid.TIME) {
return quoteAndCast(param, "time", standardConformingStrings);
type = "time";
} else if (paramType == Oid.TIMETZ) {
return quoteAndCast(param, "time with time zone", standardConformingStrings);
type = "time with time zone";
} else if (paramType == Oid.DATE) {
return quoteAndCast(param, "date", standardConformingStrings);
type = "date";
} else if (paramType == Oid.INTERVAL) {
return quoteAndCast(param, "interval", standardConformingStrings);
type = "interval";
} else if (paramType == Oid.NUMERIC) {
return quoteAndCast(param, "numeric", standardConformingStrings);
type = "numeric";
} else {
type = null;
}
return quoteAndCast(param, null, standardConformingStrings);
}
return quoteAndCast(textValue, type, standardConformingStrings);
}

@Override
Expand Down
155 changes: 116 additions & 39 deletions pgjdbc/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,133 @@

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;

public class ParameterInjectionTest {
@Test
public void negateParameter() throws Exception {
try (Connection conn = TestUtil.openDB()) {
PreparedStatement stmt = conn.prepareStatement("SELECT -?");
private interface ParameterBinder {
void bind(PreparedStatement stmt) throws SQLException;
}

stmt.setInt(1, 1);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match");
int value = rs.getInt(1);
assertEquals(-1, value, "Input value 1");
}
private void testParamInjection(ParameterBinder bindPositiveOne, ParameterBinder bindNegativeOne)
throws SQLException {
try (Connection conn = TestUtil.openDB()) {
{
PreparedStatement stmt = conn.prepareStatement("SELECT -?");
bindPositiveOne.bind(stmt);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(),
"number of result columns must match");
int value = rs.getInt(1);
assertEquals(-1, value);
}
bindNegativeOne.bind(stmt);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(),
"number of result columns must match");
int value = rs.getInt(1);
assertEquals(1, value);
}
}
{
PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?");
bindPositiveOne.bind(stmt);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(),
"rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(-1, value);
}

stmt.setInt(1, -1);
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next());
assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match");
int value = rs.getInt(1);
assertEquals(1, value, "Input value -1");
}
bindNegativeOne.bind(stmt);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(1, value);
}

}
}
}

@Test
public void negateParameterWithContinuation() throws Exception {
try (Connection conn = TestUtil.openDB()) {
PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?");
@Test
public void handleInt2() throws SQLException {
testParamInjection(
stmt -> {
stmt.setShort(1, (short) 1);
},
stmt -> {
stmt.setShort(1, (short) -1);
}
);
}

stmt.setInt(1, 1);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(-1, value);
}
@Test
public void handleInt4() throws SQLException {
testParamInjection(
stmt -> {
stmt.setInt(1, 1);
},
stmt -> {
stmt.setInt(1, -1);
}
);
}

stmt.setInt(1, -1);
stmt.setString(2, "\nWHERE false --");
try (ResultSet rs = stmt.executeQuery()) {
assertTrue(rs.next(), "ResultSet should contain a row");
assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
int value = rs.getInt(1);
assertEquals(1, value);
}
@Test
public void handleBigInt() throws SQLException {
testParamInjection(
stmt -> {
stmt.setLong(1, (long) 1);
},
stmt -> {
stmt.setLong(1, (long) -1);
}
}
);
}

@Test
public void handleNumeric() throws SQLException {
testParamInjection(
stmt -> {
stmt.setBigDecimal(1, new BigDecimal("1"));
},
stmt -> {
stmt.setBigDecimal(1, new BigDecimal("-1"));
}
);
}

@Test
public void handleFloat() throws SQLException {
testParamInjection(
stmt -> {
stmt.setFloat(1, 1);
},
stmt -> {
stmt.setFloat(1, -1);
}
);
}

@Test
public void handleDouble() throws SQLException {
testParamInjection(
stmt -> {
stmt.setDouble(1, 1);
},
stmt -> {
stmt.setDouble(1, -1);
}
);
}
}

0 comments on commit d93c741

Please sign in to comment.