Skip to content

Commit

Permalink
Add limited support for cube, rollup and grouping set to grammar
Browse files Browse the repository at this point in the history
- Add CUBE, ROLLUP and simple GROUPING SETS (i.e. one-level of nesting)
to grammar.
- Add analyzer support for simple grouping sets.
- Add planning and execution support for queries with a single
grouping set (this is essentially a simple GROUP BY).
  • Loading branch information
Raghav Sethi committed Nov 19, 2015
1 parent 9c1fe4b commit ce2ae58
Show file tree
Hide file tree
Showing 19 changed files with 794 additions and 74 deletions.
Expand Up @@ -59,7 +59,7 @@ public class Analysis
private final IdentityHashMap<Expression, Integer> resolvedNames = new IdentityHashMap<>();

private final IdentityHashMap<QuerySpecification, List<FunctionCall>> aggregates = new IdentityHashMap<>();
private final IdentityHashMap<QuerySpecification, List<FieldOrExpression>> groupByExpressions = new IdentityHashMap<>();
private final IdentityHashMap<QuerySpecification, List<List<FieldOrExpression>>> groupByExpressions = new IdentityHashMap<>();
private final IdentityHashMap<Node, Expression> where = new IdentityHashMap<>();
private final IdentityHashMap<QuerySpecification, Expression> having = new IdentityHashMap<>();
private final IdentityHashMap<Node, List<FieldOrExpression>> orderByExpressions = new IdentityHashMap<>();
Expand Down Expand Up @@ -172,12 +172,12 @@ public Type getCoercion(Expression expression)
return coercions.get(expression);
}

public void setGroupByExpressions(QuerySpecification node, List<FieldOrExpression> expressions)
public void setGroupingSets(QuerySpecification node, List<List<FieldOrExpression>> expressions)
{
groupByExpressions.put(node, expressions);
}

