Skip to content

Commit

Permalink
Enable MultipleDistinctAggregationsToSubqueries
Browse files Browse the repository at this point in the history
Make MultipleDistinctAggregationsToSubqueries to fire when
distinct_aggregations_strategy=AUTOMATIC and we can be
confident based on stats that the rule will be beneficial.
Aggregation source is limited to table scan, filter,
and project.
  • Loading branch information
lukasz-stec authored and Dith3r committed May 10, 2024
1 parent ad0e8c3 commit 0d37869
Show file tree
Hide file tree
Showing 11 changed files with 818 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ public PlanOptimizers(
new RemoveRedundantExists(),
new RemoveRedundantWindow(),
new ImplementFilteredAggregations(),
new MultipleDistinctAggregationsToSubqueries(metadata),
new SingleDistinctAggregationToGroupBy(),
new MergeLimitWithDistinct(),
new PruneCountAggregationOverScalar(metadata),
Expand Down Expand Up @@ -685,6 +684,11 @@ public PlanOptimizers(
new RemoveRedundantIdentityProjections(),
new PushAggregationThroughOuterJoin(),
new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants
// Run this after PredicatePushDown and PushProjectionIntoTableScan as it uses stats, and those two rules may reduce the number of partitions
// and columns we need stats for thus reducing the overhead of reading statistics from the metastore.
new MultipleDistinctAggregationsToSubqueries(distinctAggregationController),
// Run SingleDistinctAggregationToGroupBy after MultipleDistinctAggregationsToSubqueries to ensure the single column distinct is optimized
new SingleDistinctAggregationToGroupBy(),
new DistinctAggregationToGroupBy(plannerContext, distinctAggregationController), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector
// It also is run before MultipleDistinctAggregationToMarkDistinct to take precedence f enabled
new ImplementFilteredAggregations(), // DistinctAggregationToGroupBy will add filters if fired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,40 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.inject.Inject;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;

import java.util.List;
import java.util.Set;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.distinctAggregationsStrategy;
import static io.trino.SystemSessionProperties.getTaskConcurrency;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
import static io.trino.sql.planner.iterative.rule.DistinctAggregationToGroupBy.canUsePreAggregate;
import static io.trino.sql.planner.iterative.rule.DistinctAggregationToGroupBy.distinctAggregationsUniqueArgumentCount;
import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct;
import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static java.lang.Double.isNaN;
import static java.util.Objects.requireNonNull;

Expand All @@ -38,13 +57,16 @@ public class DistinctAggregationController
{
private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8;
private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * 8;
private static final double MAX_JOIN_GROUPING_KEYS_SIZE = 100 * 1024 * 1024; // 100 MB

private final TaskCountEstimator taskCountEstimator;
private final Metadata metadata;

@Inject
public DistinctAggregationController(TaskCountEstimator taskCountEstimator)
public DistinctAggregationController(TaskCountEstimator taskCountEstimator, Metadata metadata)
{
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
}

public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Rule.Context context)
Expand All @@ -57,6 +79,11 @@ public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Rule.Conte
return chooseMarkDistinctStrategy(aggregationNode, context) == PRE_AGGREGATE;
}

public boolean shouldSplitToSubqueries(AggregationNode aggregationNode, Rule.Context context)
{
return chooseMarkDistinctStrategy(aggregationNode, context) == SPLIT_TO_SUBQUERIES;
}

private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Rule.Context context)
{
DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(context.getSession());
Expand All @@ -67,6 +94,9 @@ private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode
if (distinctAggregationsStrategy == PRE_AGGREGATE && canUsePreAggregate(aggregationNode)) {
return PRE_AGGREGATE;
}
if (distinctAggregationsStrategy == SPLIT_TO_SUBQUERIES && isAggregationCandidateForSplittingToSubqueries(aggregationNode) && isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), context)) {
return SPLIT_TO_SUBQUERIES;
}
// in case strategy is chosen by the session property, but we cannot use it, lets fallback to single-step
return SINGLE_STEP;
}
Expand All @@ -86,6 +116,12 @@ private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode
return SINGLE_STEP;
}

if (isAggregationCandidateForSplittingToSubqueries(aggregationNode) && shouldSplitAggregationToSubqueries(aggregationNode, context)) {
// for simple distinct aggregations on top of table scan it makes sense to split the aggregation into multiple subqueries,
// so they can be handled by the SingleDistinctAggregationToGroupBy and use other single column optimizations
return SPLIT_TO_SUBQUERIES;
}

// mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2
// because group-by keys are added to every grouping set and this makes partial aggregation behaves badly
if (canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) {
Expand Down Expand Up @@ -116,4 +152,103 @@ private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode,
.map(symbol -> sourceStats.getSymbolStatistics(symbol).getDistinctValuesCount())
.max(Double::compareTo).orElse(Double.NaN);
}

// Since, to avoid degradation caused by multiple table scans, we want to split to sub-queries only if we are confident
// it brings big benefits, we are fairly conservative in the decision below.
private boolean shouldSplitAggregationToSubqueries(AggregationNode aggregationNode, Rule.Context context)
{
if (!isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), context)) {
// only table scan, union, filter and project are supported
return false;
}

if (searchFrom(aggregationNode.getSource(), context.getLookup()).whereIsInstanceOfAny(UnionNode.class).findFirst().isPresent()) {
// supporting union with auto decision is complex
return false;
}

// skip if the source has a filter with low selectivity, as the scan and filter can
// be the main bottleneck in this case, and we want to avoid duplicating this effort.
if (searchFrom(aggregationNode.getSource(), context.getLookup())
.where(node -> node instanceof FilterNode filterNode && isSelective(filterNode, context.getStatsProvider()))
.matches()) {
return false;
}

if (isAdditionalReadOverheadTooExpensive(aggregationNode, context)) {
return false;
}

if (aggregationNode.hasSingleGlobalAggregation()) {
return true;
}

PlanNodeStatsEstimate stats = context.getStatsProvider().getStats(aggregationNode);
double groupingKeysSizeInBytes = stats.getOutputSizeInBytes(aggregationNode.getGroupingKeys());

// estimated group by result size is big so that both calculating aggregation multiple times and join would be inefficient
return !(isNaN(groupingKeysSizeInBytes) || groupingKeysSizeInBytes > MAX_JOIN_GROUPING_KEYS_SIZE);
}

private static boolean isAdditionalReadOverheadTooExpensive(AggregationNode aggregationNode, Rule.Context context)
{
Set<Symbol> distinctInputs = aggregationNode.getAggregations()
.values().stream()
.filter(AggregationNode.Aggregation::isDistinct)
.flatMap(aggregation -> aggregation.getArguments().stream())
.filter(expression -> expression instanceof Reference)
.map(Symbol::from)
.collect(toImmutableSet());

TableScanNode tableScanNode = (TableScanNode) searchFrom(aggregationNode.getSource(), context.getLookup()).whereIsInstanceOfAny(TableScanNode.class).findOnlyElement();
Set<Symbol> additionalColumns = Sets.difference(ImmutableSet.copyOf(tableScanNode.getOutputSymbols()), distinctInputs);

// Group by columns need to read N times, where N is number of sub-queries.
// Distinct columns are read once.
double singleTableScanDataSize = context.getStatsProvider().getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols());
double additionalColumnsDataSize = context.getStatsProvider().getStats(tableScanNode).getOutputSizeInBytes(additionalColumns);
long subqueryCount = distinctAggregationsUniqueArgumentCount(aggregationNode);
double distinctInputDataSize = singleTableScanDataSize - additionalColumnsDataSize;
double subqueriesTotalDataSize = additionalColumnsDataSize * subqueryCount + distinctInputDataSize;

return isNaN(subqueriesTotalDataSize) ||
isNaN(singleTableScanDataSize) ||
// we would read more than 50% more data
subqueriesTotalDataSize / singleTableScanDataSize > 1.5;
}

private static boolean isSelective(FilterNode filterNode, StatsProvider statsProvider)
{
double filterOutputRowCount = statsProvider.getStats(filterNode).getOutputRowCount();
double filterSourceRowCount = statsProvider.getStats(filterNode.getSource()).getOutputRowCount();
return filterOutputRowCount / filterSourceRowCount < 0.5;
}

