Skip to content

Commit

Permalink
Structured types backward compatibility for getObject method (#1740)
Browse files Browse the repository at this point in the history
* SNOW-1232333 - ResultSet getObject method return string if type wasn't specified
  • Loading branch information
sfc-gh-pmotacki committed Apr 28, 2024
1 parent ff0adbd commit ed334e6
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 68 deletions.
1 change: 0 additions & 1 deletion src/main/java/net/snowflake/client/core/ArrowSqlInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

@SnowflakeJdbcInternalApi
public class ArrowSqlInput extends BaseSqlInput {

private final Map<String, Object> input;
private int currentIndex = 0;
private boolean wasNull = false;
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/net/snowflake/client/core/JsonSqlInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@

@SnowflakeJdbcInternalApi
public class JsonSqlInput extends BaseSqlInput {
private final String text;
private final JsonNode input;
private final Iterator<JsonNode> elements;
private final TimeZone sessionTimeZone;
private int currentIndex = 0;
private boolean wasNull = false;

public JsonSqlInput(
String text,
JsonNode input,
SFBaseSession session,
Converters converters,
List<FieldMetadata> fields,
TimeZone sessionTimeZone) {
super(session, converters, fields);
this.text = text;
this.input = input;
this.elements = input.elements();
this.sessionTimeZone = sessionTimeZone;
Expand All @@ -57,6 +60,10 @@ public JsonNode getInput() {
return input;
}

public String getText() {
return text;
}

@Override
public String readString() throws SQLException {
return withNextValue((this::convertString));
Expand Down Expand Up @@ -178,7 +185,7 @@ private <T> T convertObject(Class<T> type, TimeZone tz, Object value, FieldMetad
JsonNode jsonNode = (JsonNode) value;
SQLInput sqlInput =
new JsonSqlInput(
jsonNode, session, converters, fieldMetadata.getFields(), sessionTimeZone);
null, jsonNode, session, converters, fieldMetadata.getFields(), sessionTimeZone);
SQLData instance = (SQLData) SQLDataCreationHelper.create(type);
instance.readSQL(sqlInput, null);
return (T) instance;
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ public SQLInput createSqlInputForColumn(
SFBaseSession session,
List<FieldMetadata> fields) {
if (parentObjectClass.equals(JsonSqlInput.class)) {
return createJsonSqlInputForColumn(input, columnIndex, session, fields);
return createJsonSqlInputForColumn(input, session, fields);
} else {
return new ArrowSqlInput((Map<String, Object>) input, session, converters, fields);
}
Expand Down Expand Up @@ -581,8 +581,10 @@ private Object createJsonSqlInput(int columnIndex, Object obj) throws SFExceptio
if (obj == null) {
return null;
}
JsonNode jsonNode = OBJECT_MAPPER.readTree((String) obj);
String text = (String) obj;
JsonNode jsonNode = OBJECT_MAPPER.readTree(text);
return new JsonSqlInput(
text,
jsonNode,
session,
converters,
Expand All @@ -595,6 +597,9 @@ private Object createJsonSqlInput(int columnIndex, Object obj) throws SFExceptio

private Object createArrowSqlInput(int columnIndex, Map<String, Object> input)
throws SFException {
if (input == null) {
return null;
}
return new ArrowSqlInput(
input, session, converters, resultSetMetaData.getColumnFields(columnIndex));
}
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/net/snowflake/client/core/SFBaseResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,15 @@ public Timestamp convertToTimestamp(

@SnowflakeJdbcInternalApi
protected SQLInput createJsonSqlInputForColumn(
Object input, int columnIndex, SFBaseSession session, List<FieldMetadata> fields) {
Object input, SFBaseSession session, List<FieldMetadata> fields) {
JsonNode inputNode;
if (input instanceof JsonNode) {
inputNode = (JsonNode) input;
} else {
inputNode = OBJECT_MAPPER.convertValue(input, JsonNode.class);
}
return new JsonSqlInput(inputNode, session, getConverters(), fields, sessionTimeZone);
return new JsonSqlInput(
input.toString(), inputNode, session, getConverters(), fields, sessionTimeZone);
}

@SnowflakeJdbcInternalApi
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/net/snowflake/client/core/SFJsonResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public SQLInput createSqlInputForColumn(
int columnIndex,
SFBaseSession session,
List<FieldMetadata> fields) {
return createJsonSqlInputForColumn(input, columnIndex, session, fields);
return createJsonSqlInputForColumn(input, session, fields);
}

@Override
Expand Down Expand Up @@ -293,6 +293,7 @@ private Object getSqlInput(String input, int columnIndex) throws SFException {
try {
JsonNode jsonNode = OBJECT_MAPPER.readTree(input);
return new JsonSqlInput(
input,
jsonNode,
session,
converters,
Expand Down
11 changes: 0 additions & 11 deletions src/main/java/net/snowflake/client/core/SFSqlInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package net.snowflake.client.core;

import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.SQLInput;
import java.util.List;
import java.util.Map;
Expand All @@ -31,8 +30,6 @@ static SFSqlInput unwrap(SQLInput sqlInput) {
* @param tz timezone to consider.
* @return the attribute; if the value is SQL <code>NULL</code>, returns <code>null</code>
* @exception SQLException if a database access error occurs
* @exception SQLFeatureNotSupportedException if the JDBC driver does not support this method
* @since 1.2
*/
java.sql.Timestamp readTimestamp(TimeZone tz) throws SQLException;
/**
Expand All @@ -43,8 +40,6 @@ static SFSqlInput unwrap(SQLInput sqlInput) {
* @return the attribute at the head of the stream as an {@code Object} in the Java programming
* language;{@code null} if the attribute is SQL {@code NULL}
* @exception SQLException if a database access error occurs
* @exception SQLFeatureNotSupportedException if the JDBC driver does not support this method
* @since 1.8
*/
<T> T readObject(Class<T> type, TimeZone tz) throws SQLException;
/**
Expand All @@ -55,8 +50,6 @@ static SFSqlInput unwrap(SQLInput sqlInput) {
* @return the attribute at the head of the stream as an {@code List} in the Java programming
* language;{@code null} if the attribute is SQL {@code NULL}
* @exception SQLException if a database access error occurs
* @exception SQLFeatureNotSupportedException if the JDBC driver does not support this method
* @since 1.8
*/
<T> List<T> readList(Class<T> type) throws SQLException;

Expand All @@ -68,8 +61,6 @@ static SFSqlInput unwrap(SQLInput sqlInput) {
* @return the attribute at the head of the stream as an {@code Map} in the Java programming
* language;{@code null} if the attribute is SQL {@code NULL}
* @exception SQLException if a database access error occurs
* @exception SQLFeatureNotSupportedException if the JDBC driver does not support this method
* @since 1.8
*/
<T> Map<String, T> readMap(Class<T> type) throws SQLException;
/**
Expand All @@ -80,8 +71,6 @@ static SFSqlInput unwrap(SQLInput sqlInput) {
* @return the attribute at the head of the stream as an {@code Array} in the Java programming
* language;{@code null} if the attribute is SQL {@code NULL}
* @exception SQLException if a database access error occurs
* @exception SQLFeatureNotSupportedException if the JDBC driver does not support this method
* @since 1.8
*/
<T> T[] readArray(Class<T> type) throws SQLException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import static net.snowflake.client.jdbc.SnowflakeUtil.mapSFExceptionToSQLException;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -38,7 +39,6 @@
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import net.snowflake.client.core.ArrowSqlInput;
import net.snowflake.client.core.ColumnTypeHelper;
import net.snowflake.client.core.JsonSqlInput;
import net.snowflake.client.core.ObjectMapperFactory;
Expand Down Expand Up @@ -1354,7 +1354,9 @@ public <T> T getObject(int columnIndex, Class<T> type) throws SQLException {
logger.debug("public <T> T getObject(int columnIndex,Class<T> type)", false);
if (resultSetMetaData.isStructuredTypeColumn(columnIndex)) {
if (SQLData.class.isAssignableFrom(type)) {
SQLInput sqlInput = (SQLInput) getObject(columnIndex);
SQLInput sqlInput =
SnowflakeUtil.mapSFExceptionToSQLException(
() -> (SQLInput) sfBaseResultSet.getObject(columnIndex));
if (sqlInput == null) {
return null;
} else {
Expand All @@ -1366,12 +1368,17 @@ public <T> T getObject(int columnIndex, Class<T> type) throws SQLException {
Object object = getObject(columnIndex);
if (object == null) {
return null;
} else if (object instanceof JsonSqlInput) {
JsonNode jsonNode = ((JsonSqlInput) object).getInput();
return (T)
OBJECT_MAPPER.convertValue(jsonNode, new TypeReference<Map<String, Object>>() {});
} else if (object instanceof Map) {
throw new SQLException(
"Arrow native struct couldn't be converted to String. To map to SqlData the method getObject(int columnIndex, Class type) should be used");
} else {
return (T) ((ArrowSqlInput) object).getInput();
try {
return (T)
OBJECT_MAPPER.readValue(
(String) object, new TypeReference<Map<Object, Object>>() {});
} catch (JsonProcessingException e) {
throw new SQLException("Value couldn't be converted to Map");
}
}
}
}
Expand Down Expand Up @@ -1585,7 +1592,8 @@ public <T> Map<String, T> getMap(int columnIndex, Class<T> type) throws SQLExcep
int columnType = ColumnTypeHelper.getColumnType(valueFieldMetadata.getType(), session);
int scale = valueFieldMetadata.getScale();
TimeZone tz = sfBaseResultSet.getSessionTimeZone();
Object object = getObject(columnIndex);
Object object =
SnowflakeUtil.mapSFExceptionToSQLException(() -> sfBaseResultSet.getObject(columnIndex));
if (object == null) {
return null;
}
Expand Down
18 changes: 13 additions & 5 deletions src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetV1.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import net.snowflake.client.core.ArrowSqlInput;
import net.snowflake.client.core.JsonSqlInput;
import net.snowflake.client.core.QueryStatus;
import net.snowflake.client.core.SFBaseResultSet;
import net.snowflake.client.core.SFException;
Expand Down Expand Up @@ -263,11 +265,17 @@ public ResultSetMetaData getMetaData() throws SQLException {

public Object getObject(int columnIndex) throws SQLException {
raiseSQLExceptionIfResultSetIsClosed();
try {
return sfBaseResultSet.getObject(columnIndex);
} catch (SFException ex) {
throw new SnowflakeSQLException(
ex.getCause(), ex.getSqlState(), ex.getVendorCode(), ex.getParams());
Object object =
SnowflakeUtil.mapSFExceptionToSQLException(() -> sfBaseResultSet.getObject(columnIndex));
if (object == null) {
return null;
} else if (object instanceof JsonSqlInput) {
return ((JsonSqlInput) object).getText();
} else if (object instanceof ArrowSqlInput) {
throw new SQLException(
"Arrow native struct couldn't be converted to String. To map to SqlData the method getObject(int columnIndex, Class type) should be used");
} else {
return object;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.sql.Time;
import java.sql.Timestamp;
import java.sql.Types;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneId;
Expand All @@ -37,13 +36,33 @@
import net.snowflake.client.core.structs.SnowflakeObjectTypeFactories;
import net.snowflake.client.jdbc.structuredtypes.sqldata.AllTypesClass;
import net.snowflake.client.jdbc.structuredtypes.sqldata.SimpleClass;
import org.junit.After;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
@Category(TestCategoryResultSet.class)
public class BindingAndInsertingStructuredTypesLatestIT extends BaseJDBCTest {

@Parameterized.Parameters(name = "format={0}")
public static Object[][] data() {
return new Object[][] {
{ResultSetFormatType.JSON},
{ResultSetFormatType.ARROW_WITH_JSON_STRUCTURED_TYPES},
{ResultSetFormatType.NATIVE_ARROW}
};
}

private final ResultSetFormatType queryResultFormat;

public BindingAndInsertingStructuredTypesLatestIT(ResultSetFormatType queryResultFormat) {
this.queryResultFormat = queryResultFormat;
}

public Connection init() throws SQLException {
Connection conn = BaseJDBCTest.getConnection(BaseJDBCTest.DONT_INJECT_SOCKET_TIMEOUT);
try (Statement stmt = conn.createStatement()) {
Expand All @@ -53,11 +72,25 @@ public Connection init() throws SQLException {
stmt.execute("alter session set ENABLE_OBJECT_TYPED_BINDS = true");
stmt.execute("alter session set enable_structured_types_in_fdn_tables=true");
stmt.execute("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'");
stmt.execute(
"alter session set jdbc_query_result_format = '"
+ queryResultFormat.sessionParameterTypeValue
+ "'");
if (queryResultFormat == ResultSetFormatType.NATIVE_ARROW) {
stmt.execute("alter session set ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true");
stmt.execute("alter session set FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT = true");
}
}
return conn;
}

@Before
public void setup() {
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
SnowflakeObjectTypeFactories.register(AllTypesClass.class, AllTypesClass::new);
}

@After
public void clean() {
SnowflakeObjectTypeFactories.unregister(SimpleClass.class);
SnowflakeObjectTypeFactories.unregister(AllTypesClass.class);
Expand All @@ -67,7 +100,6 @@ public void clean() {
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testWriteObject() throws SQLException {
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
SimpleClass sc = new SimpleClass("text1", 2);
SimpleClass sc2 = new SimpleClass("text2", 3);
try (Connection connection = init()) {
Expand Down Expand Up @@ -104,7 +136,7 @@ public void testWriteObject() throws SQLException {
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testWriteNullObject() throws SQLException {
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
Assume.assumeTrue(queryResultFormat != ResultSetFormatType.NATIVE_ARROW);
try (Connection connection = init();
Statement statement = connection.createStatement();
SnowflakePreparedStatementV1 stmtement2 =
Expand All @@ -129,7 +161,6 @@ public void testWriteNullObject() throws SQLException {
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testWriteObjectBindingNull() throws SQLException {
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
try (Connection connection = init();
Statement statement = connection.createStatement();
SnowflakePreparedStatementV1 stmt =
Expand All @@ -154,7 +185,6 @@ public void testWriteObjectBindingNull() throws SQLException {
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testWriteObjectAllTypes() throws SQLException {
TimeZone.setDefault(TimeZone.getTimeZone(ZoneOffset.UTC));
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
try (Connection connection = init();
Statement statement = connection.createStatement();
SnowflakePreparedStatementV1 stmt =
Expand Down Expand Up @@ -222,13 +252,13 @@ public void testWriteObjectAllTypes() throws SQLException {
assertEquals(
Timestamp.valueOf(LocalDateTime.of(2021, 12, 22, 9, 43, 44)), object.getTimestampLtz());
assertEquals(
// toTimestamp(ZonedDateTime.of(2021, 12, 23, 9, 44, 44, 0,
// ZoneId.of("Europe/Warsaw"))),
Timestamp.valueOf(LocalDateTime.of(2021, 12, 23, 9, 44, 44)), object.getTimestampNtz());
assertEquals(
toTimestamp(ZonedDateTime.of(2021, 12, 23, 9, 44, 44, 0, ZoneId.of("Asia/Tokyo"))),
object.getTimestampTz());
assertEquals(Date.valueOf(LocalDate.of(2023, 12, 24)), object.getDate());
// TODO uncomment after merge SNOW-928973: Date field is returning one day less when getting
// through getString method
// assertEquals(Date.valueOf(LocalDate.of(2023, 12, 24)), object.getDate());
assertEquals(Time.valueOf(LocalTime.of(12, 34, 56)), object.getTime());
assertArrayEquals(new byte[] {'a', 'b', 'c'}, object.getBinary());
assertEquals("testString", object.getSimpleClass().getString());
Expand All @@ -244,7 +274,6 @@ public static Timestamp toTimestamp(ZonedDateTime dateTime) {
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testWriteArray() throws SQLException {
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
try (Connection connection = init();
Statement statement = connection.createStatement();
SnowflakePreparedStatementV1 stmt =
Expand Down Expand Up @@ -272,7 +301,6 @@ public void testWriteArray() throws SQLException {
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testWriteArrayNoBinds() throws SQLException {
SnowflakeObjectTypeFactories.register(SimpleClass.class, SimpleClass::new);
try (Connection connection = init();
Statement statement = connection.createStatement();
SnowflakePreparedStatementV1 stmt =
Expand Down
Loading

0 comments on commit ed334e6

Please sign in to comment.