diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java index 4725fd8c87cfa3..f28e8752f7a9e1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java @@ -14,6 +14,7 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; import io.trino.metadata.SqlAggregationFunction; import io.trino.operator.aggregation.state.GenericBooleanState; import io.trino.operator.aggregation.state.GenericBooleanStateSerializer; @@ -21,6 +22,8 @@ import io.trino.operator.aggregation.state.GenericDoubleStateSerializer; import io.trino.operator.aggregation.state.GenericLongState; import io.trino.operator.aggregation.state.GenericLongStateSerializer; +import io.trino.operator.aggregation.state.GenericSliceState; +import io.trino.operator.aggregation.state.GenericSliceStateSerializer; import io.trino.operator.aggregation.state.StateCompiler; import io.trino.spi.TrinoException; import io.trino.spi.block.BlockBuilder; @@ -51,14 +54,17 @@ public class ReduceAggregationFunction private static final MethodHandle LONG_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericLongState.class, Object.class, long.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); private static final MethodHandle DOUBLE_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericDoubleState.class, Object.class, double.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); private static final MethodHandle BOOLEAN_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericBooleanState.class, Object.class, boolean.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); + private static final MethodHandle SLICE_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericSliceState.class, Object.class, Slice.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); private static final MethodHandle LONG_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); private static final MethodHandle DOUBLE_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); private static final MethodHandle BOOLEAN_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); + private static final MethodHandle SLICE_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericSliceState.class, GenericSliceState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class); private static final MethodHandle LONG_STATE_OUTPUT_FUNCTION = methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class); private static final MethodHandle DOUBLE_STATE_OUTPUT_FUNCTION = methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class); private static final MethodHandle BOOLEAN_STATE_OUTPUT_FUNCTION = methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class); + private static final MethodHandle SLICE_STATE_OUTPUT_FUNCTION = methodHandle(GenericSliceState.class, "write", Type.class, GenericSliceState.class, BlockBuilder.class); public ReduceAggregationFunction() { @@ -122,7 +128,21 @@ public AggregationImplementation specialize(BoundSignature boundSignature) .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) .build(); } - // State with Slice or Block as native container type is intentionally not supported yet, + + if (stateType.getJavaType() == Slice.class) { + return AggregationImplementation.builder() + .inputFunction(normalizeInputMethod(boundSignature, inputType, SLICE_STATE_INPUT_FUNCTION)) + .combineFunction(SLICE_STATE_COMBINE_FUNCTION) + .outputFunction(SLICE_STATE_OUTPUT_FUNCTION.bindTo(stateType)) + .accumulatorStateDescriptor( + GenericSliceState.class, + new GenericSliceStateSerializer(stateType), + StateCompiler.generateStateFactory(GenericSliceState.class)) + .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) + .build(); + } + + // State with Block as native container type is intentionally not supported yet, // as it may result in excessive JVM memory usage of remembered set. // See JDK-8017163. throw new TrinoException(NOT_SUPPORTED, format("State type not supported for %s: %s", NAME, stateType.getDisplayName())); @@ -162,6 +182,15 @@ public static void input(GenericBooleanState state, Object value, boolean initia state.setValue((boolean) inputFunction.apply(state.getValue(), value)); } + public static void input(GenericSliceState state, Object value, Slice initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) + { + if (state.isNull()) { + state.setNull(false); + state.setValue(initialStateValue); + } + state.setValue((Slice) inputFunction.apply(state.getValue(), value)); + } + public static void combine(GenericLongState state, GenericLongState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) { if (state.isNull()) { @@ -188,4 +217,13 @@ public static void combine(GenericBooleanState state, GenericBooleanState otherS } state.setValue((boolean) combineFunction.apply(state.getValue(), otherState.getValue())); } + + public static void combine(GenericSliceState state, GenericSliceState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) + { + if (state.isNull()) { + state.set(otherState); + return; + } + state.setValue((Slice) combineFunction.apply(state.getValue(), otherState.getValue())); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/GenericSliceState.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/GenericSliceState.java new file mode 100644 index 00000000000000..c09bef024c0593 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/GenericSliceState.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.state; + +import io.airlift.slice.Slice; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; +import io.trino.spi.type.Type; + +@AccumulatorStateMetadata(stateSerializerClass = GenericSliceStateSerializer.class) +public interface GenericSliceState + extends AccumulatorState +{ + Slice getValue(); + + void setValue(Slice value); + + @InitialBooleanValue(true) + boolean isNull(); + + void setNull(boolean value); + + default void set(GenericSliceState state) + { + setValue(state.getValue()); + setNull(state.isNull()); + } + + static void write(Type type, GenericSliceState state, BlockBuilder out) + { + if (state.isNull()) { + out.appendNull(); + } + else { + type.writeSlice(out, state.getValue()); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/GenericSliceStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/GenericSliceStateSerializer.java new file mode 100644 index 00000000000000..d1cd8f7ff2db54 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/GenericSliceStateSerializer.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +public class GenericSliceStateSerializer + implements AccumulatorStateSerializer +{ + private final Type serializedType; + + public GenericSliceStateSerializer(Type serializedType) + { + this.serializedType = requireNonNull(serializedType, "serializedType is null"); + } + + @Override + public Type getSerializedType() + { + return serializedType; + } + + @Override + public void serialize(GenericSliceState state, BlockBuilder out) + { + if (state.isNull()) { + out.appendNull(); + } + else { + serializedType.writeSlice(out, state.getValue()); + } + } + + @Override + public void deserialize(Block block, int index, GenericSliceState state) + { + state.setNull(false); + state.setValue(serializedType.getSlice(block, index)); + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 734cd082242f29..eff1a5922ca348 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -2219,6 +2219,11 @@ public void testReduceAgg() "FROM (VALUES (1, CAST(5 AS DOUBLE)), (1, 6), (1, 7), (2, 8), (2, 9), (3, 10)) AS t(x, y) " + "GROUP BY x", "VALUES (1, CAST(5 AS DOUBLE) + 6 + 7), (2, 8 + 9), (3, 10)"); + assertQuery( + "SELECT x, reduce_agg(y, '', (a, b) -> a || b, (a, b) -> a || b) " + + "FROM (VALUES ('1', '5'), ('1', '6'), ('1', '7'), ('2', '8'), ('2', '9'), ('3', '10')) AS t(x, y) " + + "GROUP BY x", + "VALUES ('1', '567'), ('2', '89'), ('3', '10')"); } @Test