-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sum aggregation function for decimal(p,s)
This implementation uses the specialize() method explicitly. The return type is a decimal(38,S) where S is the input scale. The accumulator uses a BigInteger interally.
- Loading branch information
Showing
8 changed files
with
532 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
172 changes: 172 additions & 0 deletions
172
...to-main/src/main/java/com/facebook/presto/operator/aggregation/DecimalSumAggregation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto.operator.aggregation; | ||
|
||
import com.facebook.presto.bytecode.DynamicClassLoader; | ||
import com.facebook.presto.metadata.BoundVariables; | ||
import com.facebook.presto.metadata.FunctionKind; | ||
import com.facebook.presto.metadata.FunctionRegistry; | ||
import com.facebook.presto.metadata.SqlAggregationFunction; | ||
import com.facebook.presto.operator.aggregation.state.BigIntegerState; | ||
import com.facebook.presto.operator.aggregation.state.BigIntegerStateFactory; | ||
import com.facebook.presto.operator.aggregation.state.BigIntegerStateSerializer; | ||
import com.facebook.presto.spi.block.Block; | ||
import com.facebook.presto.spi.block.BlockBuilder; | ||
import com.facebook.presto.spi.function.AccumulatorState; | ||
import com.facebook.presto.spi.function.AccumulatorStateSerializer; | ||
import com.facebook.presto.spi.type.DecimalType; | ||
import com.facebook.presto.spi.type.Type; | ||
import com.facebook.presto.spi.type.TypeManager; | ||
import com.google.common.collect.ImmutableList; | ||
import com.google.common.collect.ImmutableSet; | ||
|
||
import java.lang.invoke.MethodHandle; | ||
import java.math.BigDecimal; | ||
import java.math.BigInteger; | ||
import java.util.List; | ||
|
||
import static com.facebook.presto.metadata.SignatureBinder.bindVariables; | ||
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata; | ||
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; | ||
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL; | ||
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; | ||
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; | ||
import static com.facebook.presto.spi.type.Decimals.checkOverflow; | ||
import static com.facebook.presto.spi.type.Decimals.decodeUnscaledValue; | ||
import static com.facebook.presto.spi.type.Decimals.writeBigDecimal; | ||
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; | ||
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList; | ||
import static com.facebook.presto.util.Reflection.methodHandle; | ||
import static com.google.common.base.Preconditions.checkArgument; | ||
import static com.google.common.collect.Iterables.getOnlyElement; | ||
|
||
public class DecimalSumAggregation | ||
extends SqlAggregationFunction | ||
{ | ||
public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation(); | ||
private static final String NAME = "sum"; | ||
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", Type.class, BigIntegerState.class, Block.class, int.class); | ||
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputLongDecimal", Type.class, BigIntegerState.class, Block.class, int.class); | ||
|
||
private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "outputLongDecimal", DecimalType.class, BigIntegerState.class, BlockBuilder.class); | ||
|
||
private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalSumAggregation.class, "combine", BigIntegerState.class, BigIntegerState.class); | ||
|
||
public DecimalSumAggregation() | ||
{ | ||
super(NAME, | ||
ImmutableList.of(), | ||
ImmutableList.of(), | ||
parseTypeSignature("decimal(38,s)", ImmutableSet.of("s")), | ||
ImmutableList.of(parseTypeSignature("decimal(p,s)", ImmutableSet.of("p", "s"))), | ||
FunctionKind.AGGREGATE); | ||
} | ||
|
||
@Override | ||
public String getDescription() | ||
{ | ||
return "Calculates the sum over the input values"; | ||
} | ||
|
||
@Override | ||
public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) | ||
{ | ||
Type inputType = typeManager.getType(getOnlyElement(bindVariables(getSignature().getArgumentTypes(), boundVariables))); | ||
Type outputType = typeManager.getType(bindVariables(getSignature().getReturnType(), boundVariables)); | ||
return generateAggregation(inputType, outputType); | ||
} | ||
|
||
private static InternalAggregationFunction generateAggregation(Type inputType, Type outputType) | ||
{ | ||
checkArgument(inputType instanceof DecimalType, "type must be Decimal"); | ||
DynamicClassLoader classLoader = new DynamicClassLoader(DecimalSumAggregation.class.getClassLoader()); | ||
List<Type> inputTypes = ImmutableList.of(inputType); | ||
MethodHandle inputFunction; | ||
Class<? extends AccumulatorState> stateInterface = BigIntegerState.class; | ||
AccumulatorStateSerializer<?> stateSerializer = new BigIntegerStateSerializer(); | ||
|
||
if (((DecimalType) inputType).isShort()) { | ||
inputFunction = SHORT_DECIMAL_INPUT_FUNCTION; | ||
} | ||
else { | ||
inputFunction = LONG_DECIMAL_INPUT_FUNCTION; | ||
} | ||
|
||
AggregationMetadata metadata = new AggregationMetadata( | ||
generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), | ||
createInputParameterMetadata(inputType), | ||
inputFunction.bindTo(inputType), | ||
COMBINE_FUNCTION, | ||
LONG_DECIMAL_OUTPUT_FUNCTION.bindTo(outputType), | ||
stateInterface, | ||
stateSerializer, | ||
new BigIntegerStateFactory(), | ||
outputType, | ||
false); | ||
|
||
Type intermediateType = stateSerializer.getSerializedType(); | ||
GenericAccumulatorFactoryBinder factory = new AccumulatorCompiler().generateAccumulatorFactoryBinder(metadata, classLoader); | ||
return new InternalAggregationFunction(NAME, inputTypes, intermediateType, outputType, true, false, factory); | ||
} | ||
|
||
private static List<ParameterMetadata> createInputParameterMetadata(Type type) | ||
{ | ||
return ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(BLOCK_INPUT_CHANNEL, type), new ParameterMetadata(BLOCK_INDEX)); | ||
} | ||
|
||
public static void inputShortDecimal(Type type, BigIntegerState state, Block block, int position) | ||
{ | ||
accumulateValueInState(BigInteger.valueOf(type.getLong(block, position)), state); | ||
} | ||
|
||
public static void inputLongDecimal(Type type, BigIntegerState state, Block block, int position) | ||
{ | ||
accumulateValueInState(decodeUnscaledValue(type.getSlice(block, position)), state); | ||
} | ||
|
||
private static void accumulateValueInState(BigInteger value, BigIntegerState state) | ||
{ | ||
initializeIfNeeded(state); | ||
state.setBigInteger(state.getBigInteger().add(value)); | ||
} | ||
|
||
private static void initializeIfNeeded(BigIntegerState state) | ||
{ | ||
if (state.getBigInteger() == null) { | ||
state.setBigInteger(BigInteger.valueOf(0)); | ||
} | ||
} | ||
|
||
public static void combine(BigIntegerState state, BigIntegerState otherState) | ||
{ | ||
if (state.getBigInteger() == null) { | ||
state.setBigInteger(otherState.getBigInteger()); | ||
} | ||
else { | ||
state.setBigInteger(state.getBigInteger().add(otherState.getBigInteger())); | ||
} | ||
} | ||
|
||
public static void outputLongDecimal(DecimalType type, BigIntegerState state, BlockBuilder out) | ||
{ | ||
if (state.getBigInteger() == null) { | ||
out.appendNull(); | ||
} | ||
else { | ||
BigDecimal value = new BigDecimal(state.getBigInteger(), type.getScale()); | ||
checkOverflow(state.getBigInteger()); | ||
writeBigDecimal(type, out, value); | ||
} | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
...to-main/src/main/java/com/facebook/presto/operator/aggregation/state/BigIntegerState.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto.operator.aggregation.state; | ||
|
||
import com.facebook.presto.spi.function.AccumulatorState; | ||
import com.facebook.presto.spi.function.AccumulatorStateMetadata; | ||
|
||
import java.math.BigInteger; | ||
|
||
@AccumulatorStateMetadata(stateFactoryClass = BigIntegerStateFactory.class, stateSerializerClass = BigIntegerStateSerializer.class) | ||
public interface BigIntegerState | ||
extends AccumulatorState | ||
{ | ||
BigInteger getBigInteger(); | ||
|
||
void setBigInteger(BigInteger value); | ||
} |
114 changes: 114 additions & 0 deletions
114
.../src/main/java/com/facebook/presto/operator/aggregation/state/BigIntegerStateFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto.operator.aggregation.state; | ||
|
||
import com.facebook.presto.array.ObjectBigArray; | ||
import com.facebook.presto.spi.function.AccumulatorStateFactory; | ||
|
||
import java.math.BigInteger; | ||
|
||
import static com.facebook.presto.operator.aggregation.state.BigIntegerAndLongStateFactory.BIG_INTEGER_APPROX_SIZE; | ||
import static io.airlift.slice.SizeOf.SIZE_OF_LONG; | ||
import static java.util.Objects.requireNonNull; | ||
|
||
public class BigIntegerStateFactory | ||
implements AccumulatorStateFactory<BigIntegerState> | ||
{ | ||
@Override | ||
public BigIntegerState createSingleState() | ||
{ | ||
return new SingleBigIntegerState(); | ||
} | ||
|
||
@Override | ||
public Class<? extends BigIntegerState> getSingleStateClass() | ||
{ | ||
return SingleBigIntegerState.class; | ||
} | ||
|
||
@Override | ||
public BigIntegerState createGroupedState() | ||
{ | ||
return new GroupedBigIntegerState(); | ||
} | ||
|
||
@Override | ||
public Class<? extends BigIntegerState> getGroupedStateClass() | ||
{ | ||
return GroupedBigIntegerState.class; | ||
} | ||
|
||
public static class GroupedBigIntegerState | ||
extends AbstractGroupedAccumulatorState | ||
implements BigIntegerState | ||
{ | ||
private final ObjectBigArray<BigInteger> bigIntegers = new ObjectBigArray<>(); | ||
private long estimatedSizeOfBigIntegerObjects; | ||
|
||
@Override | ||
public void ensureCapacity(long size) | ||
{ | ||
bigIntegers.ensureCapacity(size); | ||
} | ||
|
||
@Override | ||
public BigInteger getBigInteger() | ||
{ | ||
return bigIntegers.get(getGroupId()); | ||
} | ||
|
||
@Override | ||
public void setBigInteger(BigInteger value) | ||
{ | ||
requireNonNull(value, "value is null"); | ||
if (getBigInteger() == null) { | ||
estimatedSizeOfBigIntegerObjects += BIG_INTEGER_APPROX_SIZE; | ||
} | ||
bigIntegers.set(getGroupId(), value); | ||
} | ||
|
||
@Override | ||
public long getEstimatedSize() | ||
{ | ||
return bigIntegers.sizeOf() + estimatedSizeOfBigIntegerObjects + SIZE_OF_LONG; | ||
} | ||
} | ||
|
||
public static class SingleBigIntegerState | ||
implements BigIntegerState | ||
{ | ||
private BigInteger bigInteger; | ||
|
||
@Override | ||
public BigInteger getBigInteger() | ||
{ | ||
return bigInteger; | ||
} | ||
|
||
@Override | ||
public void setBigInteger(BigInteger bigInteger) | ||
{ | ||
this.bigInteger = bigInteger; | ||
} | ||
|
||
@Override | ||
public long getEstimatedSize() | ||
{ | ||
if (bigInteger == null) { | ||
return 0; | ||
} | ||
return BIG_INTEGER_APPROX_SIZE; | ||
} | ||
} | ||
} |
53 changes: 53 additions & 0 deletions
53
...c/main/java/com/facebook/presto/operator/aggregation/state/BigIntegerStateSerializer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto.operator.aggregation.state; | ||
|
||
import com.facebook.presto.spi.block.Block; | ||
import com.facebook.presto.spi.block.BlockBuilder; | ||
import com.facebook.presto.spi.function.AccumulatorStateSerializer; | ||
import com.facebook.presto.spi.type.Type; | ||
import io.airlift.slice.Slices; | ||
|
||
import java.math.BigInteger; | ||
|
||
import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; | ||
|
||
public class BigIntegerStateSerializer | ||
implements AccumulatorStateSerializer<BigIntegerState> | ||
{ | ||
@Override | ||
public Type getSerializedType() | ||
{ | ||
return VARBINARY; | ||
} | ||
|
||
@Override | ||
public void serialize(BigIntegerState state, BlockBuilder out) | ||
{ | ||
if (state.getBigInteger() == null) { | ||
out.appendNull(); | ||
} | ||
else { | ||
VARBINARY.writeSlice(out, Slices.wrappedBuffer(state.getBigInteger().toByteArray())); | ||
} | ||
} | ||
|
||
@Override | ||
public void deserialize(Block block, int index, BigIntegerState state) | ||
{ | ||
if (!block.isNull(index)) { | ||
state.setBigInteger(new BigInteger(VARBINARY.getSlice(block, index).getBytes())); | ||
} | ||
} | ||
} |
Oops, something went wrong.