Skip to content

Commit

Permalink
Fix reading parquet INT32 values as short decimal
Browse files Browse the repository at this point in the history
Optimized parquet reader is improved to handle reading INT32
values into a short decimals of lower precision than 10 if the
value does not overflow the decimal.
Unoptimized parquet reader is fixed to apply overflow checks
when reading INT32 values into a short decimal.
  • Loading branch information
raunaqmorarka committed Apr 11, 2023
1 parent ccb06d5 commit f013793
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@

public final class ColumnReaderFactory
{
private static final int MAX_INT_DIGITS = 10;

private ColumnReaderFactory() {}

public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, AggregatedMemoryContext aggregatedMemoryContext, ParquetReaderOptions options)
Expand Down Expand Up @@ -220,10 +218,8 @@ public static ColumnReader create(PrimitiveField field, DateTimeZone timeZone, A
}
if (type instanceof DecimalType decimalType && decimalType.isShort()
&& isIntegerOrDecimalPrimitive(primitiveType)) {
if (decimalType.getScale() == 0 && decimalType.getPrecision() >= MAX_INT_DIGITS
&& primitiveType == INT32
&& isIntegerAnnotation(annotation)) {
return createColumnReader(field, TransformingValueDecoders::getInt32ToLongDecoder, LONG_ADAPTER, memoryContext);
if (primitiveType == INT32 && isIntegerAnnotation(annotation)) {
return createColumnReader(field, TransformingValueDecoders::getInt32ToShortDecimalDecoder, LONG_ADAPTER, memoryContext);
}
if (!(annotation instanceof DecimalLogicalTypeAnnotation decimalAnnotation)) {
throw unsupportedException(type, field);
Expand Down Expand Up @@ -270,7 +266,12 @@ && isIntegerOrDecimalPrimitive(primitiveType)) {

return switch (primitiveType) {
case BOOLEAN -> new BooleanColumnReader(field);
case INT32 -> createDecimalColumnReader(field).orElse(new IntColumnReader(field));
case INT32 -> createDecimalColumnReader(field).orElseGet(() -> {
if (type instanceof DecimalType decimalType && decimalType.isShort()) {
return new Int32ShortDecimalColumnReader(field);
}
return new IntColumnReader(field);
});
case INT64 -> {
if (annotation instanceof TimeLogicalTypeAnnotation timeAnnotation) {
if (field.getType() instanceof TimeType && timeAnnotation.getUnit() == MICROS) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.parquet.reader;

import io.trino.parquet.PrimitiveField;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;

import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.Decimals.overflows;
import static java.lang.String.format;

public class Int32ShortDecimalColumnReader
extends PrimitiveColumnReader
{
public Int32ShortDecimalColumnReader(PrimitiveField field)
{
super(field);
}

@Override
protected void readValue(BlockBuilder blockBuilder, Type trinoType)
{
if (trinoType instanceof DecimalType trinoDecimalType && trinoDecimalType.isShort()) {
long value = valuesReader.readInteger();
if (overflows(value, trinoDecimalType.getPrecision())) {
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot read parquet INT32 value '%s' as DECIMAL(%s, %s)", value, trinoDecimalType.getPrecision(), trinoDecimalType.getScale()));
}

trinoType.writeLong(blockBuilder, value);
}
else {
throw new TrinoException(NOT_SUPPORTED, format("Unsupported Trino column type (%s) for Parquet column (%s)", trinoType, field.getDescriptor()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.parquet.PrimitiveField;
import io.trino.parquet.reader.SimpleSliceInputStream;
import io.trino.parquet.reader.flat.BinaryBuffer;
import io.trino.spi.TrinoException;
import io.trino.spi.type.DecimalConversions;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
Expand All @@ -44,8 +45,10 @@
import static io.trino.parquet.reader.decoders.ValueDecoders.getRealDecoder;
import static io.trino.parquet.reader.decoders.ValueDecoders.getShortDecimalDecoder;
import static io.trino.parquet.reader.flat.Int96ColumnAdapter.Int96Buffer;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone;
import static io.trino.spi.type.Decimals.longTenToNth;
import static io.trino.spi.type.Decimals.overflows;
import static io.trino.spi.type.TimeZoneKey.UTC_KEY;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
Expand All @@ -58,6 +61,7 @@
import static java.lang.Math.floorDiv;
import static java.lang.Math.floorMod;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;

Expand Down Expand Up @@ -663,6 +667,39 @@ public void skip(int n)
};
}

public static ValueDecoder<long[]> getInt32ToShortDecimalDecoder(ParquetEncoding encoding, PrimitiveField field)
{
DecimalType decimalType = (DecimalType) field.getType();
ValueDecoder<int[]> delegate = getInt32Decoder(encoding, field);
return new ValueDecoder<>()
{
@Override
public void init(SimpleSliceInputStream input)
{
delegate.init(input);
}

@Override
public void read(long[] values, int offset, int length)
{
int[] buffer = new int[length];
delegate.read(buffer, 0, length);
for (int i = 0; i < length; i++) {
if (overflows(buffer[i], decimalType.getPrecision())) {
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot read parquet INT32 value '%s' as DECIMAL(%s, %s)", buffer[i], decimalType.getPrecision(), decimalType.getScale()));
}
values[i + offset] = buffer[i];
}
}

@Override
public void skip(int n)
{
delegate.skip(n);
}
};
}

public static ValueDecoder<long[]> getInt32ToLongDecoder(ParquetEncoding encoding, PrimitiveField field)
{
ValueDecoder<int[]> delegate = getInt32Decoder(encoding, field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ private static ColumnReaderFormat<?>[] columnReaders()
new ColumnReaderFormat<>(FLOAT, DoubleType.DOUBLE, PLAIN_WRITER, DICTIONARY_FLOAT_WRITER, WRITE_FLOAT, ASSERT_DOUBLE_STORED_AS_FLOAT),
new ColumnReaderFormat<>(DOUBLE, DoubleType.DOUBLE, PLAIN_WRITER, DICTIONARY_DOUBLE_WRITER, WRITE_DOUBLE, ASSERT_DOUBLE),
new ColumnReaderFormat<>(INT32, decimalType(0, 8), createDecimalType(8), PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_INT),
// INT32 can be read as a ShortDecimalType in Trino without decimal logical type annotation as well
new ColumnReaderFormat<>(INT32, createDecimalType(8, 0), PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_INT),
new ColumnReaderFormat<>(INT32, BIGINT, PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_LONG),
new ColumnReaderFormat<>(INT32, INTEGER, PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_INT, ASSERT_INT),
new ColumnReaderFormat<>(INT32, SMALLINT, PLAIN_WRITER, DICTIONARY_INT_WRITER, WRITE_SHORT, ASSERT_SHORT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,49 @@ public void testParquetShortDecimalWriteToTrinoTinyBlockWithNonZeroScale()
.isInstanceOf(TrinoException.class);
}

@Test
public void testReadParquetInt32AsTrinoShortDecimal()
throws Exception
{
Iterable<Integer> writeValues = intsBetween(0, 31_234);
Optional<MessageType> parquetSchema = Optional.of(parseMessageType("message hive_decimal { optional INT32 test; }"));
// Read INT32 as a short decimal of precision >= 10 with zero scale
tester.testRoundTrip(
javaIntObjectInspector,
writeValues,
transform(writeValues, value -> new SqlDecimal(BigInteger.valueOf(value), 10, 0)),
createDecimalType(10),
parquetSchema);

// Read INT32 as a short decimal of precision >= 10 with non-zero scale
tester.testRoundTrip(
javaIntObjectInspector,
ImmutableList.of(Integer.MAX_VALUE),
ImmutableList.of(new SqlDecimal(BigInteger.valueOf(Integer.MAX_VALUE), 10, 1)),
createDecimalType(10, 1),
parquetSchema);

// Read INT32 as a short decimal if value is within supported precision
tester.testRoundTrip(
javaIntObjectInspector,
ImmutableList.of(9999),
ImmutableList.of(new SqlDecimal(BigInteger.valueOf(9999), 4, 0)),
createDecimalType(4, 0),
parquetSchema);

// Cannot read INT32 as a short decimal if value exceeds supported precision
assertThatThrownBy(() -> tester.assertRoundTripWithHiveWriter(
List.of(javaIntObjectInspector),
new Iterable[] {ImmutableList.of(Integer.MAX_VALUE)},
new Iterable[] {ImmutableList.of(new SqlDecimal(BigInteger.valueOf(Integer.MAX_VALUE), 9, 0))},
List.of("test"),
List.of(createDecimalType(9, 0)),
parquetSchema,
ParquetSchemaOptions.defaultOptions()))
.hasMessage("Cannot read parquet INT32 value '2147483647' as DECIMAL(9, 0)")
.isInstanceOf(TrinoException.class);
}

@Test(dataProvider = "timestampPrecision")
public void testTimestamp(HiveTimestampPrecision precision)
throws Exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,41 @@ void assertRoundTrip(
Optional<MessageType> parquetSchema,
ParquetSchemaOptions schemaOptions)
throws Exception
{
assertRoundTripWithHiveWriter(objectInspectors, writeValues, readValues, columnNames, columnTypes, parquetSchema, schemaOptions);

// write Trino parquet
for (CompressionCodec compressionCodec : writerCompressions) {
for (ConnectorSession session : sessions) {
try (TempFile tempFile = new TempFile("test", "parquet")) {
OptionalInt min = stream(writeValues).mapToInt(Iterables::size).min();
checkState(min.isPresent());
writeParquetColumnTrino(tempFile.getFile(), columnTypes, columnNames, getIterators(readValues), min.getAsInt(), compressionCodec, schemaOptions);
assertFileContents(
session,
tempFile.getFile(),
getIterators(readValues),
columnNames,
columnTypes);
}
}
}
}

// Certain tests need the ability to specify a parquet schema which the writer wouldn't choose by itself based on the engine type.
// Explicitly provided parquetSchema is supported only by the hive writer.
// This method should be used when we need to assert that an exception should be thrown when reading from a file written with the specified
// parquetSchema to avoid getting misled due to an exception thrown when from reading the file produced by trino parquet writer which may not
// be following the specified parquetSchema.
void assertRoundTripWithHiveWriter(
List<ObjectInspector> objectInspectors,
Iterable<?>[] writeValues,
Iterable<?>[] readValues,
List<String> columnNames,
List<Type> columnTypes,
Optional<MessageType> parquetSchema,
ParquetSchemaOptions schemaOptions)
throws Exception
{
for (WriterVersion version : versions) {
for (CompressionCodec compressionCodec : compressions) {
Expand Down Expand Up @@ -385,23 +420,6 @@ void assertRoundTrip(
}
}
}

// write Trino parquet
for (CompressionCodec compressionCodec : writerCompressions) {
for (ConnectorSession session : sessions) {
try (TempFile tempFile = new TempFile("test", "parquet")) {
OptionalInt min = stream(writeValues).mapToInt(Iterables::size).min();
checkState(min.isPresent());
writeParquetColumnTrino(tempFile.getFile(), columnTypes, columnNames, getIterators(readValues), min.getAsInt(), compressionCodec, schemaOptions);
assertFileContents(
session,
tempFile.getFile(),
getIterators(readValues),
columnNames,
columnTypes);
}
}
}
}

void testMaxReadBytes(ObjectInspector objectInspector, Iterable<?> writeValues, Iterable<?> readValues, Type type, DataSize maxReadBlockSize)
Expand Down

0 comments on commit f013793

Please sign in to comment.