// Only table scan, union, filter and project are supported.
// PlanCopier.copyPlan must support all supported nodes here.
// Additionally, we should split the table scan only if reading single columns is efficient in the given connector.
private boolean isAggregationSourceSupportedForSubqueries(PlanNode source, Rule.Context context)
{
if (searchFrom(source, context.getLookup())
.where(node -> !(node instanceof TableScanNode
|| node instanceof FilterNode
|| node instanceof ProjectNode
|| node instanceof UnionNode))
.findFirst()
.isPresent()) {
return false;
}

List<PlanNode> tableScans = searchFrom(source, context.getLookup())
.whereIsInstanceOfAny(TableScanNode.class)
.findAll();

if (tableScans.isEmpty()) {
// at least one table scan is expected
return false;
}

return tableScans.stream()
.allMatch(tableScanNode -> metadata.isColumnarTableScan(context.getSession(), ((TableScanNode) tableScanNode).getTable()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ public static boolean canUsePreAggregate(AggregationNode aggregationNode)
}

public static boolean hasMultipleDistincts(AggregationNode aggregationNode)
{
return distinctAggregationsUniqueArgumentCount(aggregationNode) > 1;
}

public static long distinctAggregationsUniqueArgumentCount(AggregationNode aggregationNode)
{
return aggregationNode.getAggregations()
.values().stream()
.filter(Aggregation::isDistinct)
.map(Aggregation::getArguments)
.map(HashSet::new)
.distinct()
.count() > 1;
.count();
}

private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.NodeAndMappings;
import io.trino.sql.planner.PlanCopier;
Expand All @@ -29,13 +26,10 @@
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinNode.EquiJoinClause;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;

import java.util.Comparator;
import java.util.HashMap;
Expand All @@ -49,9 +43,6 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.distinctAggregationsStrategy;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.sql.planner.plan.JoinType.INNER;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -89,38 +80,38 @@ public class MultipleDistinctAggregationsToSubqueries
implements Rule<AggregationNode>
{
private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(
Predicates.and(
// TODO: we could support non-distinct aggregations if SingleDistinctAggregationToGroupBy supports it
SingleDistinctAggregationToGroupBy::allDistinctAggregates,
DistinctAggregationToGroupBy::hasMultipleDistincts,
// if we have more than one grouping set, we can have duplicated grouping sets and handling this is complex
aggregation -> aggregation.getGroupingSetCount() == 1,
// hash symbol is added late in the planning, and handling it here would increase complexity
aggregation -> aggregation.getHashSymbol().isEmpty()));
private final Metadata metadata;

public MultipleDistinctAggregationsToSubqueries(Metadata metadata)
.matching(MultipleDistinctAggregationsToSubqueries::isAggregationCandidateForSplittingToSubqueries);

// In addition to this check, DistinctAggregationController.isAggregationSourceSupportedForSubqueries, that accesses Metadata,
// needs also pass, for the plan to be applicable for this rule,
public static boolean isAggregationCandidateForSplittingToSubqueries(AggregationNode aggregationNode)
{
this.metadata = requireNonNull(metadata, "metadata is null");
// TODO: we could support non-distinct aggregations if SingleDistinctAggregationToGroupBy supports it
return SingleDistinctAggregationToGroupBy.allDistinctAggregates(aggregationNode) &&
DistinctAggregationToGroupBy.hasMultipleDistincts(aggregationNode) &&
// if we have more than one grouping set, we can have duplicated grouping sets and handling this is complex
aggregationNode.getGroupingSetCount() == 1 &&
// hash symbol is added late in the planning, and handling it here would increase complexity
aggregationNode.getHashSymbol().isEmpty();
}

@Override
public Pattern<AggregationNode> getPattern()
private final DistinctAggregationController distinctAggregationController;

public MultipleDistinctAggregationsToSubqueries(DistinctAggregationController distinctAggregationController)
{
return PATTERN;
this.distinctAggregationController = requireNonNull(distinctAggregationController, "distinctAggregationController is null");
}

@Override
public boolean isEnabled(Session session)
public Pattern<AggregationNode> getPattern()
{
return distinctAggregationsStrategy(session).equals(SPLIT_TO_SUBQUERIES);
return PATTERN;
}

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
{
if (!isAggregationSourceSupported(aggregationNode.getSource(), context)) {
if (!distinctAggregationController.shouldSplitToSubqueries(aggregationNode, context)) {
return Result.empty();
}
// group aggregations by arguments
Expand Down Expand Up @@ -212,24 +203,4 @@ private JoinNode buildJoin(PlanNode left, List<Symbol> leftJoinSymbols, PlanNode
ImmutableMap.of(),
Optional.empty());
}

// PlanCopier.copyPlan must support all supported nodes here.
// Additionally, we should split the table scan only if reading single columns is efficient in the given connector.
private boolean isAggregationSourceSupported(PlanNode source, Context context)
{
if (searchFrom(source, context.getLookup())
.where(node -> !(node instanceof TableScanNode
|| node instanceof FilterNode
|| node instanceof ProjectNode
|| node instanceof UnionNode))
.findFirst()
.isPresent()) {
return false;
}

return searchFrom(source, context.getLookup())
.whereIsInstanceOfAny(TableScanNode.class)
.findAll().stream()
.allMatch(tableScanNode -> metadata.isColumnarTableScan(context.getSession(), ((TableScanNode) tableScanNode).getTable()));
}
}

0 comments on commit 0d37869

Please sign in to comment.