Skip to content

Commit

Permalink
Sum aggregation function for decimal(p,s)
Browse files Browse the repository at this point in the history
This implementation uses the specialize() method explicitly.
The return type is a decimal(38,S) where S is the input scale. The
accumulator uses a BigInteger interally.
  • Loading branch information
Wojciech Biela authored and haozhun committed Sep 10, 2016
1 parent 2754aaa commit 423098b
Show file tree
Hide file tree
Showing 8 changed files with 532 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
import static com.facebook.presto.operator.aggregation.ChecksumAggregationFunction.CHECKSUM_AGGREGATION;
import static com.facebook.presto.operator.aggregation.CountColumn.COUNT_COLUMN;
import static com.facebook.presto.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION;
import static com.facebook.presto.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION;
import static com.facebook.presto.operator.aggregation.Histogram.HISTOGRAM;
import static com.facebook.presto.operator.aggregation.MapAggregationFunction.MAP_AGG;
import static com.facebook.presto.operator.aggregation.MapUnionAggregation.MAP_UNION;
Expand Down Expand Up @@ -489,6 +490,7 @@ public WindowFunctionSupplier load(SpecializedFunctionKey key)
.function(castVarcharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries()))
.function(castCharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries()))
.function(DECIMAL_AVERAGE_AGGREGATION)
.function(DECIMAL_SUM_AGGREGATION)
.function(TRY_CAST);

builder.function(new ArrayAggregationFunction(featuresConfig.isLegacyArrayAgg()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.state.BigIntegerState;
import com.facebook.presto.operator.aggregation.state.BigIntegerStateFactory;
import com.facebook.presto.operator.aggregation.state.BigIntegerStateSerializer;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;

import static com.facebook.presto.metadata.SignatureBinder.bindVariables;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.spi.type.Decimals.checkOverflow;
import static com.facebook.presto.spi.type.Decimals.decodeUnscaledValue;
import static com.facebook.presto.spi.type.Decimals.writeBigDecimal;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;

public class DecimalSumAggregation
extends SqlAggregationFunction
{
public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation();
private static final String NAME = "sum";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputShortDecimal", Type.class, BigIntegerState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "inputLongDecimal", Type.class, BigIntegerState.class, Block.class, int.class);

private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalSumAggregation.class, "outputLongDecimal", DecimalType.class, BigIntegerState.class, BlockBuilder.class);

private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalSumAggregation.class, "combine", BigIntegerState.class, BigIntegerState.class);

public DecimalSumAggregation()
{
super(NAME,
ImmutableList.of(),
ImmutableList.of(),
parseTypeSignature("decimal(38,s)", ImmutableSet.of("s")),
ImmutableList.of(parseTypeSignature("decimal(p,s)", ImmutableSet.of("p", "s"))),
FunctionKind.AGGREGATE);
}

@Override
public String getDescription()
{
return "Calculates the sum over the input values";
}

@Override
public InternalAggregationFunction specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry)
{
Type inputType = typeManager.getType(getOnlyElement(bindVariables(getSignature().getArgumentTypes(), boundVariables)));
Type outputType = typeManager.getType(bindVariables(getSignature().getReturnType(), boundVariables));
return generateAggregation(inputType, outputType);
}

private static InternalAggregationFunction generateAggregation(Type inputType, Type outputType)
{
checkArgument(inputType instanceof DecimalType, "type must be Decimal");
DynamicClassLoader classLoader = new DynamicClassLoader(DecimalSumAggregation.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(inputType);
MethodHandle inputFunction;
Class<? extends AccumulatorState> stateInterface = BigIntegerState.class;
AccumulatorStateSerializer<?> stateSerializer = new BigIntegerStateSerializer();

if (((DecimalType) inputType).isShort()) {
inputFunction = SHORT_DECIMAL_INPUT_FUNCTION;
}
else {
inputFunction = LONG_DECIMAL_INPUT_FUNCTION;
}

AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
createInputParameterMetadata(inputType),
inputFunction.bindTo(inputType),
COMBINE_FUNCTION,
LONG_DECIMAL_OUTPUT_FUNCTION.bindTo(outputType),
stateInterface,
stateSerializer,
new BigIntegerStateFactory(),
outputType,
false);

Type intermediateType = stateSerializer.getSerializedType();
GenericAccumulatorFactoryBinder factory = new AccumulatorCompiler().generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, intermediateType, outputType, true, false, factory);
}

private static List<ParameterMetadata> createInputParameterMetadata(Type type)
{
return ImmutableList.of(new ParameterMetadata(STATE), new ParameterMetadata(BLOCK_INPUT_CHANNEL, type), new ParameterMetadata(BLOCK_INDEX));
}

