Skip to content

Commit

Permalink
Add optimizer to aggregate distinct and non-distinct inputs separately
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamtagra authored and martint committed Oct 31, 2016
1 parent 6071aa5 commit 4f06521
Show file tree
Hide file tree
Showing 11 changed files with 865 additions and 3 deletions.
Expand Up @@ -52,6 +52,7 @@ public static class ProcessingOptimization
private boolean optimizeSingleDistinct = true;
private boolean pushTableWriteThroughUnion = true;
private boolean legacyArrayAgg;
private boolean optimizeMixedDistinctAggregations;

private String processingOptimization = ProcessingOptimization.DISABLED;
private boolean dictionaryAggregation;
Expand Down Expand Up @@ -298,4 +299,16 @@ public FeaturesConfig setSpillerThreads(int spillerThreads)
this.spillerThreads = spillerThreads;
return this;
}

public boolean isOptimizeMixedDistinctAggregations()
{
return optimizeMixedDistinctAggregations;
}

@Config("optimizer.optimize-mixed-distinct-aggregations")
public FeaturesConfig setOptimizeMixedDistinctAggregations(boolean value)
{
this.optimizeMixedDistinctAggregations = value;
return this;
}
}
Expand Up @@ -33,6 +33,7 @@
import com.facebook.presto.sql.planner.optimizations.MergeWindows;
import com.facebook.presto.sql.planner.optimizations.MetadataDeleteOptimizer;
import com.facebook.presto.sql.planner.optimizations.MetadataQueryOptimizer;
import com.facebook.presto.sql.planner.optimizations.OptimizeMixedDistinctAggregations;
import com.facebook.presto.sql.planner.optimizations.PartialAggregationPushDown;
import com.facebook.presto.sql.planner.optimizations.PickLayout;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
Expand Down Expand Up @@ -104,6 +105,10 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea
builder.add(new PruneUnreferencedOutputs());
}

if (featuresConfig.isOptimizeMixedDistinctAggregations()) {
builder.add(new OptimizeMixedDistinctAggregations(metadata));
}

if (!forceSingleNode) {
builder.add(new PushTableWriteThroughUnion()); // Must run before AddExchanges
builder.add(new AddExchanges(metadata, sqlParser));
Expand Down

Large diffs are not rendered by default.

Expand Up @@ -225,7 +225,7 @@ public class LocalQueryRunner

public LocalQueryRunner(Session defaultSession)
{
this(defaultSession, new FeaturesConfig(), false);
this(defaultSession, new FeaturesConfig().setOptimizeMixedDistinctAggregations(true), false);
}

public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig)
Expand Down
Expand Up @@ -53,7 +53,8 @@ public void testDefaults()
.setSpillEnabled(false)
.setOperatorMemoryLimitBeforeSpill(DataSize.valueOf("4MB"))
.setSpillerSpillPath(Paths.get(System.getProperty("java.io.tmpdir"), "presto", "spills").toString())
.setSpillerThreads(4));
.setSpillerThreads(4)
.setOptimizeMixedDistinctAggregations(false));
}

@Test
Expand All @@ -69,6 +70,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.optimize-metadata-queries", "true")
.put("optimizer.optimize-hash-generation", "false")
.put("optimizer.optimize-single-distinct", "false")
.put("optimizer.optimize-mixed-distinct-aggregations", "true")
.put("optimizer.push-table-write-through-union", "false")
.put("optimizer.processing-optimization", "columnar_dictionary")
.put("optimizer.dictionary-aggregation", "true")
Expand All @@ -90,6 +92,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.optimize-metadata-queries", "true")
.put("optimizer.optimize-hash-generation", "false")
.put("optimizer.optimize-single-distinct", "false")
.put("optimizer.optimize-mixed-distinct-aggregations", "true")
.put("optimizer.push-table-write-through-union", "false")
.put("optimizer.processing-optimization", "columnar_dictionary")
.put("optimizer.dictionary-aggregation", "true")
Expand All @@ -111,6 +114,7 @@ public void testExplicitPropertyMappings()
.setOptimizeMetadataQueries(true)
.setOptimizeHashGeneration(false)
.setOptimizeSingleDistinct(false)
.setOptimizeMixedDistinctAggregations(true)
.setPushTableWriteThroughUnion(false)
.setProcessingOptimization(COLUMNAR_DICTIONARY)
.setDictionaryAggregation(true)
Expand Down
@@ -0,0 +1,119 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.tree.FunctionCall;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.google.common.base.MoreObjects.toStringHelper;

