Skip to content

Commit

Permalink
Require explicit output functions for aggregations
Browse files Browse the repository at this point in the history
Relying on the state serializer to write to the output builder
is wrong because blocks should always be written using the type.
  • Loading branch information
electrum committed Jul 17, 2015
1 parent 5481318 commit 04d9050
Show file tree
Hide file tree
Showing 23 changed files with 190 additions and 159 deletions.
Expand Up @@ -23,16 +23,13 @@
import com.facebook.presto.operator.aggregation.state.AccumulatorStateFactory;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateSerializer;
import com.facebook.presto.operator.aggregation.state.NullableBooleanState;
import com.facebook.presto.operator.aggregation.state.NullableBooleanStateSerializer;
import com.facebook.presto.operator.aggregation.state.NullableDoubleState;
import com.facebook.presto.operator.aggregation.state.NullableDoubleStateSerializer;
import com.facebook.presto.operator.aggregation.state.NullableLongState;
import com.facebook.presto.operator.aggregation.state.NullableLongStateSerializer;
import com.facebook.presto.operator.aggregation.state.SliceState;
import com.facebook.presto.operator.aggregation.state.SliceStateSerializer;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.google.common.base.Throwables;
Expand Down Expand Up @@ -60,6 +57,11 @@ public abstract class AbstractMinMaxAggregation
private static final MethodHandle SLICE_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregation.class, "input", MethodHandle.class, SliceState.class, Slice.class);
private static final MethodHandle BOOLEAN_INPUT_FUNCTION = methodHandle(AbstractMinMaxAggregation.class, "input", MethodHandle.class, NullableBooleanState.class, boolean.class);

private static final MethodHandle LONG_OUTPUT_FUNCTION = methodHandle(NullableLongState.class, "write", Type.class, NullableLongState.class, BlockBuilder.class);
private static final MethodHandle DOUBLE_OUTPUT_FUNCTION = methodHandle(NullableDoubleState.class, "write", Type.class, NullableDoubleState.class, BlockBuilder.class);
private static final MethodHandle SLICE_OUTPUT_FUNCTION = methodHandle(SliceState.class, "write", Type.class, SliceState.class, BlockBuilder.class);
private static final MethodHandle BOOLEAN_OUTPUT_FUNCTION = methodHandle(NullableBooleanState.class, "write", Type.class, NullableBooleanState.class, BlockBuilder.class);

