Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two's complement representation for decimals #10051

Merged
merged 3 commits into from Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -15,7 +15,6 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
Expand All @@ -30,6 +29,7 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

Expand All @@ -42,14 +42,11 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.spi.type.Decimals.overflows;
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.Int128Math.addWithOverflow;
import static io.trino.spi.type.Int128Math.divideRoundUp;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.SIGN_LONG_MASK;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.addWithOverflow;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.divideRoundUp;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToUnscaledLong;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;
import static java.math.BigDecimal.ROUND_HALF_UP;
Expand All @@ -69,7 +66,7 @@ public class DecimalAverageAggregation
private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalAverageAggregation.class, "combine", LongDecimalWithOverflowAndLongState.class, LongDecimalWithOverflowAndLongState.class);

private static final BigInteger TWO = new BigInteger("2");
private static final BigInteger OVERFLOW_MULTIPLIER = TWO.shiftLeft(UNSCALED_DECIMAL_128_SLICE_LENGTH * 8 - 2);
private static final BigInteger OVERFLOW_MULTIPLIER = TWO.pow(128);

public DecimalAverageAggregation()
{
Expand Down Expand Up @@ -130,19 +127,16 @@ public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state,
int offset = state.getDecimalArrayOffset();

long rightLow = block.getLong(position, 0);
long rightHigh = 0;
if (rightLow < 0) {
rightLow = -rightLow;
rightHigh = SIGN_LONG_MASK;
}
long rightHigh = rightLow >> 63;

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
rightLow,
rightHigh,
rightLow,
decimal,
offset);

state.addOverflow(overflow);
}

Expand All @@ -155,44 +149,46 @@ public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, B
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long rightHigh = block.getLong(position, 0);
long rightLow = block.getLong(position, SIZE_OF_LONG);

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
block.getLong(position, 0),
block.getLong(position, SIZE_OF_LONG),
rightHigh,
rightLow,
decimal,
offset);

state.addOverflow(overflow);
}

public static void combine(LongDecimalWithOverflowAndLongState state, LongDecimalWithOverflowAndLongState otherState)
{
state.addLong(otherState.getLong()); // row counter

long overflow = otherState.getOverflow();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long[] otherDecimal = otherState.getDecimalArray();
int otherOffset = otherState.getDecimalArrayOffset();

if (state.isNotNull()) {
overflow += addWithOverflow(
long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
otherDecimal[otherOffset],
otherDecimal[otherOffset + 1],
decimal,
offset);
state.addOverflow(overflow + otherState.getOverflow());
}
else {
state.setNotNull();
decimal[offset] = otherDecimal[otherOffset];
decimal[offset + 1] = otherDecimal[otherOffset + 1];
state.setOverflow(otherState.getOverflow());
}

state.addOverflow(overflow);
}

public static void outputShortDecimal(DecimalType type, LongDecimalWithOverflowAndLongState state, BlockBuilder out)
Expand All @@ -201,7 +197,7 @@ public static void outputShortDecimal(DecimalType type, LongDecimalWithOverflowA
out.appendNull();
}
else {
writeShortDecimal(out, unscaledDecimalToUnscaledLong(average(state, type)));
writeShortDecimal(out, average(state, type).toLongExact());
}
}

Expand All @@ -211,23 +207,29 @@ public static void outputLongDecimal(DecimalType type, LongDecimalWithOverflowAn
out.appendNull();
}
else {
type.writeSlice(out, average(state, type));
type.writeObject(out, average(state, type));
}
}

