Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Slice support to reduce_agg function #20452

Merged
merged 1 commit into from Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -14,13 +14,16 @@
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.state.GenericBooleanState;
import io.trino.operator.aggregation.state.GenericBooleanStateSerializer;
import io.trino.operator.aggregation.state.GenericDoubleState;
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.GenericSliceState;
import io.trino.operator.aggregation.state.GenericSliceStateSerializer;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
Expand Down Expand Up @@ -51,14 +54,17 @@ public class ReduceAggregationFunction
private static final MethodHandle LONG_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericLongState.class, Object.class, long.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle DOUBLE_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericDoubleState.class, Object.class, double.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle BOOLEAN_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericBooleanState.class, Object.class, boolean.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle SLICE_STATE_INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", GenericSliceState.class, Object.class, Slice.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);

private static final MethodHandle LONG_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle DOUBLE_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle BOOLEAN_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
private static final MethodHandle SLICE_STATE_COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", GenericSliceState.class, GenericSliceState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);

private static final MethodHandle LONG_STATE_OUTPUT_FUNCTION = methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class);
private static final MethodHandle DOUBLE_STATE_OUTPUT_FUNCTION = methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class);
private static final MethodHandle BOOLEAN_STATE_OUTPUT_FUNCTION = methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class);
private static final MethodHandle SLICE_STATE_OUTPUT_FUNCTION = methodHandle(GenericSliceState.class, "write", Type.class, GenericSliceState.class, BlockBuilder.class);

public ReduceAggregationFunction()
{
Expand Down Expand Up @@ -122,7 +128,21 @@ public AggregationImplementation specialize(BoundSignature boundSignature)
.lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class)
.build();
}
// State with Slice or Block as native container type is intentionally not supported yet,

if (stateType.getJavaType() == Slice.class) {
return AggregationImplementation.builder()
.inputFunction(normalizeInputMethod(boundSignature, inputType, SLICE_STATE_INPUT_FUNCTION))
.combineFunction(SLICE_STATE_COMBINE_FUNCTION)
.outputFunction(SLICE_STATE_OUTPUT_FUNCTION.bindTo(stateType))
.accumulatorStateDescriptor(
GenericSliceState.class,
new GenericSliceStateSerializer(stateType),
StateCompiler.generateStateFactory(GenericSliceState.class))
.lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class)
.build();
}

// State with Block as native container type is intentionally not supported yet,
// as it may result in excessive JVM memory usage of remembered set.
// See JDK-8017163.
throw new TrinoException(NOT_SUPPORTED, format("State type not supported for %s: %s", NAME, stateType.getDisplayName()));
Expand Down Expand Up @@ -162,6 +182,15 @@ public static void input(GenericBooleanState state, Object value, boolean initia
state.setValue((boolean) inputFunction.apply(state.getValue(), value));
}

public static void input(GenericSliceState state, Object value, Slice initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
{
if (state.isNull()) {
state.setNull(false);
state.setValue(initialStateValue);
}
state.setValue((Slice) inputFunction.apply(state.getValue(), value));
}

public static void combine(GenericLongState state, GenericLongState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
{
if (state.isNull()) {
Expand All @@ -188,4 +217,13 @@ public static void combine(GenericBooleanState state, GenericBooleanState otherS
}
state.setValue((boolean) combineFunction.apply(state.getValue(), otherState.getValue()));
}

public static void combine(GenericSliceState state, GenericSliceState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
{
if (state.isNull()) {
state.set(otherState);
return;
}
state.setValue((Slice) combineFunction.apply(state.getValue(), otherState.getValue()));
}
}
@@ -0,0 +1,50 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.airlift.slice.Slice;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateMetadata;
import io.trino.spi.type.Type;

@AccumulatorStateMetadata(stateSerializerClass = GenericSliceStateSerializer.class)
public interface GenericSliceState
extends AccumulatorState
{
Slice getValue();

void setValue(Slice value);

@InitialBooleanValue(true)
boolean isNull();

void setNull(boolean value);

default void set(GenericSliceState state)
{
setValue(state.getValue());
setNull(state.isNull());
}

static void write(Type type, GenericSliceState state, BlockBuilder out)
{
if (state.isNull()) {
out.appendNull();
}
else {
type.writeSlice(out, state.getValue());
}
}
}
@@ -0,0 +1,56 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Type;

import static java.util.Objects.requireNonNull;

public class GenericSliceStateSerializer
implements AccumulatorStateSerializer<GenericSliceState>
{
private final Type serializedType;

public GenericSliceStateSerializer(Type serializedType)
{
this.serializedType = requireNonNull(serializedType, "serializedType is null");
}

@Override
public Type getSerializedType()
{
return serializedType;
}

@Override
public void serialize(GenericSliceState state, BlockBuilder out)
{
if (state.isNull()) {
out.appendNull();
}
else {
serializedType.writeSlice(out, state.getValue());
}
}

@Override
public void deserialize(Block block, int index, GenericSliceState state)
{
state.setNull(false);
state.setValue(serializedType.getSlice(block, index));
}
}
2 changes: 1 addition & 1 deletion docs/src/main/sphinx/functions/aggregate.md
Expand Up @@ -634,5 +634,5 @@ GROUP BY id;
-- (2, 42)
```

The state type must be a boolean, integer, floating-point, or date/time/interval.
The state type must be a boolean, integer, floating-point, char, varchar or date/time/interval.
:::
Expand Up @@ -2219,6 +2219,11 @@ public void testReduceAgg()
"FROM (VALUES (1, CAST(5 AS DOUBLE)), (1, 6), (1, 7), (2, 8), (2, 9), (3, 10)) AS t(x, y) " +
"GROUP BY x",
"VALUES (1, CAST(5 AS DOUBLE) + 6 + 7), (2, 8 + 9), (3, 10)");
assertQuery(
"SELECT x, reduce_agg(y, '', (a, b) -> a || b, (a, b) -> a || b) " +
"FROM (VALUES ('1', '5'), ('1', '6'), ('1', '7'), ('2', '8'), ('2', '9'), ('3', '10')) AS t(x, y) " +
"GROUP BY x",
"VALUES ('1', '567'), ('2', '89'), ('3', '10')");
}

@Test
Expand Down