private final String name;
private final OperatorType operatorType;
private final Signature signature;
Expand Down Expand Up @@ -97,40 +99,39 @@ protected InternalAggregationFunction generateAggregation(Type type, MethodHandl

List<Type> inputTypes = ImmutableList.of(type);

AccumulatorStateSerializer<?> stateSerializer;
AccumulatorStateFactory<?> stateFactory;
MethodHandle inputFunction;
MethodHandle outputFunction;
Class<? extends AccumulatorState> stateInterface;

if (type.getJavaType() == long.class) {
stateFactory = compiler.generateStateFactory(NullableLongState.class, classLoader);
stateSerializer = new NullableLongStateSerializer(type);
stateInterface = NullableLongState.class;
inputFunction = LONG_INPUT_FUNCTION;
outputFunction = LONG_OUTPUT_FUNCTION;
}
else if (type.getJavaType() == double.class) {
stateFactory = compiler.generateStateFactory(NullableDoubleState.class, classLoader);
stateSerializer = new NullableDoubleStateSerializer(type);
stateInterface = NullableDoubleState.class;
inputFunction = DOUBLE_INPUT_FUNCTION;
outputFunction = DOUBLE_OUTPUT_FUNCTION;
}
else if (type.getJavaType() == Slice.class) {
stateFactory = compiler.generateStateFactory(SliceState.class, classLoader);
stateSerializer = new SliceStateSerializer(type);
stateInterface = SliceState.class;
inputFunction = SLICE_INPUT_FUNCTION;
outputFunction = SLICE_OUTPUT_FUNCTION;
}
else if (type.getJavaType() == boolean.class) {
stateFactory = compiler.generateStateFactory(NullableBooleanState.class, classLoader);
stateSerializer = new NullableBooleanStateSerializer(type);
stateInterface = NullableBooleanState.class;
inputFunction = BOOLEAN_INPUT_FUNCTION;
outputFunction = BOOLEAN_OUTPUT_FUNCTION;
}
else {
throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Argument type to max/min unsupported");
}

inputFunction = inputFunction.bindTo(compareMethodHandle);
outputFunction = outputFunction.bindTo(type);

AccumulatorStateFactory<?> stateFactory = compiler.generateStateFactory(stateInterface, classLoader);
AccumulatorStateSerializer<?> stateSerializer = compiler.generateStateSerializer(stateInterface, classLoader);

Type intermediateType = stateSerializer.getSerializedType();
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
Expand All @@ -141,7 +142,7 @@ else if (type.getJavaType() == boolean.class) {
inputParameterMetadata,
inputFunction,
null,
null,
outputFunction,
stateInterface,
stateSerializer,
stateFactory,
Expand Down
Expand Up @@ -62,7 +62,6 @@
import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

Expand Down Expand Up @@ -152,10 +151,10 @@ private static <T> Class<? extends T> generateAccumulatorClass(
}

if (grouped) {
generateGroupedEvaluateFinal(definition, confidenceField, stateSerializerField, stateField, metadata.getOutputFunction(), metadata.isApproximate(), callSiteBinder);
generateGroupedEvaluateFinal(definition, confidenceField, stateField, metadata.getOutputFunction(), metadata.isApproximate(), callSiteBinder);
}
else {
generateEvaluateFinal(definition, confidenceField, stateSerializerField, stateField, metadata.getOutputFunction(), metadata.isApproximate(), callSiteBinder);
generateEvaluateFinal(definition, confidenceField, stateField, metadata.getOutputFunction(), metadata.isApproximate(), callSiteBinder);
}

return defineClass(definition, accumulatorInterface, callSiteBinder.getBindings(), classLoader);
Expand Down Expand Up @@ -648,9 +647,8 @@ private static void generateEvaluateIntermediate(ClassDefinition definition, Fie
private static void generateGroupedEvaluateFinal(
ClassDefinition definition,
FieldDefinition confidenceField,
FieldDefinition stateSerializerField,
FieldDefinition stateField,
@Nullable MethodHandle outputFunction,
MethodHandle outputFunction,
boolean approximate,
CallSiteBinder callSiteBinder)
{
Expand All @@ -665,30 +663,22 @@ private static void generateGroupedEvaluateFinal(

body.append(state.invoke("setGroupId", void.class, groupId.cast(long.class)));

if (outputFunction != null) {
body.comment("output(state, out)");
body.append(state);
if (approximate) {
checkNotNull(confidenceField, "confidenceField is null");
body.append(thisVariable.getField(confidenceField));
}
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
}
else {
checkArgument(!approximate, "Approximate aggregations must specify an output function");
ByteCodeExpression stateSerializer = thisVariable.getField(stateSerializerField);
body.append(stateSerializer.invoke("serialize", void.class, state.cast(Object.class), out));
body.comment("output(state, out)");
body.append(state);
if (approximate) {
checkNotNull(confidenceField, "confidenceField is null");
body.append(thisVariable.getField(confidenceField));
}
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));

body.ret();
}

private static void generateEvaluateFinal(
ClassDefinition definition,
FieldDefinition confidenceField,
FieldDefinition stateSerializerField,
FieldDefinition stateField,
@Nullable
MethodHandle outputFunction,
boolean approximate,
CallSiteBinder callSiteBinder)
Expand All @@ -705,21 +695,15 @@ private static void generateEvaluateFinal(

ByteCodeExpression state = thisVariable.getField(stateField);

if (outputFunction != null) {
body.comment("output(state, out)");
body.append(state);
if (approximate) {
checkNotNull(confidenceField, "confidenceField is null");
body.append(thisVariable.getField(confidenceField));
}
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
}
else {
checkArgument(!approximate, "Approximate aggregations must specify an output function");
ByteCodeExpression stateSerializer = thisVariable.getField(stateSerializerField);
body.append(stateSerializer.invoke("serialize", void.class, state.cast(Object.class), out));
body.comment("output(state, out)");
body.append(state);
if (approximate) {
checkNotNull(confidenceField, "confidenceField is null");
body.append(thisVariable.getField(confidenceField));
}
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));

body.ret();
}

Expand Down
Expand Up @@ -32,7 +32,6 @@
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -210,11 +209,7 @@ private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateCla
.filter(method -> method.getParameterTypes()[0] == stateClass)
.collect(toImmutableList());

if (methods.isEmpty()) {
List<Method> noOutputFunction = new ArrayList<>();
noOutputFunction.add(null);
return noOutputFunction;
}
checkArgument(!methods.isEmpty(), "Aggregation has no output functions");
return methods;
}

Expand Down
Expand Up @@ -56,7 +56,6 @@ public class AggregationMetadata
private final MethodHandle intermediateInputFunction;
@Nullable
private final MethodHandle combineFunction;
@Nullable
private final MethodHandle outputFunction;
private final AccumulatorStateSerializer<?> stateSerializer;
private final AccumulatorStateFactory<?> stateFactory;
Expand All @@ -70,7 +69,7 @@ public AggregationMetadata(
@Nullable List<ParameterMetadata> intermediateInputMetadata,
@Nullable MethodHandle intermediateInputFunction,
@Nullable MethodHandle combineFunction,
@Nullable MethodHandle outputFunction,
MethodHandle outputFunction,
Class<?> stateInterface,
AccumulatorStateSerializer<?> stateSerializer,
AccumulatorStateFactory<?> stateFactory,
Expand All @@ -92,7 +91,7 @@ public AggregationMetadata(
checkArgument(combineFunction != null || intermediateInputFunction != null, "Aggregation must have either a combine or a intermediate input method");
this.intermediateInputFunction = intermediateInputFunction;
this.combineFunction = combineFunction;
this.outputFunction = outputFunction;
this.outputFunction = checkNotNull(outputFunction, "outputFunction is null");
this.stateSerializer = checkNotNull(stateSerializer, "stateSerializer is null");
this.stateFactory = checkNotNull(stateFactory, "stateFactory is null");
this.approximate = approximate;
Expand Down Expand Up @@ -150,7 +149,6 @@ public MethodHandle getCombineFunction()
return combineFunction;
}

@Nullable
public MethodHandle getOutputFunction()
{
return outputFunction;
Expand Down
Expand Up @@ -14,6 +14,8 @@
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.state.TriStateBooleanState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.type.SqlType;

Expand Down Expand Up @@ -41,4 +43,10 @@ public static void booleanAnd(TriStateBooleanState state, @SqlType(StandardTypes
}
}
}

@OutputFunction(StandardTypes.BOOLEAN)
public static void output(TriStateBooleanState state, BlockBuilder out)
{
TriStateBooleanState.write(BooleanType.BOOLEAN, state, out);
}
}
Expand Up @@ -14,6 +14,8 @@
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.state.TriStateBooleanState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.type.SqlType;

Expand Down Expand Up @@ -41,4 +43,10 @@ public static void booleanOr(TriStateBooleanState state, @SqlType(StandardTypes.
}
}
}

