Skip to content

Commit

Permalink
Add execution support for multiple grouping sets
Browse files Browse the repository at this point in the history
- Add a GroupIdOperator that, for every input row, emits rows for each
  grouping set specified. The row generated will contain NULLs in the
  channels corresponding to the columns present in the union of all
  grouping columns but not present in the current grouping set.
- Modify QueryPlanner to chain a GroupIdNode to an AggregationNode if
  multiple grouping sets are present.
- Add support for ALL and DISTINCT set quantifiers to GROUP BY.
  • Loading branch information
Raghav Sethi committed Mar 4, 2016
1 parent e223a9e commit fd0a15d
Show file tree
Hide file tree
Showing 33 changed files with 972 additions and 121 deletions.
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;

import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.RunLengthEncodedBlock;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.ImmutableList;

import java.util.BitSet;
import java.util.Collection;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableSet;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class GroupIdOperator
implements Operator
{
public static class GroupIdOperatorFactory
implements OperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final List<Type> inputTypes;
private final List<Type> outputTypes;
private final List<List<Integer>> groupingSetChannels;

private boolean closed;

public GroupIdOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
List<? extends Type> inputTypes,
List<List<Integer>> groupingSetChannels)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.groupingSetChannels = ImmutableList.copyOf(requireNonNull(groupingSetChannels));
this.inputTypes = ImmutableList.copyOf(requireNonNull(inputTypes));

// add the groupId channel to the output types
this.outputTypes = ImmutableList.<Type>builder().addAll(inputTypes).add(BIGINT).build();
}

@Override
public List<Type> getTypes()
{
return outputTypes;
}

@Override
public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, GroupIdOperator.class.getSimpleName());

Set<Integer> allGroupingColumns = groupingSetChannels.stream()
.flatMap(Collection::stream)
.collect(toImmutableSet());

// will have a 'true' for every channel that should be set to null for each grouping set
BitSet[] groupingSetNullChannels = new BitSet[groupingSetChannels.size()];
for (int i = 0; i < groupingSetChannels.size(); i++) {
groupingSetNullChannels[i] = new BitSet(inputTypes.size());
// first set all grouping columns to true
for (int groupingColumn : allGroupingColumns) {
groupingSetNullChannels[i].set(groupingColumn, true);
}
// then set all the columns in this grouping set to false
for (int nonNullGroupingColumn : groupingSetChannels.get(i)) {
groupingSetNullChannels[i].set(nonNullGroupingColumn, false);
}
}

Block[] nullBlocks = new Block[inputTypes.size()];
for (int i = 0; i < nullBlocks.length; i++) {
nullBlocks[i] = inputTypes.get(i).createBlockBuilder(new BlockBuilderStatus(), 1)
.appendNull()
.build();
}

Block[] groupIdBlocks = new Block[groupingSetNullChannels.length];
for (int i = 0; i < groupingSetNullChannels.length; i++) {
BlockBuilder builder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), 1);
BIGINT.writeLong(builder, i);
groupIdBlocks[i] = builder.build();
}

return new GroupIdOperator(operatorContext, outputTypes, groupingSetNullChannels, nullBlocks, groupIdBlocks);
}

@Override
public void close()
{
closed = true;
}

@Override
public OperatorFactory duplicate()
{
return new GroupIdOperatorFactory(operatorId, planNodeId, inputTypes, groupingSetChannels);
}
}

private final OperatorContext operatorContext;
private final List<Type> types;
private final BitSet[] groupingSetNullChannels;
private final Block[] nullBlocks;
private final Block[] groupIdBlocks;

private Page currentPage = null;
private int currentGroupingSet = 0;
private boolean finishing;

public GroupIdOperator(
OperatorContext operatorContext,
List<Type> types,
BitSet[] groupingSetNullChannels,
Block[] nullBlocks,
Block[] groupIdBlocks)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.types = requireNonNull(types, "inputTypes is null");
this.groupingSetNullChannels = requireNonNull(groupingSetNullChannels, "groupingSetNullChannels is null");
this.nullBlocks = requireNonNull(nullBlocks);
checkArgument(nullBlocks.length == (types.size() - 1), "length of nullBlocks must be one plus length of types");
this.groupIdBlocks = requireNonNull(groupIdBlocks);
checkArgument(groupIdBlocks.length == groupingSetNullChannels.length, "groupIdBlocks and groupingSetNullChannels must have the same length");
}

@Override
public OperatorContext getOperatorContext()
{
return operatorContext;
}

@Override
public List<Type> getTypes()
{
return types;
}

@Override
public void finish()
{
finishing = true;
}

@Override
public boolean isFinished()
{
return finishing && currentPage == null;
}

@Override
public boolean needsInput()
{
return !finishing && currentPage == null;
}

@Override
public void addInput(Page page)
{
checkState(!finishing, "Operator is already finishing");
checkState(currentPage == null, "currentPage must be null to add a new page");

currentPage = requireNonNull(page, "page is null");
}

