diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ColumnReaderFactory.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ColumnReaderFactory.java index 3c58f36a15935..ebafcaa95f05e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ColumnReaderFactory.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ColumnReaderFactory.java @@ -83,8 +83,6 @@ public final class ColumnReaderFactory { - private static final int MAX_INT_DIGITS = 10; - private ColumnReaderFactory() {} public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, AggregatedMemoryContext aggregatedMemoryContext, ParquetReaderOptions options) @@ -220,10 +218,8 @@ public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, A } if (type instanceof DecimalType decimalType && decimalType.isShort() && isIntegerOrDecimalPrimitive(primitiveType)) { - if (decimalType.getScale() == 0 && decimalType.getPrecision() >= MAX_INT_DIGITS - && primitiveType == INT32 - && isIntegerAnnotation(annotation)) { - return createColumnReader(field, TransformingValueDecoders::getInt32ToLongDecoder, LONG_ADAPTER, memoryContext); + if (primitiveType == INT32 && isIntegerAnnotation(annotation)) { + return createColumnReader(field, TransformingValueDecoders::getInt32ToShortDecimalDecoder, LONG_ADAPTER, memoryContext); } if (!(annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation)) { throw unsupportedException(type, field); @@ -270,7 +266,12 @@ && isIntegerOrDecimalPrimitive(primitiveType)) { return switch (primitiveType) { case BOOLEAN -> new BooleanColumnReader(field); - case INT32 -> createDecimalColumnReader(field).orElse(new IntColumnReader(field)); + case INT32 -> createDecimalColumnReader(field).orElseGet(() -> { + if (type instanceof DecimalType decimalType && decimalType.isShort()) { + return new Int32ShortDecimalColumnReader(field); + } + return new IntColumnReader(field); + }); case INT64 -> { if (annotation instanceof TimeLogicalTypeAnnotation timeAnnotation) { if (field.getType() instanceof TimeType && timeAnnotation.getUnit() == MICROS) { diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/Int32ShortDecimalColumnReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/Int32ShortDecimalColumnReader.java new file mode 100644 index 0000000000000..cfd67252662b8 --- /dev/null +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/Int32ShortDecimalColumnReader.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.parquet.reader; + +import io.trino.parquet.PrimitiveField; +import io.trino.spi.TrinoException; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.Type; + +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.type.Decimals.overflows; +import static java.lang.String.format; + +public class Int32ShortDecimalColumnReader + extends PrimitiveColumnReader +{ + public Int32ShortDecimalColumnReader(PrimitiveField field) + { + super(field); + } + + @Override + protected void readValue(BlockBuilder blockBuilder, Type trinoType) + { + if (trinoType instanceof DecimalType trinoDecimalType && trinoDecimalType.isShort()) { + long value = valuesReader.readInteger(); + if (overflows(value, trinoDecimalType.getPrecision())) { + throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot read parquet INT32 value '%s' as DECIMAL(%s, %s)", value, trinoDecimalType.getPrecision(), trinoDecimalType.getScale())); + } + + trinoType.writeLong(blockBuilder, value); + } + else { + throw new TrinoException(NOT_SUPPORTED, format("Unsupported Trino column type (%s) for Parquet column (%s)", trinoType, field.getDescriptor())); + } + } +} diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/decoders/TransformingValueDecoders.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/decoders/TransformingValueDecoders.java index 018911c4e4e23..f944bd3ff9f1e 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/decoders/TransformingValueDecoders.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/decoders/TransformingValueDecoders.java @@ -18,6 +18,7 @@ import io.trino.parquet.PrimitiveField; import io.trino.parquet.reader.SimpleSliceInputStream; import io.trino.parquet.reader.flat.BinaryBuffer; +import io.trino.spi.TrinoException; import io.trino.spi.type.DecimalConversions; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -44,8 +45,10 @@ import static io.trino.parquet.reader.decoders.ValueDecoders.getRealDecoder; import static io.trino.parquet.reader.decoders.ValueDecoders.getShortDecimalDecoder; import static io.trino.parquet.reader.flat.Int96ColumnAdapter.Int96Buffer; +import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; import static io.trino.spi.type.Decimals.longTenToNth; +import static io.trino.spi.type.Decimals.overflows; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; @@ -58,6 +61,7 @@ import static java.lang.Math.floorDiv; import static java.lang.Math.floorMod; import static java.lang.Math.toIntExact; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; @@ -663,6 +667,39 @@ public void skip(int n) }; } + public static ValueDecoder getInt32ToShortDecimalDecoder(ParquetEncoding encoding, PrimitiveField field) + { + DecimalType decimalType = (DecimalType) field.getType(); + ValueDecoder delegate = getInt32Decoder(encoding, field); + return new ValueDecoder<>() + { + @Override + public void init(SimpleSliceInputStream input) + { + delegate.init(input); + } + + @Override + public void read(long[] values, int offset, int length) + { + int[] buffer = new int[length]; + delegate.read(buffer, 0, length); + for (int i = 0; i < length; i++) { + if (overflows(buffer[i], decimalType.getPrecision())) { + throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot read parquet INT32 value '%s' as DECIMAL(%s, %s)", buffer[i], decimalType.getPrecision(), decimalType.getScale())); + } + values[i + offset] = buffer[i]; + } + } + + @Override + public void skip(int n) + { + delegate.skip(n); + } + }; + } + public static ValueDecoder getInt32ToLongDecoder(ParquetEncoding encoding, PrimitiveField field) { ValueDecoder delegate = getInt32Decoder(encoding, field); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java index 65f3d0c072021..bb84dd2316c7b 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestingColumnReader.java @@ -651,6 +651,8 @@ private static ColumnReaderFormat[] columnReaders() new ColumnReaderFormat<>(FLOAT, DoubleType.DOUBLE, PLAIN_WRITER, DICTIONARY_FLOAT_WRITER, WRITE_FLOAT, ASSERT_DOUBLE_STORED_AS_FLOAT), new ColumnReaderFormat<>(DOUBLE, DoubleType.DOUBLE, PLAIN_WRITER, DICTIONARY_DOUBLE_WRITER, WRITE_DOUBLE, ASSERT_DOUBLE), new ColumnReaderFormat<>(INT32, decimalType(0, 8), createDecimalType(8), PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_INT), + // INT32 can be read as a ShortDecimalType in Trino without decimal logical type annotation as well + new ColumnReaderFormat<>(INT32, createDecimalType(8, 0), PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_INT), new ColumnReaderFormat<>(INT32, BIGINT, PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_LONG), new ColumnReaderFormat<>(INT32, INTEGER, PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_INT), new ColumnReaderFormat<>(INT32, SMALLINT, PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_SHORT, ASSERT_SHORT), diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java index fb35fc78f3895..c8da88c636c18 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java @@ -1192,6 +1192,49 @@ public void testParquetShortDecimalWriteToTrinoTinyBlockWithNonZeroScale() .isInstanceOf(TrinoException.class); } + @Test + public void testReadParquetInt32AsTrinoShortDecimal() + throws Exception + { + Iterable writeValues = intsBetween(0, 31_234); + Optional parquetSchema = Optional.of(parseMessageType("message hive_decimal { optional INT32 test; }")); + // Read INT32 as a short decimal of precision >= 10 with zero scale + tester.testRoundTrip( + javaIntObjectInspector, + writeValues, + transform(writeValues, value -> new SqlDecimal(BigInteger.valueOf(value), 10, 0)), + createDecimalType(10), + parquetSchema); + + // Read INT32 as a short decimal of precision >= 10 with non-zero scale + tester.testRoundTrip( + javaIntObjectInspector, + ImmutableList.of(Integer.MAX_VALUE), + ImmutableList.of(new SqlDecimal(BigInteger.valueOf(Integer.MAX_VALUE), 10, 1)), + createDecimalType(10, 1), + parquetSchema); + + // Read INT32 as a short decimal if value is within supported precision + tester.testRoundTrip( + javaIntObjectInspector, + ImmutableList.of(9999), + ImmutableList.of(new SqlDecimal(BigInteger.valueOf(9999), 4, 0)), + createDecimalType(4, 0), + parquetSchema); + + // Cannot read INT32 as a short decimal if value exceeds supported precision + assertThatThrownBy(() -> tester.assertRoundTripWithHiveWriter( + List.of(javaIntObjectInspector), + new Iterable[] {ImmutableList.of(Integer.MAX_VALUE)}, + new Iterable[] {ImmutableList.of(new SqlDecimal(BigInteger.valueOf(Integer.MAX_VALUE), 9, 0))}, + List.of("test"), + List.of(createDecimalType(9, 0)), + parquetSchema, + ParquetSchemaOptions.defaultOptions())) + .hasMessage("Cannot read parquet INT32 value '2147483647' as DECIMAL(9, 0)") + .isInstanceOf(TrinoException.class); + } + @Test(dataProvider = "timestampPrecision") public void testTimestamp(HiveTimestampPrecision precision) throws Exception diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java index e98e82ebe9968..34be2e686b80b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java @@ -356,6 +356,41 @@ void assertRoundTrip( Optional parquetSchema, ParquetSchemaOptions schemaOptions) throws Exception + { + assertRoundTripWithHiveWriter(objectInspectors, writeValues, readValues, columnNames, columnTypes, parquetSchema, schemaOptions); + + // write Trino parquet + for (CompressionCodec compressionCodec : writerCompressions) { + for (ConnectorSession session : sessions) { + try (TempFile tempFile = new TempFile("test", "parquet")) { + OptionalInt min = stream(writeValues).mapToInt(Iterables::size).min(); + checkState(min.isPresent()); + writeParquetColumnTrino(tempFile.getFile(), columnTypes, columnNames, getIterators(readValues), min.getAsInt(), compressionCodec, schemaOptions); + assertFileContents( + session, + tempFile.getFile(), + getIterators(readValues), + columnNames, + columnTypes); + } + } + } + } + + // Certain tests need the ability to specify a parquet schema which the writer wouldn't choose by itself based on the engine type. + // Explicitly provided parquetSchema is supported only by the hive writer. + // This method should be used when we need to assert that an exception should be thrown when reading from a file written with the specified + // parquetSchema to avoid getting misled due to an exception thrown when from reading the file produced by trino parquet writer which may not + // be following the specified parquetSchema. + void assertRoundTripWithHiveWriter( + List objectInspectors, + Iterable[] writeValues, + Iterable[] readValues, + List columnNames, + List columnTypes, + Optional parquetSchema, + ParquetSchemaOptions schemaOptions) + throws Exception { for (WriterVersion version : versions) { for (CompressionCodec compressionCodec : compressions) { @@ -385,23 +420,6 @@ void assertRoundTrip( } } } - - // write Trino parquet - for (CompressionCodec compressionCodec : writerCompressions) { - for (ConnectorSession session : sessions) { - try (TempFile tempFile = new TempFile("test", "parquet")) { - OptionalInt min = stream(writeValues).mapToInt(Iterables::size).min(); - checkState(min.isPresent()); - writeParquetColumnTrino(tempFile.getFile(), columnTypes, columnNames, getIterators(readValues), min.getAsInt(), compressionCodec, schemaOptions); - assertFileContents( - session, - tempFile.getFile(), - getIterators(readValues), - columnNames, - columnTypes); - } - } - } } void testMaxReadBytes(ObjectInspector objectInspector, Iterable writeValues, Iterable readValues, Type type, DataSize maxReadBlockSize)