Skip to content

Commit

Permalink
Allow rescaling parquet decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Feb 27, 2020
1 parent 6c71a14 commit 34678fa
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 56 deletions.
Expand Up @@ -21,6 +21,7 @@
import com.google.common.collect.Range;
import com.google.common.primitives.Shorts;
import io.airlift.units.DataSize;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.SqlDate;
Expand All @@ -32,7 +33,6 @@
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.schema.MessageType;
import org.joda.time.DateTimeZone;
import org.testng.annotations.BeforeClass;
Expand Down Expand Up @@ -73,6 +73,7 @@
import static io.prestosql.spi.type.DateType.DATE;
import static io.prestosql.spi.type.DecimalType.createDecimalType;
import static io.prestosql.spi.type.Decimals.MAX_PRECISION;
import static io.prestosql.spi.type.Decimals.longTenToNth;
import static io.prestosql.spi.type.DoubleType.DOUBLE;
import static io.prestosql.spi.type.IntegerType.INTEGER;
import static io.prestosql.spi.type.RealType.REAL;
Expand Down Expand Up @@ -876,12 +877,10 @@ public void testDecimalBackedByINT64()

@Test
public void testParquetShortDecimalWriteToPrestoDecimalWithNonMatchingScale()
throws Exception
{
assertThatThrownBy(() -> {
MessageType parquetSchema = parseMessageType(format("message hive_decimal { optional INT64 test (DECIMAL(%d, %d)); }", 10, 1));
tester.testRoundTrip(javaLongObjectInspector, ImmutableList.of(1L), ImmutableList.of(1L), createDecimalType(10, 2), Optional.of(parquetSchema));
}).hasMessage("Presto decimal column type has different scale (2) than Parquet decimal column (1)")
.isInstanceOf(ParquetDecodingException.class);
MessageType parquetSchema = parseMessageType(format("message hive_decimal { optional INT64 test (DECIMAL(%d, %d)); }", 10, 1));
tester.testRoundTrip(javaLongObjectInspector, ImmutableList.of(10L), ImmutableList.of(SqlDecimal.of(100L, 10, 2)), createDecimalType(10, 2), Optional.of(parquetSchema));
}

@Test
Expand Down Expand Up @@ -910,15 +909,13 @@ public void testDecimalBackedByFixedLenByteArray()

@Test
public void testParquetLongDecimalWriteToPrestoDecimalWithNonMatchingScale()
throws Exception
{
assertThatThrownBy(() ->
tester.testRoundTrip(
new JavaHiveDecimalObjectInspector(new DecimalTypeInfo(38, 10)),
ImmutableList.of(HiveDecimal.create(0)),
ImmutableList.of(new SqlDecimal(BigInteger.ZERO, 38, 10)),
createDecimalType(38, 9)))
.hasMessage("Presto decimal column type has different scale (9) than Parquet decimal column (10)")
.isInstanceOf(ParquetDecodingException.class);
tester.testRoundTrip(
new JavaHiveDecimalObjectInspector(new DecimalTypeInfo(38, 10)),
ImmutableList.of(HiveDecimal.create(100 * longTenToNth(10), 10)),
ImmutableList.of(new SqlDecimal(BigInteger.valueOf(100 * longTenToNth(9)), 38, 9)),
createDecimalType(38, 9));
}

@Test
Expand Down Expand Up @@ -987,8 +984,8 @@ public void testParquetShortDecimalWriteToPrestoBigintBlockWithNonZeroScale()
assertThatThrownBy(() -> {
MessageType parquetSchema = parseMessageType(format("message hive_decimal { optional INT64 test (DECIMAL(%d, %d)); }", 10, 1));
tester.testRoundTrip(javaLongObjectInspector, ImmutableList.of(1L), ImmutableList.of(1L), BIGINT, Optional.of(parquetSchema));
}).hasMessage("Parquet decimal column type with non-zero scale (1) cannot be converted to Presto bigint column type")
.isInstanceOf(ParquetDecodingException.class);
}).hasMessage("Unsupported Presto column type (bigint) for Parquet column ([test] optional int64 test (DECIMAL(10,1)))")
.isInstanceOf(PrestoException.class);
}