public static void inputShortDecimal(Type type, BigIntegerState state, Block block, int position)
{
accumulateValueInState(BigInteger.valueOf(type.getLong(block, position)), state);
}

public static void inputLongDecimal(Type type, BigIntegerState state, Block block, int position)
{
accumulateValueInState(decodeUnscaledValue(type.getSlice(block, position)), state);
}

private static void accumulateValueInState(BigInteger value, BigIntegerState state)
{
initializeIfNeeded(state);
state.setBigInteger(state.getBigInteger().add(value));
}

private static void initializeIfNeeded(BigIntegerState state)
{
if (state.getBigInteger() == null) {
state.setBigInteger(BigInteger.valueOf(0));
}
}

public static void combine(BigIntegerState state, BigIntegerState otherState)
{
if (state.getBigInteger() == null) {
state.setBigInteger(otherState.getBigInteger());
}
else {
state.setBigInteger(state.getBigInteger().add(otherState.getBigInteger()));
}
}

public static void outputLongDecimal(DecimalType type, BigIntegerState state, BlockBuilder out)
{
if (state.getBigInteger() == null) {
out.appendNull();
}
else {
BigDecimal value = new BigDecimal(state.getBigInteger(), type.getScale());
checkOverflow(state.getBigInteger());
writeBigDecimal(type, out, value);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.state;

import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AccumulatorStateMetadata;

import java.math.BigInteger;

@AccumulatorStateMetadata(stateFactoryClass = BigIntegerStateFactory.class, stateSerializerClass = BigIntegerStateSerializer.class)
public interface BigIntegerState
extends AccumulatorState
{
BigInteger getBigInteger();

void setBigInteger(BigInteger value);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.state;

import com.facebook.presto.array.ObjectBigArray;
import com.facebook.presto.spi.function.AccumulatorStateFactory;

import java.math.BigInteger;

import static com.facebook.presto.operator.aggregation.state.BigIntegerAndLongStateFactory.BIG_INTEGER_APPROX_SIZE;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static java.util.Objects.requireNonNull;

public class BigIntegerStateFactory
implements AccumulatorStateFactory<BigIntegerState>
{
@Override
public BigIntegerState createSingleState()
{
return new SingleBigIntegerState();
}

@Override
public Class<? extends BigIntegerState> getSingleStateClass()
{
return SingleBigIntegerState.class;
}

@Override
public BigIntegerState createGroupedState()
{
return new GroupedBigIntegerState();
}

@Override
public Class<? extends BigIntegerState> getGroupedStateClass()
{
return GroupedBigIntegerState.class;
}

public static class GroupedBigIntegerState
extends AbstractGroupedAccumulatorState
implements BigIntegerState
{
private final ObjectBigArray<BigInteger> bigIntegers = new ObjectBigArray<>();
private long estimatedSizeOfBigIntegerObjects;

@Override
public void ensureCapacity(long size)
{
bigIntegers.ensureCapacity(size);
}

@Override
public BigInteger getBigInteger()
{
return bigIntegers.get(getGroupId());
}

@Override
public void setBigInteger(BigInteger value)
{
requireNonNull(value, "value is null");
if (getBigInteger() == null) {
estimatedSizeOfBigIntegerObjects += BIG_INTEGER_APPROX_SIZE;
}
bigIntegers.set(getGroupId(), value);
}

@Override
public long getEstimatedSize()
{
return bigIntegers.sizeOf() + estimatedSizeOfBigIntegerObjects + SIZE_OF_LONG;
}
}

public static class SingleBigIntegerState
implements BigIntegerState
{
private BigInteger bigInteger;

@Override
public BigInteger getBigInteger()
{
return bigInteger;
}

@Override
public void setBigInteger(BigInteger bigInteger)
{
this.bigInteger = bigInteger;
}

@Override
public long getEstimatedSize()
{
if (bigInteger == null) {
return 0;
}
return BIG_INTEGER_APPROX_SIZE;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.state;

import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.type.Type;
import io.airlift.slice.Slices;

import java.math.BigInteger;

import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY;

public class BigIntegerStateSerializer
implements AccumulatorStateSerializer<BigIntegerState>
{
@Override
public Type getSerializedType()
{
return VARBINARY;
}

@Override
public void serialize(BigIntegerState state, BlockBuilder out)
{
if (state.getBigInteger() == null) {
out.appendNull();
}
else {
VARBINARY.writeSlice(out, Slices.wrappedBuffer(state.getBigInteger().toByteArray()));
}
}

@Override
public void deserialize(Block block, int index, BigIntegerState state)
{
if (!block.isNull(index)) {
state.setBigInteger(new BigInteger(VARBINARY.getSlice(block, index).getBytes()));
}
}
}

0 comments on commit 423098b

Please sign in to comment.