Skip to content

Commit

Permalink
Support string joining in SQL via LISTAGG clause
Browse files Browse the repository at this point in the history
The LISTAGG joins varchar values from a group of rows
by using a configurable delimiter.

The syntax tree should represent the structure of the
original parsed query as closely as possible and any semantic interpretation
should be part of the analysis/planning phase.
In case of `LISTAGG` aggregation function though it is more pragmatic
now to create a synthetic FunctionCall expression during the parsing of
the syntax tree.
  • Loading branch information
findinpath authored and martint committed Sep 8, 2021
1 parent 253a6b0 commit 4c321a9
Show file tree
Hide file tree
Showing 23 changed files with 1,844 additions and 9 deletions.
Expand Up @@ -287,6 +287,7 @@
import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG;
import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG;
import static io.trino.operator.aggregation.listagg.ListaggAggregationFunction.LISTAGG;
import static io.trino.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY;
import static io.trino.operator.aggregation.minmaxby.MinByAggregationFunction.MIN_BY;
import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
Expand Down Expand Up @@ -563,6 +564,7 @@ public FunctionRegistry(
.function(ARRAY_CONCAT_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, JSON_TO_ARRAY, JSON_STRING_TO_ARRAY)
.function(ARRAY_AGG)
.function(LISTAGG)
.functions(new MapSubscriptOperator())
.functions(MAP_CONSTRUCTOR, JSON_TO_MAP, JSON_STRING_TO_MAP)
.functions(new MapAggregationFunction(blockTypeOperators), new MapUnionAggregation(blockTypeOperators))
Expand Down
Expand Up @@ -158,7 +158,9 @@ public void forEach(T consumer)
short currentBlockId = headBlockIndex.get(getGroupId());
int currentPosition = headPosition.get(getGroupId());
while (currentBlockId != NULL) {
accept(consumer, values.get(currentBlockId), currentPosition);
if (!accept(consumer, values.get(currentBlockId), currentPosition)) {
break;
}

long absoluteCurrentAddress = toAbsolutePosition(currentBlockId, currentPosition);
currentBlockId = nextBlockIndex.get(absoluteCurrentAddress);
Expand All @@ -181,5 +183,5 @@ private long toAbsolutePosition(short blockId, int position)
return sumPositions.get(blockId) + position;
}

protected abstract void accept(T consumer, PageBuilder pageBuilder, int currentPosition);
protected abstract boolean accept(T consumer, PageBuilder pageBuilder, int currentPosition);
}
Expand Up @@ -39,8 +39,9 @@ public final void add(Block block, int position)
}

@Override
protected final void accept(ArrayAggregationStateConsumer consumer, PageBuilder pageBuilder, int currentPosition)
protected final boolean accept(ArrayAggregationStateConsumer consumer, PageBuilder pageBuilder, int currentPosition)
{
consumer.accept(pageBuilder.getBlockBuilder(VALUE_CHANNEL), currentPosition);
return true;
}
}
@@ -0,0 +1,101 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation.listagg;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.operator.aggregation.AbstractGroupCollectionAggregationState;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.type.Type;

public final class GroupListaggAggregationState
extends AbstractGroupCollectionAggregationState<ListaggAggregationStateConsumer>
implements ListaggAggregationState
{
private static final int MAX_BLOCK_SIZE = 1024 * 1024;
private static final int VALUE_CHANNEL = 0;

private Slice separator;
private boolean overflowError;
private Slice overflowFiller;
private boolean showOverflowEntryCount;

public GroupListaggAggregationState(Type valueType)
{
super(PageBuilder.withMaxPageSize(MAX_BLOCK_SIZE, ImmutableList.of(valueType)));
}

@Override
public void setSeparator(Slice separator)
{
this.separator = separator;
}

@Override
public Slice getSeparator()
{
return separator;
}

@Override
public void setOverflowFiller(Slice overflowFiller)
{
this.overflowFiller = overflowFiller;
}

@Override
public Slice getOverflowFiller()
{
return overflowFiller;
}

@Override
public void setOverflowError(boolean overflowError)
{
this.overflowError = overflowError;
}

@Override
public boolean isOverflowError()
{
return overflowError;
}

@Override
public void setShowOverflowEntryCount(boolean showOverflowEntryCount)
{
this.showOverflowEntryCount = showOverflowEntryCount;
}

@Override
public boolean showOverflowEntryCount()
{
return showOverflowEntryCount;
}

@Override
public final void add(Block block, int position)
{
prepareAdd();
appendAtChannel(VALUE_CHANNEL, block, position);
}

@Override
protected final boolean accept(ListaggAggregationStateConsumer consumer, PageBuilder pageBuilder, int currentPosition)
{
consumer.accept(pageBuilder.getBlockBuilder(VALUE_CHANNEL), currentPosition);
return true;
}
}

0 comments on commit 4c321a9

Please sign in to comment.