@VisibleForTesting
public static Slice average(LongDecimalWithOverflowAndLongState state, DecimalType type)
public static Int128 average(LongDecimalWithOverflowAndLongState state, DecimalType type)
{
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long overflow = state.getOverflow();
if (overflow != 0) {
BigDecimal sum = new BigDecimal(unscaledDecimalToBigInteger(decimal[offset], decimal[offset + 1]), type.getScale());
BigDecimal count = BigDecimal.valueOf(state.getLong());
BigDecimal sum = new BigDecimal(Int128.valueOf(decimal[offset], decimal[offset + 1]).toBigInteger(), type.getScale());
sum = sum.add(new BigDecimal(OVERFLOW_MULTIPLIER.multiply(BigInteger.valueOf(overflow))));

BigDecimal count = BigDecimal.valueOf(state.getLong());
return Decimals.encodeScaledValue(sum.divide(count, type.getScale(), ROUND_HALF_UP), type.getScale());
}
return divideRoundUp(decimal[offset], decimal[offset + 1], 0, state.getLong(), 0);

Int128 result = divideRoundUp(decimal[offset], decimal[offset + 1], 0, 0, state.getLong(), 0);
if (overflows(result)) {
throw new ArithmeticException("Decimal overflow");
}
return result;
}
}
Expand Up @@ -27,6 +27,7 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

Expand All @@ -37,12 +38,9 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.spi.type.Int128Math.addWithOverflow;
import static io.trino.spi.type.TypeSignatureParameter.numericParameter;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.SIGN_LONG_MASK;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.addWithOverflow;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwIfOverflows;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.throwOverflowException;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;

Expand Down Expand Up @@ -111,20 +109,16 @@ public static void inputShortDecimal(LongDecimalWithOverflowState state, Block b
int offset = state.getDecimalArrayOffset();

long rightLow = block.getLong(position, 0);
long rightHigh = 0;
if (rightLow < 0) {
rightLow = -rightLow;
rightHigh = SIGN_LONG_MASK;
}
long rightHigh = rightLow >> 63;

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
rightLow,
rightHigh,
rightLow,
decimal,
offset);
state.addOverflow(overflow);
state.setOverflow(Math.addExact(overflow, state.getOverflow()));
}

public static void inputLongDecimal(LongDecimalWithOverflowState state, Block block, int position)
Expand All @@ -134,60 +128,62 @@ public static void inputLongDecimal(LongDecimalWithOverflowState state, Block bl
long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long rightHigh = block.getLong(position, 0);
long rightLow = block.getLong(position, SIZE_OF_LONG);

long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
block.getLong(position, 0),
block.getLong(position, SIZE_OF_LONG),
rightHigh,
rightLow,
decimal,
offset);

state.addOverflow(overflow);
}

