diff --git a/presto-postgresql/src/main/java/io/prestosql/plugin/postgresql/PostgreSqlClient.java b/presto-postgresql/src/main/java/io/prestosql/plugin/postgresql/PostgreSqlClient.java index 412e60b499a16..2f63eeabcaef4 100644 --- a/presto-postgresql/src/main/java/io/prestosql/plugin/postgresql/PostgreSqlClient.java +++ b/presto-postgresql/src/main/java/io/prestosql/plugin/postgresql/PostgreSqlClient.java @@ -37,11 +37,14 @@ import io.prestosql.plugin.jdbc.StatsCollecting; import io.prestosql.plugin.jdbc.WriteMapping; import io.prestosql.spi.PrestoException; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.ConnectorTableMetadata; import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.connector.TableNotFoundException; import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.MapType; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.TinyintType; import io.prestosql.spi.type.Type; @@ -107,6 +110,7 @@ public class PostgreSqlClient private final Type jsonType; private final Type uuidType; + private final MapType varcharMapType; private final boolean supportArrays; @Inject @@ -119,6 +123,7 @@ public PostgreSqlClient( super(config, "\"", connectionFactory); this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON)); this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID)); + this.varcharMapType = (MapType) typeManager.getType(TypeSignature.parseTypeSignature("map(varchar, varchar)")); switch (postgreSqlConfig.getArrayMapping()) { case DISABLED: @@ -265,6 +270,8 @@ public Optional toPrestoType(ConnectorSession session, Connection case "timestamptz": // PostgreSQL's "timestamp with time zone" is reported as Types.TIMESTAMP rather than Types.TIMESTAMP_WITH_TIMEZONE return Optional.of(timestampWithTimeZoneColumnMapping()); + case "hstore": + return Optional.of(hstoreColumnMapping()); } if (typeHandle.getJdbcType() == Types.VARCHAR && !jdbcTypeName.equals("varchar")) { // This can be e.g. an ENUM @@ -361,6 +368,39 @@ private static LongWriteFunction timestampWithTimeZoneWriteFunction() }; } + private ColumnMapping hstoreColumnMapping() + { + return ColumnMapping.blockMapping( + varcharMapType, + varcharMapReadFunction(), + (statement, index, block) -> { throw new PrestoException(NOT_SUPPORTED, "PosgtreSQL hstore write is not supported"); }, + DISABLE_PUSHDOWN); + } + + private BlockReadFunction varcharMapReadFunction() + { + return (resultSet, columnIndex) -> { + @SuppressWarnings("unchecked") + Map map = (Map) resultSet.getObject(columnIndex); + BlockBuilder keyBlockBuilder = varcharMapType.getKeyType().createBlockBuilder(null, map.size()); + BlockBuilder valueBlockBuilder = varcharMapType.getValueType().createBlockBuilder(null, map.size()); + for (Map.Entry entry : map.entrySet()) { + if (entry.getKey() == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "hstore key is null"); + } + varcharMapType.getKeyType().writeSlice(keyBlockBuilder, utf8Slice(entry.getKey())); + if (entry.getValue() == null) { + valueBlockBuilder.appendNull(); + } + else { + varcharMapType.getValueType().writeSlice(valueBlockBuilder, utf8Slice(entry.getValue())); + } + } + return varcharMapType.createBlockFromKeyValue(Optional.empty(), new int[] {0, map.size()}, keyBlockBuilder.build(), valueBlockBuilder.build()) + .getObject(0, Block.class); + }; + } + private static ColumnMapping arrayColumnMapping(ConnectorSession session, ArrayType arrayType, String elementJdbcTypeName) { return ColumnMapping.blockMapping( diff --git a/presto-postgresql/src/test/java/io/prestosql/plugin/postgresql/TestPostgreSqlTypeMapping.java b/presto-postgresql/src/test/java/io/prestosql/plugin/postgresql/TestPostgreSqlTypeMapping.java index b8eade5912fb1..5b63e9e835d2b 100644 --- a/presto-postgresql/src/test/java/io/prestosql/plugin/postgresql/TestPostgreSqlTypeMapping.java +++ b/presto-postgresql/src/test/java/io/prestosql/plugin/postgresql/TestPostgreSqlTypeMapping.java @@ -41,9 +41,11 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.Collections; import java.util.List; -import java.util.UUID; +import java.util.Map; import java.util.function.Function; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -52,6 +54,7 @@ import static io.prestosql.spi.type.SmallintType.SMALLINT; import static io.prestosql.spi.type.TimeZoneKey.UTC_KEY; import static io.prestosql.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; +import static io.prestosql.spi.type.TypeSignature.parseTypeSignature; import static io.prestosql.spi.type.VarbinaryType.VARBINARY; import static io.prestosql.tests.datatype.DataType.bigintDataType; import static io.prestosql.tests.datatype.DataType.booleanDataType; @@ -76,6 +79,7 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static java.util.function.Function.identity; +import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; @Test @@ -148,6 +152,9 @@ public void setUp() timeGapInKathmandu = LocalDateTime.of(1986, 1, 1, 0, 13, 7); checkIsGap(kathmandu, timeGapInKathmandu); + + JdbcSqlExecutor executor = new JdbcSqlExecutor(postgreSqlServer.getJdbcUrl()); + executor.execute("CREATE EXTENSION hstore"); } @Test @@ -656,6 +663,19 @@ private DataTypeTest jsonTestCases(DataType jsonDataType) .addRoundTrip(jsonDataType, "[]"); } + @Test + public void testHstore() + { + DataTypeTest.create() + .addRoundTrip(hstoreDataType(), null) + .addRoundTrip(hstoreDataType(), ImmutableMap.of()) + .addRoundTrip(hstoreDataType(), ImmutableMap.of("key1", "value1")) + .addRoundTrip(hstoreDataType(), ImmutableMap.of("key1", "value1", "key2", "value2", "key3", "value3")) + .addRoundTrip(hstoreDataType(), ImmutableMap.of("key1", " \" ", "key2", " ' ", "key3", " ]) ")) + .addRoundTrip(hstoreDataType(), Collections.singletonMap("key1", null)) + .execute(getQueryRunner(), postgresCreateAndInsert("tpch.postgresql_test_hstore")); + } + @Test public void testUuid() { @@ -714,6 +734,23 @@ public static DataType jsonbDataType() identity()); } + private DataType> hstoreDataType() + { + return dataType( + "hstore", + getQueryRunner().getMetadata().getType(parseTypeSignature("map(varchar, varchar)")), + value -> value.entrySet().stream() + .flatMap(entry -> Stream.of(entry.getKey(), entry.getValue())) + .map(string -> { + if (string == null) { + return "null"; + } + return DataType.formatStringLiteral(string); + }) + .collect(joining(",", "hstore(ARRAY[", "]::varchar[])")), + identity()); + } + public static DataType uuidDataType() { return dataType( diff --git a/presto-tests/src/main/java/io/prestosql/tests/TestingPrestoClient.java b/presto-tests/src/main/java/io/prestosql/tests/TestingPrestoClient.java index 89fadb29cfcb6..adbcea1ebcfa6 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/TestingPrestoClient.java +++ b/presto-tests/src/main/java/io/prestosql/tests/TestingPrestoClient.java @@ -44,13 +44,13 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -238,10 +238,12 @@ else if (type instanceof ArrayType) { .collect(toList()); } else if (type instanceof MapType) { - return ((Map) value).entrySet().stream() - .collect(Collectors.toMap( - e -> convertToRowValue(((MapType) type).getKeyType(), e.getKey()), - e -> convertToRowValue(((MapType) type).getValueType(), e.getValue()))); + Map result = new HashMap<>(); + ((Map) value) + .forEach((k, v) -> result.put( + convertToRowValue(((MapType) type).getKeyType(), k), + convertToRowValue(((MapType) type).getValueType(), v))); + return result; } else if (type instanceof DecimalType) { return new BigDecimal((String) value);