public class AggregationMatcher
implements Matcher
{
private final List<FunctionCall> aggregations;
private final Map<Symbol, Symbol> masks;
private final List<List<Symbol>> groupingSets;
private final Optional<Symbol> groupId;

public AggregationMatcher(List<List<Symbol>> groupingSets, List<FunctionCall> aggregations, Map<Symbol, Symbol> masks, Optional<Symbol> groupId)
{
this.aggregations = aggregations;
this.masks = masks;
this.groupingSets = groupingSets;
this.groupId = groupId;
}

@Override
public boolean matches(PlanNode node, Session session, Metadata metadata, ExpressionAliases expressionAliases)
{
if (!(node instanceof AggregationNode)) {
return false;
}

AggregationNode aggregationNode = (AggregationNode) node;

if (groupId.isPresent() != aggregationNode.getGroupIdSymbol().isPresent()) {
return false;
}

if (groupingSets.size() != aggregationNode.getGroupingSets().size()) {
return false;
}

List<Symbol> aggregationsWithMask = aggregationNode.getAggregations()
.entrySet()
.stream()
.filter(entry -> entry.getValue().isDistinct())
.map(entry -> entry.getKey())
.collect(Collectors.toList());

if (aggregationsWithMask.size() != masks.keySet().size()) {
return false;
}

for (Symbol symbol : aggregationsWithMask) {
if (!masks.keySet().contains(symbol)) {
return false;
}
}

for (int i = 0; i < groupingSets.size(); i++) {
if (!matches(groupingSets.get(i), aggregationNode.getGroupingSets().get(i))) {
return false;
}
}

if (!matches(aggregations, aggregationNode.getAggregations().values().stream().collect(Collectors.toList()))) {
return false;
}

return true;
}

static <T> boolean matches(Collection<T> expected, Collection<T> actual)
{
if (expected.size() != actual.size()) {
return false;
}

for (T symbol : expected) {
if (!actual.contains(symbol)) {
return false;
}
}

return true;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("groupingSets", groupingSets)
.add("aggregations", aggregations)
.add("masks", masks)
.add("groudId", groupId)
.toString();
}
}
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.PlanNode;

import java.util.List;
import java.util.Map;

import static com.google.common.base.MoreObjects.toStringHelper;

public class GroupIdMatcher
implements Matcher
{
private final List<List<Symbol>> groups;
private final Map<Symbol, Symbol> identityMappings;

public GroupIdMatcher(List<List<Symbol>> groups, Map<Symbol, Symbol> identityMappings)
{
this.groups = groups;
this.identityMappings = identityMappings;
}

@Override
public boolean matches(PlanNode node, Session session, Metadata metadata, ExpressionAliases expressionAliases)
{
if (!(node instanceof GroupIdNode)) {
return false;
}

GroupIdNode groudIdNode = (GroupIdNode) node;
List<List<Symbol>> actualGroups = groudIdNode.getGroupingSets();
Map<Symbol, Symbol> actualArgumentMappings = groudIdNode.getArgumentMappings();

if (actualGroups.size() != groups.size()) {
return false;
}

for (int i = 0; i < actualGroups.size(); i++) {
if (!AggregationMatcher.matches(actualGroups.get(i), groups.get(i))) {
return false;
}
}

if (!AggregationMatcher.matches(identityMappings.keySet(), actualArgumentMappings.keySet())) {
return false;
}

return true;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("groups", groups)
.toString();
}
}
Expand Up @@ -17,8 +17,11 @@
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
Expand All @@ -32,6 +35,7 @@
import com.facebook.presto.sql.tree.Window;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -122,6 +126,16 @@ public static PlanMatchPattern apply(List<String> correlationSymbolAliases, Plan
return node(ApplyNode.class, inputPattern, subqueryPattern).with(new CorrelationMatcher(correlationSymbolAliases));
}

public static PlanMatchPattern groupingSet(List<List<Symbol>> groups, PlanMatchPattern source)
{
return node(GroupIdNode.class, source).with(new GroupIdMatcher(groups, ImmutableMap.of()));
}

public static PlanMatchPattern aggregation(List<List<Symbol>> groupingSets, List<FunctionCall> aggregations, Map<Symbol, Symbol> masks, Optional<Symbol> groupId, PlanMatchPattern source)
{
return node(AggregationNode.class, source).with(new AggregationMatcher(groupingSets, aggregations, masks, groupId));
}

public PlanMatchPattern(List<PlanMatchPattern> sourcePatterns)
{
requireNonNull(sourcePatterns, "sourcePatterns are null");
Expand Down

0 comments on commit 4f06521

Please sign in to comment.