public static void combine(LongDecimalWithOverflowState state, LongDecimalWithOverflowState otherState)
{
long overflow = otherState.getOverflow();

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long[] otherDecimal = otherState.getDecimalArray();
int otherOffset = otherState.getDecimalArrayOffset();

if (state.isNotNull()) {
overflow += addWithOverflow(
long overflow = addWithOverflow(
decimal[offset],
decimal[offset + 1],
otherDecimal[otherOffset],
otherDecimal[otherOffset + 1],
decimal,
offset);
state.addOverflow(Math.addExact(overflow, otherState.getOverflow()));
}
else {
state.setNotNull();
decimal[offset] = otherDecimal[otherOffset];
decimal[offset + 1] = otherDecimal[otherOffset + 1];
state.setOverflow(otherState.getOverflow());
}

state.addOverflow(overflow);
}

public static void outputLongDecimal(LongDecimalWithOverflowState state, BlockBuilder out)
{
if (state.isNotNull()) {
if (state.getOverflow() != 0) {
throwOverflowException();
throw new ArithmeticException("Decimal overflow");
}

long[] decimal = state.getDecimalArray();
int offset = state.getDecimalArrayOffset();

long rawLow = decimal[offset];
long rawHigh = decimal[offset + 1];
long rawHigh = decimal[offset];
long rawLow = decimal[offset + 1];

throwIfOverflows(rawLow, rawHigh);
out.writeLong(rawLow);
Decimals.throwIfOverflows(rawHigh, rawLow);
out.writeLong(rawHigh);
out.writeLong(rawLow);
out.closeEntry();
}
else {
Expand Down
Expand Up @@ -18,15 +18,15 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;

import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class LongDecimalWithOverflowAndLongStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowAndLongState>
{
private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
private static final int SERIALIZED_SIZE = (Long.BYTES * 2) + Int128.SIZE;

@Override
public Type getSerializedType()
Expand Down
Expand Up @@ -18,15 +18,15 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;

import static io.trino.spi.type.VarbinaryType.VARBINARY;

public class LongDecimalWithOverflowStateSerializer
implements AccumulatorStateSerializer<LongDecimalWithOverflowState>
{
private static final int SERIALIZED_SIZE = Long.BYTES + UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
private static final int SERIALIZED_SIZE = Long.BYTES + Int128.SIZE;

@Override
public Type getSerializedType()
Expand Down
Expand Up @@ -19,13 +19,14 @@
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;

import java.math.BigDecimal;
import java.math.BigInteger;

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static io.trino.spi.type.Decimals.encodeUnscaledValue;
import static java.lang.Character.isDigit;
import static java.lang.String.format;

Expand All @@ -37,7 +38,7 @@ private DataSizeFunctions() {}
@ScalarFunction(value = "parse_data_size", alias = "parse_presto_data_size")
@LiteralParameters("x")
@SqlType("decimal(38,0)")
public static Slice parsePrestoDataSize(@SqlType("varchar(x)") Slice input)
public static Int128 parsePrestoDataSize(@SqlType("varchar(x)") Slice input)
{
String dataSize = input.toStringUtf8();

Expand All @@ -60,7 +61,7 @@ public static Slice parsePrestoDataSize(@SqlType("varchar(x)") Slice input)
Unit unit = Unit.parse(dataSize.substring(valueLength), dataSize);
BigInteger bytes = value.multiply(unit.getFactor()).toBigInteger();
try {
return encodeUnscaledValue(bytes);
return Decimals.valueOf(bytes);
}
catch (ArithmeticException e) {
throw new TrinoException(NUMERIC_VALUE_OUT_OF_RANGE, format("Value out of range: '%s' ('%sB')", dataSize, bytes));
Expand Down
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.TimeZoneKey;
Expand All @@ -50,10 +51,9 @@
import static io.trino.operator.scalar.QuarterOfYearDateTimeField.QUARTER_OF_YEAR;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone;
import static io.trino.spi.type.Int128Math.rescale;
import static io.trino.spi.type.TimeZoneKey.getTimeZoneKeyForOffset;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_SECOND;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.rescale;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger;
import static io.trino.type.DateTimes.PICOSECONDS_PER_NANOSECOND;
import static io.trino.type.DateTimes.PICOSECONDS_PER_SECOND;
import static io.trino.type.DateTimes.scaleEpochMillisToMicros;
Expand Down Expand Up @@ -155,10 +155,11 @@ private FromUnixtimeNanosDecimal() {}

@LiteralParameters({"p", "s"})
@SqlType("timestamp(9) with time zone")
public static LongTimestampWithTimeZone fromLong(@LiteralParameter("s") long scale, ConnectorSession session, @SqlType("decimal(p, s)") Slice unixTimeNanos)
public static LongTimestampWithTimeZone fromLong(@LiteralParameter("s") long scale, ConnectorSession session, @SqlType("decimal(p, s)") Int128 unixTimeNanos)
{
// TODO (https://github.com/trinodb/trino/issues/5781)
BigInteger unixTimeNanosInt = unscaledDecimalToBigInteger(rescale(unixTimeNanos, -(int) scale));
Int128 decimal = rescale(unixTimeNanos, -(int) scale);
BigInteger unixTimeNanosInt = decimal.toBigInteger();
long epochSeconds = unixTimeNanosInt.divide(BigInteger.valueOf(NANOSECONDS_PER_SECOND)).longValue();
long nanosOfSecond = unixTimeNanosInt.remainder(BigInteger.valueOf(NANOSECONDS_PER_SECOND)).longValue();
long picosOfSecond = nanosOfSecond * PICOSECONDS_PER_NANOSECOND;
Expand Down