@Override
public Page getOutput()
{
if (currentPage == null) {
return null;
}

return generateNextPage();
}

private Page generateNextPage()
{
// generate 'n' pages for every input page, where n is the number of grouping sets
Block[] inputBlocks = currentPage.getBlocks();
Block[] outputBlocks = new Block[currentPage.getChannelCount() + 1];

for (int channel = 0; channel < currentPage.getChannelCount(); channel++) {
if (groupingSetNullChannels[currentGroupingSet].get(channel)) {
outputBlocks[channel] = new RunLengthEncodedBlock(nullBlocks[channel], currentPage.getPositionCount());
}
else {
outputBlocks[channel] = inputBlocks[channel];
}
}

outputBlocks[outputBlocks.length - 1] = new RunLengthEncodedBlock(groupIdBlocks[currentGroupingSet], currentPage.getPositionCount());
currentGroupingSet = (currentGroupingSet + 1) % groupingSetNullChannels.length;
Page outputPage = new Page(currentPage.getPositionCount(), outputBlocks);

if (currentGroupingSet == 0) {
currentPage = null;
}

return outputPage;
}
}
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FrameBound;
import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GroupBy;
import com.facebook.presto.sql.tree.GroupingElement; import com.facebook.presto.sql.tree.GroupingElement;
import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.Insert; import com.facebook.presto.sql.tree.Insert;
Expand Down Expand Up @@ -111,6 +112,7 @@
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Multimap; import com.google.common.collect.Multimap;
Expand Down Expand Up @@ -403,7 +405,7 @@ For example, a table with two partition keys (ds, cluster_name)
Optional.of(logicalAnd( Optional.of(logicalAnd(
equal(nameReference("table_schema"), new StringLiteral(table.getSchemaName())), equal(nameReference("table_schema"), new StringLiteral(table.getSchemaName())),
equal(nameReference("table_name"), new StringLiteral(table.getObjectName())))), equal(nameReference("table_name"), new StringLiteral(table.getObjectName())))),
ImmutableList.of(new SimpleGroupBy(ImmutableList.of(nameReference("partition_number")))), Optional.of(new GroupBy(false, ImmutableList.of(new SimpleGroupBy(ImmutableList.of(nameReference("partition_number")))))),
Optional.empty(), Optional.empty(),
ImmutableList.of(), ImmutableList.of(),
Optional.empty()); Optional.empty());
Expand All @@ -412,7 +414,7 @@ For example, a table with two partition keys (ds, cluster_name)
selectAll(wrappedList.build()), selectAll(wrappedList.build()),
subquery(query), subquery(query),
showPartitions.getWhere(), showPartitions.getWhere(),
ImmutableList.of(), Optional.empty(),
Optional.empty(), Optional.empty(),
ImmutableList.<SortItem>builder() ImmutableList.<SortItem>builder()
.addAll(showPartitions.getOrderBy()) .addAll(showPartitions.getOrderBy())
Expand Down Expand Up @@ -1515,30 +1517,20 @@ else if (expression instanceof LongLiteral) {


private List<List<FieldOrExpression>> analyzeGroupBy(QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions) private List<List<FieldOrExpression>> analyzeGroupBy(QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions)
{ {
List<Set<Set<Expression>>> enumeratedGroupingSets = node.getGroupBy().stream() List<Set<Expression>> computedGroupingSets = ImmutableList.of(); // empty list = no aggregations
.map(GroupingElement::enumerateGroupingSets)
.distinct()
.collect(toImmutableList());


// compute cross product of enumerated grouping sets, if there are any if (node.getGroupBy().isPresent()) {
List<List<Expression>> computedGroupingSets = ImmutableList.of(); List<List<Set<Expression>>> enumeratedGroupingSets = node.getGroupBy().get().getGroupingElements().stream()
if (!enumeratedGroupingSets.isEmpty()) { .map(GroupingElement::enumerateGroupingSets)
computedGroupingSets = Sets.cartesianProduct(enumeratedGroupingSets).stream()
.map(groupingSetList -> groupingSetList.stream()
.flatMap(Collection::stream)
.distinct()
.collect(toImmutableList()))
.distinct()
.collect(toImmutableList()); .collect(toImmutableList());
}


// if there are aggregates, but no grouping columns, create a grand total grouping set // compute cross product of enumerated grouping sets, if there are any
if (computedGroupingSets.isEmpty() && !extractAggregates(node).isEmpty()) { computedGroupingSets = computeGroupingSetsCrossProduct(enumeratedGroupingSets, node.getGroupBy().get().isDistinct());
computedGroupingSets = ImmutableList.of(ImmutableList.of()); checkState(!computedGroupingSets.isEmpty(), "computed grouping sets cannot be empty");
} }

else if (!extractAggregates(node).isEmpty()) {
if (computedGroupingSets.size() > 1) { // if there are aggregates, but no group by, create a grand total grouping set (global aggregation)
throw new SemanticException(NOT_SUPPORTED, node, "Grouping by multiple sets of columns is not yet supported"); computedGroupingSets = ImmutableList.of(ImmutableSet.of());
} }


List<List<FieldOrExpression>> analyzedGroupingSets = computedGroupingSets.stream() List<List<FieldOrExpression>> analyzedGroupingSets = computedGroupingSets.stream()
Expand All @@ -1549,7 +1541,39 @@ private List<List<FieldOrExpression>> analyzeGroupBy(QuerySpecification node, Re
return analyzedGroupingSets; return analyzedGroupingSets;
} }


private List<FieldOrExpression> analyzeGroupingColumns(List<Expression> groupingColumns, QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions) private List<Set<Expression>> computeGroupingSetsCrossProduct(List<List<Set<Expression>>> enumeratedGroupingSets, boolean isDistinct)
{
checkState(!enumeratedGroupingSets.isEmpty(), "enumeratedGroupingSets cannot be empty");

List<Set<Expression>> groupingSetsCrossProduct = new ArrayList<>();
enumeratedGroupingSets.get(0)
.stream()
.map(ImmutableSet::copyOf)
.forEach(groupingSetsCrossProduct::add);

for (int i = 1; i < enumeratedGroupingSets.size(); i++) {
List<Set<Expression>> groupingSets = enumeratedGroupingSets.get(i);
List<Set<Expression>> oldGroupingSetsCrossProduct = ImmutableList.copyOf(groupingSetsCrossProduct);
groupingSetsCrossProduct.clear();
for (Set<Expression> existingSet : oldGroupingSetsCrossProduct) {
for (Set<Expression> groupingSet : groupingSets) {
Set<Expression> concatenatedSet = ImmutableSet.<Expression>builder()
.addAll(existingSet)
.addAll(groupingSet)
.build();
groupingSetsCrossProduct.add(concatenatedSet);
}
}
}

if (isDistinct) {
return ImmutableList.copyOf(ImmutableSet.copyOf(groupingSetsCrossProduct));
}

return groupingSetsCrossProduct;
}

private List<FieldOrExpression> analyzeGroupingColumns(Set<Expression> groupingColumns, QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions)
{ {
ImmutableList.Builder<FieldOrExpression> groupingColumnsBuilder = ImmutableList.builder(); ImmutableList.Builder<FieldOrExpression> groupingColumnsBuilder = ImmutableList.builder();
for (Expression groupingColumn : groupingColumns) { for (Expression groupingColumn : groupingColumns) {
Expand Down Expand Up @@ -1722,25 +1746,24 @@ private void analyzeAggregations(QuerySpecification node,
} }
} }


ImmutableList<FieldOrExpression> allGroupingColumns = groupingSets.stream()
.flatMap(Collection::stream)
.distinct()
.collect(toImmutableList());

// is this an aggregation query? // is this an aggregation query?
if (!groupingSets.isEmpty()) { if (!groupingSets.isEmpty()) {
// ensure SELECT, ORDER BY and HAVING are constant with respect to group // ensure SELECT, ORDER BY and HAVING are constant with respect to group
// e.g, these are all valid expressions: // e.g, these are all valid expressions:
// SELECT f(a) GROUP BY a // SELECT f(a) GROUP BY a
// SELECT f(a + 1) GROUP BY a + 1 // SELECT f(a + 1) GROUP BY a + 1
// SELECT a + sum(b) GROUP BY a // SELECT a + sum(b) GROUP BY a
ImmutableList<FieldOrExpression> distinctGroupingColumns = groupingSets.stream()
.flatMap(Collection::stream)
.distinct()
.collect(toImmutableList());


for (FieldOrExpression fieldOrExpression : Iterables.concat(outputExpressions, orderByExpressions)) { for (FieldOrExpression fieldOrExpression : Iterables.concat(outputExpressions, orderByExpressions)) {
verifyAggregations(node, allGroupingColumns, tupleDescriptor, fieldOrExpression, columnReferences); verifyAggregations(node, distinctGroupingColumns, tupleDescriptor, fieldOrExpression, columnReferences);
} }


if (node.getHaving().isPresent()) { if (node.getHaving().isPresent()) {
verifyAggregations(node, allGroupingColumns, tupleDescriptor, new FieldOrExpression(node.getHaving().get()), columnReferences); verifyAggregations(node, distinctGroupingColumns, tupleDescriptor, new FieldOrExpression(node.getHaving().get()), columnReferences);
} }
} }
} }
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.LimitNode;
Expand Down Expand Up @@ -178,6 +179,12 @@ public Optional<SplitSource> visitAggregation(AggregationNode node, Void context
return node.getSource().accept(this, context); return node.getSource().accept(this, context);
} }


@Override
public Optional<SplitSource> visitGroupId(GroupIdNode node, Void context)
{
return node.getSource().accept(this, context);
}

@Override @Override
public Optional<SplitSource> visitMarkDistinct(MarkDistinctNode node, Void context) public Optional<SplitSource> visitMarkDistinct(MarkDistinctNode node, Void context)
{ {
Expand Down
Loading

0 comments on commit fd0a15d

Please sign in to comment.