Skip to content

Commit

Permalink
fix decimal stats rescale in presto-orc tuple domain predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
sopel39 authored and cberner committed Aug 6, 2016
1 parent 8925a43 commit 3baae80
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
Expand Up @@ -67,6 +67,7 @@
import static com.facebook.presto.spi.type.DateType.DATE;
import static com.facebook.presto.spi.type.Decimals.isLongDecimal;
import static com.facebook.presto.spi.type.Decimals.isShortDecimal;
import static com.facebook.presto.spi.type.Decimals.rescale;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.type.IntegerType.INTEGER;
import static com.facebook.presto.spi.type.SmallintType.SMALLINT;
Expand All @@ -82,7 +83,6 @@
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.math.BigDecimal.ROUND_HALF_UP;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -580,11 +580,7 @@ private void parseDecimalColumn(int column, byte[] bytes, int start, int length)
}
else {
DecimalType columnType = (DecimalType) types[column];
BigDecimal decimal = new BigDecimal(new String(bytes, start, length, UTF_8));

checkState(decimal.scale() <= columnType.getScale(), "Read decimal value scale larger than column scale");
decimal = decimal.setScale(columnType.getScale(), ROUND_HALF_UP);
checkState(decimal.precision() <= columnType.getPrecision(), "Read decimal precision larger than column precision");
BigDecimal decimal = rescale(new BigDecimal(new String(bytes, start, length, UTF_8)), columnType);

if (columnType.isShort()) {
longs[column] = decimal.unscaledValue().longValue();
Expand Down
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.spi.predicate.Range;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.predicate.ValueSet;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -36,6 +37,7 @@
import static com.facebook.presto.spi.type.Decimals.encodeUnscaledValue;
import static com.facebook.presto.spi.type.Decimals.isLongDecimal;
import static com.facebook.presto.spi.type.Decimals.isShortDecimal;
import static com.facebook.presto.spi.type.Decimals.rescale;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -107,10 +109,10 @@ public static Domain getDomain(Type type, long rowCount, ColumnStatistics column
}
}
else if (isShortDecimal(type)) {
return createDomain(type, hasNullValue, columnStatistics.getDecimalStatistics(), value -> value.unscaledValue().longValue());
return createDomain(type, hasNullValue, columnStatistics.getDecimalStatistics(), value -> rescale(value, (DecimalType) type).unscaledValue().longValue());
}
else if (isLongDecimal(type)) {
return createDomain(type, hasNullValue, columnStatistics.getDecimalStatistics(), value -> encodeUnscaledValue(value.unscaledValue()));
return createDomain(type, hasNullValue, columnStatistics.getDecimalStatistics(), value -> encodeUnscaledValue(rescale(value, (DecimalType) type).unscaledValue()));
}
else if (type.getTypeSignature().getBase().equals(StandardTypes.DATE) && columnStatistics.getDateStatistics() != null) {
return createDomain(type, hasNullValue, columnStatistics.getDateStatistics(), value -> (long) value);
Expand Down
Expand Up @@ -222,11 +222,14 @@ public void testDecimal()
assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, null, null)), notNull(SHORT_DECIMAL));

assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, "999.99", "999.99")), singleValue(SHORT_DECIMAL, shortDecimal("999.99")));
assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, "999.9", "999.9")), singleValue(SHORT_DECIMAL, shortDecimal("999.90")));
assertEquals(getDomain(LONG_DECIMAL, 10, decimalColumnStats(10L, "1234567890.0987654321", "1234567890.0987654321")),
singleValue(LONG_DECIMAL, longDecimal("1234567890.0987654321")));

assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, "-999.99", "999.99")),
create(ValueSet.ofRanges(range(SHORT_DECIMAL, shortDecimal("-999.99"), true, shortDecimal("999.99"), true)), false));
assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, "10.5", "20")),
create(ValueSet.ofRanges(range(SHORT_DECIMAL, shortDecimal("10.50"), true, shortDecimal("20.00"), true)), false));
assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, null, "999.99")),
create(ValueSet.ofRanges(lessThanOrEqual(SHORT_DECIMAL, shortDecimal("999.99"))), false));
assertEquals(getDomain(SHORT_DECIMAL, 10, decimalColumnStats(10L, "-999.99", null)),
Expand Down
Expand Up @@ -30,6 +30,7 @@
import static java.lang.Math.pow;
import static java.lang.Math.round;
import static java.lang.String.format;
import static java.math.BigDecimal.ROUND_UNNECESSARY;
import static java.math.BigInteger.TEN;

public class Decimals
Expand Down Expand Up @@ -248,6 +249,16 @@ public static void writeBigDecimal(DecimalType decimalType, BlockBuilder blockBu
decimalType.writeSlice(blockBuilder, encodeScaledValue(value));
}

public static BigDecimal rescale(BigDecimal value, DecimalType type)
{
value = value.setScale(type.getScale(), ROUND_UNNECESSARY);

if (value.precision() > type.getPrecision()) {
throw new IllegalArgumentException("decimal precision larger than column precision");
}
return value;
}

public static long rescale(long value, int fromScale, int toScale)
{
if (toScale < fromScale) {
Expand Down

0 comments on commit 3baae80

Please sign in to comment.