Skip to content

Commit

Permalink
Refactor DistinctAggregationController
Browse files Browse the repository at this point in the history
Make different distinct aggregation strategy choices
exclusive, so that order of optimizer rules does not matter.
  • Loading branch information
lukasz-stec authored and Dith3r committed May 10, 2024
1 parent be51792 commit ad0e8c3
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@
import com.google.inject.Inject;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.TaskCountEstimator;
import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;

import static io.trino.SystemSessionProperties.distinctAggregationsStrategy;
import static io.trino.SystemSessionProperties.getTaskConcurrency;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
import static io.trino.sql.planner.iterative.rule.DistinctAggregationToGroupBy.canUsePreAggregate;
import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct;
import static java.lang.Double.isNaN;
import static java.util.Objects.requireNonNull;

Expand All @@ -41,42 +49,54 @@ public DistinctAggregationController(TaskCountEstimator taskCountEstimator)

public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Rule.Context context)
{
return !canParallelizeSingleStepDistinctAggregation(aggregationNode, context, MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER);
return chooseMarkDistinctStrategy(aggregationNode, context) == MARK_DISTINCT;
}

public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Rule.Context context)
{
if (canParallelizeSingleStepDistinctAggregation(aggregationNode, context, PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER)) {
return false;
}

// mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2
// because group-by keys are added to every grouping set and this makes partial aggregation behaves badly
return aggregationNode.getGroupingKeys().size() <= 2;
return chooseMarkDistinctStrategy(aggregationNode, context) == PRE_AGGREGATE;
}

private boolean canParallelizeSingleStepDistinctAggregation(AggregationNode aggregationNode, Rule.Context context, int maxOutputRowCountMultiplier)
private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Rule.Context context)
{
if (aggregationNode.getGroupingKeys().isEmpty()) {
// global distinct aggregation is computed using a single thread. MarkDistinct will help parallelize the execution.
return false;
DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(context.getSession());
if (distinctAggregationsStrategy != AUTOMATIC) {
if (distinctAggregationsStrategy == MARK_DISTINCT && canUseMarkDistinct(aggregationNode)) {
return MARK_DISTINCT;
}
if (distinctAggregationsStrategy == PRE_AGGREGATE && canUsePreAggregate(aggregationNode)) {
return PRE_AGGREGATE;
}
// in case strategy is chosen by the session property, but we cannot use it, lets fallback to single-step
return SINGLE_STEP;
}
double numberOfDistinctValues = getMinDistinctValueCountEstimate(aggregationNode, context);
if (Double.isNaN(numberOfDistinctValues)) {
// if the estimate is unknown, use MarkDistinct to avoid query failure
return false;
}
int maxNumberOfConcurrentThreadsForAggregation = getMaxNumberOfConcurrentThreadsForAggregation(context);

if (numberOfDistinctValues <= maxOutputRowCountMultiplier * maxNumberOfConcurrentThreadsForAggregation) {
// small numberOfDistinctValues reduces the distinct aggregation parallelism, also because the partitioning may be skewed.
// This makes query to underutilize the cluster CPU but also to possibly concentrate memory on few nodes.
// MarkDistinct should increase the parallelism at a cost of CPU.
return false;
// use single_step if it can be parallelized
// small numberOfDistinctValues reduces the distinct aggregation parallelism, also because the partitioning may be skewed.
// this makes query to underutilize the cluster CPU but also to possibly concentrate memory on few nodes.
// single_step alternatives should increase the parallelism at a cost of CPU.
if (!aggregationNode.getGroupingKeys().isEmpty() && // global distinct aggregation is computed using a single thread. Strategies other than single_step will help parallelize the execution.
!isNaN(numberOfDistinctValues) && // if the estimate is unknown, use alternatives to avoid query failure
(numberOfDistinctValues > PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * maxNumberOfConcurrentThreadsForAggregation ||
(numberOfDistinctValues > MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * maxNumberOfConcurrentThreadsForAggregation &&
// if the NDV and the number of grouping keys is small, pre-aggregate is faster than single_step at a cost of CPU
aggregationNode.getGroupingKeys().size() > 2))) {
return SINGLE_STEP;
}

// mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2
// because group-by keys are added to every grouping set and this makes partial aggregation behaves badly
if (canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) {
return PRE_AGGREGATE;
}
else if (canUseMarkDistinct(aggregationNode)) {
return MARK_DISTINCT;
}

// can parallelize single-step, and single-step distinct is more efficient than alternatives
return true;
// if no strategy found, use single_step by default
return SINGLE_STEP;
}

private int getMaxNumberOfConcurrentThreadsForAggregation(Rule.Context context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -83,16 +82,18 @@ public class DistinctAggregationToGroupBy
private static final CatalogSchemaFunctionName APPROX_DISTINCT_NAME = builtinFunctionName("approx_distinct");

private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(Predicates.and(
Predicates.or(
// single distinct can be supported in this rule, but it is already supported by SingleDistinctAggregationToGroupBy, which produces simpler plans (without group-id)
DistinctAggregationToGroupBy::hasMultipleDistincts,
DistinctAggregationToGroupBy::hasMixedDistinctAndNonDistincts),
DistinctAggregationToGroupBy::allDistinctAggregationsHaveSingleArgument,
DistinctAggregationToGroupBy::noFilters,
DistinctAggregationToGroupBy::noMasks,
aggregation -> !aggregation.hasOrderings(),
aggregation -> aggregation.getStep().equals(SINGLE)));
.matching(DistinctAggregationToGroupBy::canUsePreAggregate);

public static boolean canUsePreAggregate(AggregationNode aggregationNode)
{
// single distinct can be supported in this rule, but it is already supported by SingleDistinctAggregationToGroupBy, which produces simpler plans (without group-id)
return (hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode)) &&
allDistinctAggregationsHaveSingleArgument(aggregationNode) &&
noFilters(aggregationNode) &&
noMasks(aggregationNode) &&
!aggregationNode.hasOrderings() &&
aggregationNode.getStep().equals(SINGLE);
}

public static boolean hasMultipleDistincts(AggregationNode aggregationNode)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -67,12 +66,13 @@ public class MultipleDistinctAggregationToMarkDistinct
implements Rule<AggregationNode>
{
private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(
Predicates.and(
MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask,
Predicates.or(
MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts,
MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));
.matching(MultipleDistinctAggregationToMarkDistinct::canUseMarkDistinct);

public static boolean canUseMarkDistinct(AggregationNode aggregationNode)
{
return hasNoDistinctWithFilterOrMask(aggregationNode) &&
(hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode));
}

private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode)
{
Expand Down
Loading

0 comments on commit ad0e8c3

Please sign in to comment.