Skip to content

Commit

Permalink
Support reading Postgres hstore
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi authored and martint committed Jul 25, 2019
1 parent b4dd1c8 commit c39358e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 6 deletions.
Expand Up @@ -37,11 +37,14 @@
import io.prestosql.plugin.jdbc.StatsCollecting;
import io.prestosql.plugin.jdbc.WriteMapping;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.ConnectorTableMetadata;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.connector.TableNotFoundException;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.MapType;
import io.prestosql.spi.type.StandardTypes;
import io.prestosql.spi.type.TinyintType;
import io.prestosql.spi.type.Type;
Expand Down Expand Up @@ -107,6 +110,7 @@ public class PostgreSqlClient

private final Type jsonType;
private final Type uuidType;
private final MapType varcharMapType;
private final boolean supportArrays;

@Inject
Expand All @@ -119,6 +123,7 @@ public PostgreSqlClient(
super(config, "\"", connectionFactory);
this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON));
this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID));
this.varcharMapType = (MapType) typeManager.getType(TypeSignature.parseTypeSignature("map(varchar, varchar)"));

switch (postgreSqlConfig.getArrayMapping()) {
case DISABLED:
Expand Down Expand Up @@ -265,6 +270,8 @@ public Optional<ColumnMapping> toPrestoType(ConnectorSession session, Connection
case "timestamptz":
// PostgreSQL's "timestamp with time zone" is reported as Types.TIMESTAMP rather than Types.TIMESTAMP_WITH_TIMEZONE
return Optional.of(timestampWithTimeZoneColumnMapping());
case "hstore":
return Optional.of(hstoreColumnMapping());
}
if (typeHandle.getJdbcType() == Types.VARCHAR && !jdbcTypeName.equals("varchar")) {
// This can be e.g. an ENUM
Expand Down Expand Up @@ -361,6 +368,39 @@ private static LongWriteFunction timestampWithTimeZoneWriteFunction()
};
}

private ColumnMapping hstoreColumnMapping()
{
return ColumnMapping.blockMapping(
varcharMapType,
varcharMapReadFunction(),
(statement, index, block) -> { throw new PrestoException(NOT_SUPPORTED, "PosgtreSQL hstore write is not supported"); },
DISABLE_PUSHDOWN);
}

private BlockReadFunction varcharMapReadFunction()
{
return (resultSet, columnIndex) -> {
@SuppressWarnings("unchecked")
Map<String, String> map = (Map<String, String>) resultSet.getObject(columnIndex);
BlockBuilder keyBlockBuilder = varcharMapType.getKeyType().createBlockBuilder(null, map.size());
BlockBuilder valueBlockBuilder = varcharMapType.getValueType().createBlockBuilder(null, map.size());
for (Map.Entry<String, String> entry : map.entrySet()) {
if (entry.getKey() == null) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "hstore key is null");
}
varcharMapType.getKeyType().writeSlice(keyBlockBuilder, utf8Slice(entry.getKey()));
if (entry.getValue() == null) {
valueBlockBuilder.appendNull();
}
else {
varcharMapType.getValueType().writeSlice(valueBlockBuilder, utf8Slice(entry.getValue()));
}
}
return varcharMapType.createBlockFromKeyValue(Optional.empty(), new int[] {0, map.size()}, keyBlockBuilder.build(), valueBlockBuilder.build())
.getObject(0, Block.class);
};
}

private static ColumnMapping arrayColumnMapping(ConnectorSession session, ArrayType arrayType, String elementJdbcTypeName)
{
return ColumnMapping.blockMapping(
Expand Down
Expand Up @@ -41,9 +41,11 @@
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand All @@ -52,6 +54,7 @@
import static io.prestosql.spi.type.SmallintType.SMALLINT;
import static io.prestosql.spi.type.TimeZoneKey.UTC_KEY;
import static io.prestosql.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static io.prestosql.spi.type.TypeSignature.parseTypeSignature;
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
import static io.prestosql.tests.datatype.DataType.bigintDataType;
import static io.prestosql.tests.datatype.DataType.booleanDataType;
Expand All @@ -76,6 +79,7 @@
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

@Test
Expand Down Expand Up @@ -148,6 +152,9 @@ public void setUp()

timeGapInKathmandu = LocalDateTime.of(1986, 1, 1, 0, 13, 7);
checkIsGap(kathmandu, timeGapInKathmandu);

JdbcSqlExecutor executor = new JdbcSqlExecutor(postgreSqlServer.getJdbcUrl());
executor.execute("CREATE EXTENSION hstore");
}

@Test
Expand Down Expand Up @@ -656,6 +663,19 @@ private DataTypeTest jsonTestCases(DataType<String> jsonDataType)
.addRoundTrip(jsonDataType, "[]");
}

@Test
public void testHstore()
{
DataTypeTest.create()
.addRoundTrip(hstoreDataType(), null)
.addRoundTrip(hstoreDataType(), ImmutableMap.of())
.addRoundTrip(hstoreDataType(), ImmutableMap.of("key1", "value1"))
.addRoundTrip(hstoreDataType(), ImmutableMap.of("key1", "value1", "key2", "value2", "key3", "value3"))
.addRoundTrip(hstoreDataType(), ImmutableMap.of("key1", " \" ", "key2", " ' ", "key3", " ]) "))
.addRoundTrip(hstoreDataType(), Collections.singletonMap("key1", null))
.execute(getQueryRunner(), postgresCreateAndInsert("tpch.postgresql_test_hstore"));
}

@Test
public void testUuid()
{
Expand Down Expand Up @@ -714,6 +734,23 @@ public static DataType<String> jsonbDataType()
identity());
}

private DataType<Map<String, String>> hstoreDataType()
{
return dataType(
"hstore",
getQueryRunner().getMetadata().getType(parseTypeSignature("map(varchar, varchar)")),
value -> value.entrySet().stream()
.flatMap(entry -> Stream.of(entry.getKey(), entry.getValue()))
.map(string -> {
if (string == null) {
return "null";
}
return DataType.formatStringLiteral(string);
})
.collect(joining(",", "hstore(ARRAY[", "]::varchar[])")),
identity());
}

public static DataType<java.util.UUID> uuidDataType()
{
return dataType(
Expand Down
Expand Up @@ -44,13 +44,13 @@
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -238,10 +238,12 @@ else if (type instanceof ArrayType) {
.collect(toList());
}
else if (type instanceof MapType) {
return ((Map<Object, Object>) value).entrySet().stream()
.collect(Collectors.toMap(
e -> convertToRowValue(((MapType) type).getKeyType(), e.getKey()),
e -> convertToRowValue(((MapType) type).getValueType(), e.getValue())));
Map<Object, Object> result = new HashMap<>();
((Map<Object, Object>) value)
.forEach((k, v) -> result.put(
convertToRowValue(((MapType) type).getKeyType(), k),
convertToRowValue(((MapType) type).getValueType(), v)));
return result;
}
else if (type instanceof DecimalType) {
return new BigDecimal((String) value);
Expand Down

0 comments on commit c39358e

Please sign in to comment.