public List<FieldOrExpression> getGroupByExpressions(QuerySpecification node)
public List<List<FieldOrExpression>> getGroupingSets(QuerySpecification node)
{
return groupByExpressions.get(node);
}
Expand Down
Expand Up @@ -60,6 +60,7 @@
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FrameBound;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GroupingElement;
import com.facebook.presto.sql.tree.Insert;
import com.facebook.presto.sql.tree.Intersect;
import com.facebook.presto.sql.tree.Join;
Expand All @@ -85,6 +86,7 @@
import com.facebook.presto.sql.tree.ShowSchemas;
import com.facebook.presto.sql.tree.ShowSession;
import com.facebook.presto.sql.tree.ShowTables;
import com.facebook.presto.sql.tree.SimpleGroupBy;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.Statement;
Expand Down Expand Up @@ -388,7 +390,7 @@ For example, a table with two partition keys (ds, cluster_name)
Optional.of(logicalAnd(
equal(nameReference("table_schema"), new StringLiteral(table.getSchemaName())),
equal(nameReference("table_name"), new StringLiteral(table.getTableName())))),
ImmutableList.of(nameReference("partition_number")),
ImmutableList.of(new SimpleGroupBy(ImmutableList.of(nameReference("partition_number")))),
Optional.empty(),
ImmutableList.of(),
Optional.empty());
Expand Down Expand Up @@ -933,7 +935,7 @@ protected RelationType visitQuerySpecification(QuerySpecification node, Analysis
node.getWhere().ifPresent(where -> analyzeWhere(node, tupleDescriptor, context, where));

List<FieldOrExpression> outputExpressions = analyzeSelect(node, tupleDescriptor, context);
List<FieldOrExpression> groupByExpressions = analyzeGroupBy(node, tupleDescriptor, context, outputExpressions);
List<List<FieldOrExpression>> groupByExpressions = analyzeGroupBy(node, tupleDescriptor, context, outputExpressions);
List<FieldOrExpression> orderByExpressions = analyzeOrderBy(node, tupleDescriptor, context, outputExpressions);
analyzeHaving(node, tupleDescriptor, context);

Expand Down Expand Up @@ -1411,48 +1413,78 @@ else if (expression instanceof LongLiteral) {
return orderByExpressions;
}

private List<FieldOrExpression> analyzeGroupBy(QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions)
private List<List<FieldOrExpression>> analyzeGroupBy(QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions)
{
ImmutableList.Builder<FieldOrExpression> groupByExpressionsBuilder = ImmutableList.builder();
if (!node.getGroupBy().isEmpty()) {
// Translate group by expressions that reference ordinals
for (Expression expression : node.getGroupBy()) {
// first, see if this is an ordinal
FieldOrExpression groupByExpression;
List<Set<Set<Expression>>> enumeratedGroupingSets = node.getGroupBy().stream()
.map(GroupingElement::enumerateGroupingSets)
.distinct()
.collect(toImmutableList());

if (expression instanceof LongLiteral) {
long ordinal = ((LongLiteral) expression).getValue();
if (ordinal < 1 || ordinal > outputExpressions.size()) {
throw new SemanticException(INVALID_ORDINAL, expression, "GROUP BY position %s is not in select list", ordinal);
}
// compute cross product of enumerated grouping sets, if there are any
List<List<Expression>> computedGroupingSets = ImmutableList.of();
if (!enumeratedGroupingSets.isEmpty()) {
computedGroupingSets = Sets.cartesianProduct(enumeratedGroupingSets).stream()
.map(groupingSetList -> groupingSetList.stream()
.flatMap(Collection::stream)
.distinct()
.collect(toImmutableList()))
.distinct()
.collect(toImmutableList());
}

groupByExpression = outputExpressions.get((int) (ordinal - 1));
}
else {
ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, tupleDescriptor, context);
analysis.addInPredicates(node, expressionAnalysis.getSubqueryInPredicates());
groupByExpression = new FieldOrExpression(expression);
}
// if there are aggregates, but no grouping columns, create a grand total grouping set
if (computedGroupingSets.isEmpty() && !extractAggregates(node).isEmpty()) {
computedGroupingSets = ImmutableList.of(ImmutableList.of());
}

Type type;
if (groupByExpression.isExpression()) {
Analyzer.verifyNoAggregatesOrWindowFunctions(metadata, groupByExpression.getExpression(), "GROUP BY");
type = analysis.getType(groupByExpression.getExpression());
}
else {
type = tupleDescriptor.getFieldByIndex(groupByExpression.getFieldIndex()).getType();
}
if (!type.isComparable()) {
throw new SemanticException(TYPE_MISMATCH, node, "%s is not comparable, and therefore cannot be used in GROUP BY", type);
if (computedGroupingSets.size() > 1) {
throw new SemanticException(NOT_SUPPORTED, node, "Grouping by multiple sets of columns is not yet supported");
}

List<List<FieldOrExpression>> analyzedGroupingSets = computedGroupingSets.stream()
.map(groupingSet -> analyzeGroupingColumns(groupingSet, node, tupleDescriptor, context, outputExpressions))
.collect(toImmutableList());

analysis.setGroupingSets(node, analyzedGroupingSets);
return analyzedGroupingSets;
}

private List<FieldOrExpression> analyzeGroupingColumns(List<Expression> groupingColumns, QuerySpecification node, RelationType tupleDescriptor, AnalysisContext context, List<FieldOrExpression> outputExpressions)
{
ImmutableList.Builder<FieldOrExpression> groupingColumnsBuilder = ImmutableList.builder();
for (Expression groupingColumn : groupingColumns) {
// first, see if this is an ordinal
FieldOrExpression groupByExpression;

if (groupingColumn instanceof LongLiteral) {
long ordinal = ((LongLiteral) groupingColumn).getValue();
if (ordinal < 1 || ordinal > outputExpressions.size()) {
throw new SemanticException(INVALID_ORDINAL, groupingColumn, "GROUP BY position %s is not in select list", ordinal);
}

groupByExpressionsBuilder.add(groupByExpression);
groupByExpression = outputExpressions.get((int) (ordinal - 1));
}
else {
ExpressionAnalysis expressionAnalysis = analyzeExpression(groupingColumn, tupleDescriptor, context);
analysis.addInPredicates(node, expressionAnalysis.getSubqueryInPredicates());
groupByExpression = new FieldOrExpression(groupingColumn);
}

Type type;
if (groupByExpression.isExpression()) {
Analyzer.verifyNoAggregatesOrWindowFunctions(metadata, groupByExpression.getExpression(), "GROUP BY");
type = analysis.getType(groupByExpression.getExpression());
}
else {
type = tupleDescriptor.getFieldByIndex(groupByExpression.getFieldIndex()).getType();
}
if (!type.isComparable()) {
throw new SemanticException(TYPE_MISMATCH, node, "%s is not comparable, and therefore cannot be used in GROUP BY", type);
}
}

List<FieldOrExpression> groupByExpressions = groupByExpressionsBuilder.build();
analysis.setGroupByExpressions(node, groupByExpressions);
return groupByExpressions;
groupingColumnsBuilder.add(groupByExpression);
}
return groupingColumnsBuilder.build();
}

private RelationType computeOutputDescriptor(QuerySpecification node, RelationType inputTupleDescriptor)
Expand Down Expand Up @@ -1578,7 +1610,7 @@ private RelationType analyzeFrom(QuerySpecification node, AnalysisContext contex

private void analyzeAggregations(QuerySpecification node,
RelationType tupleDescriptor,
List<FieldOrExpression> groupByExpressions,
List<List<FieldOrExpression>> groupingSets,
List<FieldOrExpression> outputExpressions,
List<FieldOrExpression> orderByExpressions,
AnalysisContext context,
Expand All @@ -1592,19 +1624,25 @@ private void analyzeAggregations(QuerySpecification node,
}
}

ImmutableList<FieldOrExpression> allGroupingColumns = groupingSets.stream()
.flatMap(Collection::stream)
.distinct()
.collect(toImmutableList());

// is this an aggregation query?
if (!aggregates.isEmpty() || !groupByExpressions.isEmpty()) {
if (!aggregates.isEmpty() || !allGroupingColumns.isEmpty()) {
// ensure SELECT, ORDER BY and HAVING are constant with respect to group
// e.g, these are all valid expressions:
// SELECT f(a) GROUP BY a
// SELECT f(a + 1) GROUP BY a + 1
// SELECT a + sum(b) GROUP BY a

for (FieldOrExpression fieldOrExpression : Iterables.concat(outputExpressions, orderByExpressions)) {
verifyAggregations(node, groupByExpressions, tupleDescriptor, fieldOrExpression, columnReferences);
verifyAggregations(node, allGroupingColumns, tupleDescriptor, fieldOrExpression, columnReferences);
}

if (node.getHaving().isPresent()) {
verifyAggregations(node, groupByExpressions, tupleDescriptor, new FieldOrExpression(node.getHaving().get()), columnReferences);
verifyAggregations(node, allGroupingColumns, tupleDescriptor, new FieldOrExpression(node.getHaving().get()), columnReferences);
}
}
}
Expand Down
Expand Up @@ -75,6 +75,7 @@
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;

class QueryPlanner
Expand Down Expand Up @@ -339,18 +340,24 @@ private PlanBuilder explicitCoercionSymbols(PlanBuilder subPlan, Iterable<Symbol

private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
{
if (analysis.getAggregates(node).isEmpty() && analysis.getGroupByExpressions(node).isEmpty()) {
List<List<FieldOrExpression>> groupingSets = analysis.getGroupingSets(node);
if (groupingSets.isEmpty()) {
return subPlan;
}

return aggregateGroupingSet(getOnlyElement(groupingSets), subPlan, node);
}

private PlanBuilder aggregateGroupingSet(List<FieldOrExpression> groupingSet, PlanBuilder subPlan, QuerySpecification node)
{
List<FieldOrExpression> arguments = analysis.getAggregates(node).stream()
.map(FunctionCall::getArguments)
.flatMap(List::stream)
.map(FieldOrExpression::new)
.collect(toImmutableList());

// 1. Pre-project all scalar inputs (arguments and non-trivial group by expressions)
Iterable<FieldOrExpression> inputs = Iterables.concat(analysis.getGroupByExpressions(node), arguments);
Iterable<FieldOrExpression> inputs = Iterables.concat(groupingSet, arguments);
if (!Iterables.isEmpty(inputs)) { // avoid an empty projection if the only aggregation is COUNT (which has no arguments)
subPlan = project(subPlan, inputs);
}
Expand Down Expand Up @@ -380,7 +387,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)

// 2.b. Rewrite group by expressions in terms of pre-projected inputs
Set<Symbol> groupBySymbols = new LinkedHashSet<>();
for (FieldOrExpression fieldOrExpression : analysis.getGroupByExpressions(node)) {
for (FieldOrExpression fieldOrExpression : groupingSet) {
Symbol symbol = subPlan.translate(fieldOrExpression);
groupBySymbols.add(symbol);
translations.put(fieldOrExpression, symbol);
Expand All @@ -397,7 +404,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
Symbol aggregateSymbol = translations.get(aggregate);
if (marker == null) {
if (args.size() == 1) {
marker = symbolAllocator.newSymbol(Iterables.getOnlyElement(args), BOOLEAN, "distinct");
marker = symbolAllocator.newSymbol(getOnlyElement(args), BOOLEAN, "distinct");
}
else {
marker = symbolAllocator.newSymbol(aggregateSymbol.getName(), BOOLEAN, "distinct");
Expand Down Expand Up @@ -445,7 +452,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
// Add back the implicit casts that we removed in 2.a
// TODO: this is a hack, we should change type coercions to coerce the inputs to functions/operators instead of coercing the output
if (needPostProjectionCoercion) {
return explicitCoercionFields(subPlan, analysis.getGroupByExpressions(node), analysis.getAggregates(node));
return explicitCoercionFields(subPlan, groupingSet, analysis.getAggregates(node));
}
return subPlan;
}
Expand Down Expand Up @@ -630,7 +637,7 @@ private PlanBuilder appendSemiJoin(PlanBuilder subPlan, InPredicate inPredicate)
SubqueryExpression subqueryExpression = (SubqueryExpression) inPredicate.getValueList();
RelationPlanner relationPlanner = new RelationPlanner(analysis, symbolAllocator, idAllocator, metadata, session);
RelationPlan valueListRelation = relationPlanner.process(subqueryExpression.getQuery(), null);
Symbol filteringSourceJoinSymbol = Iterables.getOnlyElement(valueListRelation.getRoot().getOutputSymbols());
Symbol filteringSourceJoinSymbol = getOnlyElement(valueListRelation.getRoot().getOutputSymbols());

Symbol semiJoinOutputSymbol = symbolAllocator.newSymbol("semijoinresult", BOOLEAN);

Expand Down
Expand Up @@ -668,6 +668,17 @@ public void testGroupBy()
analyze("SELECT a, SUM(b) FROM t1 GROUP BY a");
}

@Test
public void testSingleGroupingSet()
throws Exception
{
// TODO: validate output
analyze("SELECT SUM(b) FROM t1 GROUP BY ()");
analyze("SELECT SUM(b) FROM t1 GROUP BY GROUPING SETS (())");
analyze("SELECT a, SUM(b) FROM t1 GROUP BY GROUPING SETS (a)");
analyze("SELECT a, SUM(b) FROM t1 GROUP BY GROUPING SETS ((a, b))");
}

@Test
public void testAggregateWithWildcard()
throws Exception
Expand Down
Expand Up @@ -111,10 +111,27 @@ querySpecification
: SELECT setQuantifier? selectItem (',' selectItem)*
(FROM relation (',' relation)*)?
(WHERE where=booleanExpression)?
(GROUP BY groupBy+=expression (',' groupBy+=expression)*)?
(GROUP BY groupingElement (',' groupingElement)*)?
(HAVING having=booleanExpression)?
;

groupingElement
: groupingExpressions #singleGroupingSet
| ROLLUP '(' (qualifiedName (',' qualifiedName)*)? ')' #rollup
| CUBE '(' (qualifiedName (',' qualifiedName)*)? ')' #cube
| GROUPING SETS '(' groupingSet (',' groupingSet)* ')' #multipleGroupingSets
;

groupingExpressions
: '(' (expression (',' expression)*)? ')'
| expression
;

groupingSet
: '(' (qualifiedName (',' qualifiedName)*)? ')'
| qualifiedName
;

namedQuery
: name=identifier (columnAliases)? AS '(' query ')'
;
Expand Down Expand Up @@ -371,6 +388,10 @@ DISTINCT: 'DISTINCT';
WHERE: 'WHERE';
GROUP: 'GROUP';
BY: 'BY';
GROUPING: 'GROUPING';
SETS: 'SETS';
CUBE: 'CUBE';
ROLLUP: 'ROLLUP';
ORDER: 'ORDER';
HAVING: 'HAVING';
LIMIT: 'LIMIT';
Expand Down

0 comments on commit ce2ae58

Please sign in to comment.