@Test
Expand Down
Expand Up @@ -42,6 +42,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.file.Files;
import java.util.Collections;
import java.util.Iterator;
Expand Down Expand Up @@ -140,9 +141,151 @@ public Object[][] testReadingMatchingPrecisionDataProvider()
};
}

/**
* Tests if Parquet decimal with given precision and scale can be read into Presto decimal with different precision and scale
* if Parquet decimal value could be rescaled into Presto decimal without loosing most and least significant digits.
*/
@Test(dataProvider = "testReadingRescaledDecimalsProvider")
public void testReadingRescaledDecimals(int precision, int scale, boolean forceFixedLengthArray, int schemaPrecision, int schemaScale, List<String> values, List<String> expected)
{
String tableName = generateTableName("rescaled_decimals", precision, scale);

createTable(tableName, schemaPrecision, schemaScale);

writeParquetDecimalsRecord(
getParquetWritePath(tableName),
ImmutableList.of(new ParquetDecimalInsert("value", forceFixedLengthArray, precision, scale, values)));

assertValues(tableName, schemaScale, expected);

dropTable(tableName);
}

@DataProvider
public Object[][] testReadingRescaledDecimalsProvider()
{
// parquetPrecision, parquetScale, useFixedLengthArray, schemaPrecision, schemaScale, writeValues, expectedValues
return new Object[][] {
{10, 2, false, 12, 4,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.0100", "10.0000", "1.2300", maximumValue(10, 2), minimumValue(10, 2))},
{10, 2, true, 13, 5,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.01000", "10.0000", "1.23000", maximumValue(10, 2), minimumValue(10, 2))},
{4, 2, false, 6, 4,
ImmutableList.of("10.01", "10", "1.23", maximumValue(4, 2), minimumValue(4, 2)),
ImmutableList.of("10.0100", "10.0000", "1.2300", maximumValue(4, 2), minimumValue(4, 2))},
{4, 2, false, 6, 2,
ImmutableList.of("10.01", "10", "1.23", maximumValue(4, 2), minimumValue(4, 2)),
ImmutableList.of("10.01", "10.00", "1.23", maximumValue(4, 2), minimumValue(4, 2))},
{10, 2, false, 11, 3,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.010", "10.000", "1.230", maximumValue(10, 2), minimumValue(10, 2))},
{10, 2, true, 12, 4,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.0100", "10.0000", "1.2300", maximumValue(10, 2), minimumValue(10, 2))},
{4, 2, false, 10, 5,
ImmutableList.of("10.01", "10", "1.23", maximumValue(4, 2), minimumValue(4, 2)),
ImmutableList.of("10.01000", "10.00000", "1.23000", maximumValue(4, 2), minimumValue(4, 2))},
{4, 2, true, 10, 5,
ImmutableList.of("10.01", "10", "1.23", maximumValue(4, 2), minimumValue(4, 2)),
ImmutableList.of("10.01000", "10.00000", "1.23000", maximumValue(4, 2), minimumValue(4, 2))},
{14, 2, false, 20, 3,
ImmutableList.of("10.01", "10", "1.23", maximumValue(14, 2), minimumValue(14, 2)),
ImmutableList.of("10.010", "10.000", "1.230", maximumValue(14, 2), minimumValue(14, 2))},
{6, 3, false, 9, 6,
ImmutableList.of("10.01", "10", "1.23", maximumValue(6, 3), minimumValue(6, 3)),
ImmutableList.of("10.010000", "10.000000", "1.230000", maximumValue(6, 3), minimumValue(6, 3))},
{6, 3, true, 9, 6,
ImmutableList.of("10.01", "10", "1.23", maximumValue(6, 3), minimumValue(6, 3)),
ImmutableList.of("10.010000", "10.000000", "1.230000", maximumValue(6, 3), minimumValue(6, 3))},
{10, 2, false, 38, 4,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.0100", "10.0000", "1.2300", maximumValue(10, 2), minimumValue(10, 2))},
{18, 4, false, 38, 14,
ImmutableList.of("10.01", "10", "1.23", maximumValue(18, 4), minimumValue(18, 4)),
ImmutableList.of("10.0100", "10.0000", "1.2300", maximumValue(18, 4), minimumValue(18, 4))},
};
}

/**
* Tests if Parquet decimal with given precision and scale can be read into Presto decimal with different precision and scale
* if Parquet decimal value will be rounded to fit into Presto decimal.
*/
@Test(dataProvider = "testReadingRoundedDecimalsProvider")
public void testReadingRoundedDecimals(int precision, int scale, boolean forceFixedLengthArray, int schemaPrecision, int schemaScale, List<String> values, List<String> expected)
{
String tableName = generateTableName("rounded_decimals", precision, scale);

createTable(tableName, schemaPrecision, schemaScale);

writeParquetDecimalsRecord(
getParquetWritePath(tableName),
ImmutableList.of(new ParquetDecimalInsert("value", forceFixedLengthArray, precision, scale, values)));

assertRoundedValues(tableName, schemaScale, expected);

dropTable(tableName);
}

@DataProvider
public Object[][] testReadingRoundedDecimalsProvider()
{
// parquetPrecision, parquetScale, useFixedLengthArray, schemaPrecision, schemaScale, writeValues, expectedValues
return new Object[][] {
{10, 2, false, 12, 1,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2))},
{9, 2, true, 12, 1,
ImmutableList.of("10.01", "10", "1.23", maximumValue(9, 2), minimumValue(9, 2)),
ImmutableList.of("10.01", "10", "1.23", maximumValue(9, 2), minimumValue(9, 2))},
{4, 2, false, 7, 1,
ImmutableList.of("10.01", "10", "1.23", maximumValue(4, 2), minimumValue(4, 2)),
ImmutableList.of("10.01", "10", "1.23", maximumValue(4, 2), minimumValue(4, 2))},
{10, 2, false, 12, 1,
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2)),
ImmutableList.of("10.01", "10", "1.23", maximumValue(10, 2), minimumValue(10, 2))},
};
}

/**
* Tests if Parquet decimal with given precision and scale cannot be read into Presto decimal with different precision and scale
* because when rescaling decimal we would loose most significant digits.
*/
@Test(dataProvider = "testReadingNonRescalableDecimalsProvider")
public void testReadingNonRescalableDecimals(int precision, int scale, boolean forceFixedLengthArray, int schemaPrecision, int schemaScale, List<String> values)
{
String tableName = generateTableName("non_rescalable", precision, scale);

createTable(tableName, schemaPrecision, schemaScale);

writeParquetDecimalsRecord(
getParquetWritePath(tableName),
ImmutableList.of(new ParquetDecimalInsert("value", forceFixedLengthArray, precision, scale, values)));

assertQueryFails(format("SELECT * FROM tpch.%s", tableName), format("Cannot cast DECIMAL\\(%d, %d\\) '.*' to DECIMAL\\(%d, %d\\)", precision, scale, schemaPrecision, schemaScale));

dropTable(tableName);
}

@DataProvider
public Object[][] testReadingNonRescalableDecimalsProvider()
{
// parquetPrecision, parquetScale, useFixedLengthArray, schemaPrecision, schemaScale, writeValues
return new Object[][] {
{4, 2, false, 4, 3, ImmutableList.of("10.01")},
{10, 2, false, 10, 3, ImmutableList.of("12345678.91")},
{10, 2, false, 3, 2, ImmutableList.of("10.01")},
{10, 2, true, 14, 7, ImmutableList.of("99999999.99")},
{10, 2, false, 10, 4, ImmutableList.of("99999999.99")},
{18, 8, false, 32, 23, ImmutableList.of("1234567890.12345678")},
{20, 8, false, 32, 21, ImmutableList.of("123456789012.12345678")},
};
}

protected void createTable(String tableName, int precision, int scale)
{
assertUpdate(format("CREATE TABLE %s (value decimal(%d, %d)) WITH (format = 'PARQUET')", tableName, precision, scale));
assertUpdate(format("CREATE TABLE tpch.%s (value decimal(%d, %d)) WITH (format = 'PARQUET')", tableName, precision, scale));
}

protected void dropTable(String tableName)
Expand All @@ -152,7 +295,7 @@ protected void dropTable(String tableName)

protected void assertValues(String tableName, int scale, List<String> expected)
{
MaterializedResult materializedRows = computeActual(format("SELECT value FROM %s", tableName));
MaterializedResult materializedRows = computeActual(format("SELECT value FROM tpch.%s", tableName));

List<BigDecimal> actualValues = materializedRows.getMaterializedRows().stream()
.map(row -> row.getField(0))
Expand All @@ -166,6 +309,22 @@ protected void assertValues(String tableName, int scale, List<String> expected)
assertThat(actualValues).containsExactlyInAnyOrder(expectedValues);
}

protected void assertRoundedValues(String tableName, int scale, List<String> expected)
{
MaterializedResult materializedRows = computeActual(format("SELECT value FROM tpch.%s", tableName));

List<BigDecimal> actualValues = materializedRows.getMaterializedRows().stream()
.map(row -> row.getField(0))
.map(BigDecimal.class::cast)
.collect(toImmutableList());

BigDecimal[] expectedValues = expected.stream()
.map(value -> new BigDecimal(value).setScale(scale, RoundingMode.HALF_UP))
.toArray(BigDecimal[]::new);

assertThat(actualValues).containsExactlyInAnyOrder(expectedValues);
}

private static java.nio.file.Path getBasePath()
{
try {
Expand Down
Expand Up @@ -13,6 +13,7 @@
*/
package io.prestosql.parquet.reader;

import io.airlift.slice.Slice;
import io.prestosql.parquet.RichColumnDescriptor;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.type.DecimalType;
Expand All @@ -23,7 +24,8 @@

import java.math.BigInteger;

import static io.prestosql.spi.type.Decimals.isLongDecimal;
import static io.prestosql.spi.type.DecimalConversions.longToLongCast;
import static io.prestosql.spi.type.DecimalConversions.longToShortCast;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand All @@ -41,20 +43,32 @@ public class LongDecimalColumnReader
@Override
protected void readValue(BlockBuilder blockBuilder, Type prestoType)
{
if (!isLongDecimal(prestoType)) {
if (!(prestoType instanceof DecimalType)) {
throw new ParquetDecodingException(format("Unsupported Presto column type (%s) for Parquet column (%s)", prestoType, columnDescriptor));
}

DecimalType prestoDecimalType = (DecimalType) prestoType;
if (prestoDecimalType.getScale() != parquetDecimalType.getScale()) {
throw new ParquetDecodingException(format(
"Presto decimal column type has different scale (%s) than Parquet decimal column (%s)",
prestoDecimalType.getScale(),
parquetDecimalType.getScale()));
}

if (definitionLevel == columnDescriptor.getMaxDefinitionLevel()) {
Binary value = valuesReader.readBytes();
prestoType.writeSlice(blockBuilder, Decimals.encodeUnscaledValue(new BigInteger(value.getBytes())));
Binary binary = valuesReader.readBytes();
Slice value = Decimals.encodeUnscaledValue(new BigInteger(binary.getBytes()));

if (prestoDecimalType.isShort()) {
prestoType.writeLong(blockBuilder, longToShortCast(
value,
parquetDecimalType.getPrecision(),
parquetDecimalType.getScale(),
prestoDecimalType.getPrecision(),
prestoDecimalType.getScale()));
}
else {
prestoType.writeSlice(blockBuilder, longToLongCast(
value,
parquetDecimalType.getPrecision(),
parquetDecimalType.getScale(),
prestoDecimalType.getPrecision(),
prestoDecimalType.getScale()));
}
}
else if (isValueNull()) {
blockBuilder.appendNull();
Expand Down

0 comments on commit 34678fa

Please sign in to comment.