Skip to content

Commit

Permalink
Add Slice support to reduce_agg function
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Jan 24, 2024
1 parent d92ad42 commit 314c85a
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 1 deletion.
Expand Up @@ -14,13 +14,16 @@
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.state.GenericBooleanState;
import io.trino.operator.aggregation.state.GenericBooleanStateSerializer;
import io.trino.operator.aggregation.state.GenericDoubleState;
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.GenericSliceState;
import io.trino.operator.aggregation.state.GenericSliceStateSerializer;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
Expand Down Expand Up @@ -51,14 +54,17 @@ public class ReduceAggregationFunction
private static final MethodHandle LONG_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericLongState.class, Object.class, long.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle DOUBLE_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericDoubleState.class, Object.class, double.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle BOOLEAN_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericBooleanState.class, Object.class, boolean.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle SLICE_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericSliceState.class, Object.class, Slice.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);

private static final MethodHandle LONG_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle DOUBLE_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle BOOLEAN_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle SLICE_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericSliceState.class, GenericSliceState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);

private static final MethodHandle LONG_STATE_OUTPUT_FUNCTION = methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class);
private static final MethodHandle DOUBLE_STATE_OUTPUT_FUNCTION = methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class);
private static final MethodHandle BOOLEAN_STATE_OUTPUT_FUNCTION = methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class);
private static final MethodHandle SLICE_STATE_OUTPUT_FUNCTION = methodHandle(GenericSliceState.class, "write", Type.class, GenericSliceState.class, BlockBuilder.class);

public ReduceAggregationFunction()
{
Expand Down Expand Up @@ -122,7 +128,21 @@ public AggregationImplementation specialize(BoundSignature boundSignature)
.lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class)
.build();
}
// State with Slice or Block as native container type is intentionally not supported yet,

if (stateType.getJavaType() == Slice.class) {
return AggregationImplementation.builder()
.inputFunction(normalizeInputMethod(boundSignature, inputType, SLICE_STATE_INPUT_FUNCTION))
.combineFunction(SLICE_STATE_COMBINE_FUNCTION)
.outputFunction(SLICE_STATE_OUTPUT_FUNCTION.bindTo(stateType))
.accumulatorStateDescriptor(
GenericSliceState.class,
new GenericSliceStateSerializer(stateType),
StateCompiler.generateStateFactory(GenericSliceState.class))
.lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class)
.build();
}

// State with Block as native container type is intentionally not supported yet,
// as it may result in excessive JVM memory usage of remembered set.
// See JDK-8017163.
throw new TrinoException(NOT_SUPPORTED, format("State type not supported for %s: %s", NAME, stateType.getDisplayName()));
Expand Down Expand Up @@ -162,6 +182,15 @@ public static void input(GenericBooleanState state, Object value, boolean initia
state.setValue((boolean) inputFunction.apply(state.getValue(), value));
}

public static void input(GenericSliceState state, Object value, Slice initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
{
if (state.isNull()) {
state.setNull(false);
state.setValue(initialStateValue);
}
state.setValue((Slice) inputFunction.apply(state.getValue(), value));
}

public static void combine(GenericLongState state, GenericLongState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
{
if (state.isNull()) {
Expand All @@ -188,4 +217,13 @@ public static void combine(GenericBooleanState state, GenericBooleanState otherS
}
state.setValue((boolean) combineFunction.apply(state.getValue(), otherState.getValue()));
}

public static void combine(GenericSliceState state, GenericSliceState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
{
if (state.isNull()) {
state.set(otherState);
return;
}
state.setValue((Slice) combineFunction.apply(state.getValue(), otherState.getValue()));
}
}
@@ -0,0 +1,50 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.airlift.slice.Slice;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateMetadata;
import io.trino.spi.type.Type;

@AccumulatorStateMetadata(stateSerializerClass = GenericSliceStateSerializer.class)
public interface GenericSliceState
extends AccumulatorState
{
Slice getValue();

void setValue(Slice value);

@InitialBooleanValue(true)
boolean isNull();

void setNull(boolean value);

default void set(GenericSliceState state)
{
setValue(state.getValue());
setNull(state.isNull());
}

static void write(Type type, GenericSliceState state, BlockBuilder out)
{
if (state.isNull()) {
out.appendNull();
}
else {
type.writeSlice(out, state.getValue());
}
}
}
@@ -0,0 +1,56 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Type;

import static java.util.Objects.requireNonNull;

public class GenericSliceStateSerializer
implements AccumulatorStateSerializer<GenericSliceState>
{
private final Type serializedType;

public GenericSliceStateSerializer(Type serializedType)
{
this.serializedType = requireNonNull(serializedType, "serializedType is null");
}

@Override
public Type getSerializedType()
{
return serializedType;
}

@Override
public void serialize(GenericSliceState state, BlockBuilder out)
{
if (state.isNull()) {
out.appendNull();
}
else {
serializedType.writeSlice(out, state.getValue());
}
}

@Override
public void deserialize(Block block, int index, GenericSliceState state)
{
state.setNull(false);
state.setValue(serializedType.getSlice(block, index));
}
}
Expand Up @@ -2219,6 +2219,11 @@ public void testReduceAgg()
"FROM (VALUES (1, CAST(5 AS DOUBLE)), (1, 6), (1, 7), (2, 8), (2, 9), (3, 10)) AS t(x, y) " +
"GROUP BY x",
"VALUES (1, CAST(5 AS DOUBLE) + 6 + 7), (2, 8 + 9), (3, 10)");
assertQuery(
"SELECT x, reduce_agg(y, '', (a, b) -> a || b, (a, b) -> a || b) " +
"FROM (VALUES ('1', '5'), ('1', '6'), ('1', '7'), ('2', '8'), ('2', '9'), ('3', '10')) AS t(x, y) " +
"GROUP BY x",
"VALUES ('1', '567'), ('2', '89'), ('3', '10')");
}

@Test
Expand Down

0 comments on commit 314c85a

Please sign in to comment.