@OutputFunction(StandardTypes.BOOLEAN)
public static void output(TriStateBooleanState state, BlockBuilder out)
{
TriStateBooleanState.write(BooleanType.BOOLEAN, state, out);
}
}
Expand Up @@ -14,6 +14,10 @@
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.state.LongState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.StandardTypes;

import static com.facebook.presto.spi.type.BigintType.BIGINT;

@AggregationFunction("count")
public final class CountAggregation
Expand All @@ -35,4 +39,10 @@ public static void combine(LongState state, LongState otherState)
{
state.setLong(state.getLong() + otherState.getLong());
}

@OutputFunction(StandardTypes.BIGINT)
public static void output(LongState state, BlockBuilder out)
{
BIGINT.writeLong(out, state.getLong());
}
}
Expand Up @@ -23,6 +23,7 @@
import com.facebook.presto.operator.aggregation.state.LongState;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
Expand Down Expand Up @@ -50,6 +51,7 @@ public class CountColumn
private static final Signature SIGNATURE = new Signature(NAME, ImmutableList.of(typeParameter("T")), StandardTypes.BIGINT, ImmutableList.of("T"), false, false);
private static final MethodHandle INPUT_FUNCTION = methodHandle(CountColumn.class, "input", LongState.class, Block.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(CountColumn.class, "combine", LongState.class, LongState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(CountColumn.class, "output", LongState.class, BlockBuilder.class);

@Override
public Signature getSignature()
Expand Down Expand Up @@ -89,7 +91,7 @@ private static InternalAggregationFunction generateAggregation(Type type)
null,
null,
COMBINE_FUNCTION,
null,
OUTPUT_FUNCTION,
LongState.class,
stateSerializer,
stateFactory,
Expand All @@ -114,4 +116,9 @@ public static void combine(LongState state, LongState otherState)
{
state.setLong(state.getLong() + otherState.getLong());
}

public static void output(LongState state, BlockBuilder out)
{
BIGINT.writeLong(out, state.getLong());
}
}
Expand Up @@ -14,9 +14,12 @@
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.state.LongState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.type.SqlType;

import static com.facebook.presto.spi.type.BigintType.BIGINT;

@AggregationFunction("count_if")
public final class CountIfAggregation
{
Expand All @@ -35,4 +38,10 @@ public static void combine(LongState state, LongState otherState)
{
state.setLong(state.getLong() + otherState.getLong());
}

@OutputFunction(StandardTypes.BIGINT)
public static void output(LongState state, BlockBuilder out)
{
BIGINT.writeLong(out, state.getLong());
}
}
Expand Up @@ -14,6 +14,8 @@
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.state.NullableDoubleState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.type.SqlType;

Expand All @@ -31,4 +33,10 @@ public static void sum(NullableDoubleState state, @SqlType(StandardTypes.DOUBLE)
state.setNull(false);
state.setDouble(state.getDouble() + value);
}

@OutputFunction(StandardTypes.DOUBLE)
public static void output(NullableDoubleState state, BlockBuilder out)
{
NullableDoubleState.write(DoubleType.DOUBLE, state, out);
}
}

0 comments on commit 04d9050

Please sign in to comment.