diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 15c14af2bb5002..f4c39e15caba1f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -457,7 +457,6 @@ public PlanOptimizers( new RemoveRedundantExists(), new RemoveRedundantWindow(), new ImplementFilteredAggregations(), - new MultipleDistinctAggregationsToSubqueries(metadata), new SingleDistinctAggregationToGroupBy(), new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata), @@ -685,6 +684,11 @@ public PlanOptimizers( new RemoveRedundantIdentityProjections(), new PushAggregationThroughOuterJoin(), new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants + // Run this after PredicatePushDown and PushProjectionIntoTableScan as it uses stats, and those two rules may reduce the number of partitions + // and columns we need stats for thus reducing the overhead of reading statistics from the metastore. + new MultipleDistinctAggregationsToSubqueries(distinctAggregationController), + // Run SingleDistinctAggregationToGroupBy after MultipleDistinctAggregationsToSubqueries to ensure the single column distinct is optimized + new SingleDistinctAggregationToGroupBy(), new DistinctAggregationToGroupBy(plannerContext, distinctAggregationController), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector // It also is run before MultipleDistinctAggregationToMarkDistinct to take precedence f enabled new ImplementFilteredAggregations(), // DistinctAggregationToGroupBy will add filters if fired diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationController.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationController.java index d3e565921c4997..2ecf17fbd6cfa1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationController.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationController.java @@ -13,21 +13,40 @@ */ package io.trino.sql.planner.iterative.rule; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.inject.Inject; import io.trino.cost.PlanNodeStatsEstimate; +import io.trino.cost.StatsProvider; import io.trino.cost.TaskCountEstimator; +import io.trino.metadata.Metadata; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.UnionNode; +import java.util.List; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.SystemSessionProperties.distinctAggregationsStrategy; import static io.trino.SystemSessionProperties.getTaskConcurrency; import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC; import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT; import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE; import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES; import static io.trino.sql.planner.iterative.rule.DistinctAggregationToGroupBy.canUsePreAggregate; +import static io.trino.sql.planner.iterative.rule.DistinctAggregationToGroupBy.distinctAggregationsUniqueArgumentCount; import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct; +import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static java.lang.Double.isNaN; import static java.util.Objects.requireNonNull; @@ -38,13 +57,16 @@ public class DistinctAggregationController { private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8; private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * 8; + private static final double MAX_JOIN_GROUPING_KEYS_SIZE = 100 * 1024 * 1024; // 100 MB private final TaskCountEstimator taskCountEstimator; + private final Metadata metadata; @Inject - public DistinctAggregationController(TaskCountEstimator taskCountEstimator) + public DistinctAggregationController(TaskCountEstimator taskCountEstimator, Metadata metadata) { this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); } public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Rule.Context context) @@ -57,6 +79,11 @@ public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Rule.Conte return chooseMarkDistinctStrategy(aggregationNode, context) == PRE_AGGREGATE; } + public boolean shouldSplitToSubqueries(AggregationNode aggregationNode, Rule.Context context) + { + return chooseMarkDistinctStrategy(aggregationNode, context) == SPLIT_TO_SUBQUERIES; + } + private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Rule.Context context) { DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(context.getSession()); @@ -67,6 +94,9 @@ private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode if (distinctAggregationsStrategy == PRE_AGGREGATE && canUsePreAggregate(aggregationNode)) { return PRE_AGGREGATE; } + if (distinctAggregationsStrategy == SPLIT_TO_SUBQUERIES && isAggregationCandidateForSplittingToSubqueries(aggregationNode) && isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), context)) { + return SPLIT_TO_SUBQUERIES; + } // in case strategy is chosen by the session property, but we cannot use it, lets fallback to single-step return SINGLE_STEP; } @@ -86,6 +116,12 @@ private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode return SINGLE_STEP; } + if (isAggregationCandidateForSplittingToSubqueries(aggregationNode) && shouldSplitAggregationToSubqueries(aggregationNode, context)) { + // for simple distinct aggregations on top of table scan it makes sense to split the aggregation into multiple subqueries, + // so they can be handled by the SingleDistinctAggregationToGroupBy and use other single column optimizations + return SPLIT_TO_SUBQUERIES; + } + // mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2 // because group-by keys are added to every grouping set and this makes partial aggregation behaves badly if (canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) { @@ -116,4 +152,103 @@ private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode, .map(symbol -> sourceStats.getSymbolStatistics(symbol).getDistinctValuesCount()) .max(Double::compareTo).orElse(Double.NaN); } + + // Since, to avoid degradation caused by multiple table scans, we want to split to sub-queries only if we are confident + // it brings big benefits, we are fairly conservative in the decision below. + private boolean shouldSplitAggregationToSubqueries(AggregationNode aggregationNode, Rule.Context context) + { + if (!isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), context)) { + // only table scan, union, filter and project are supported + return false; + } + + if (searchFrom(aggregationNode.getSource(), context.getLookup()).whereIsInstanceOfAny(UnionNode.class).findFirst().isPresent()) { + // supporting union with auto decision is complex + return false; + } + + // skip if the source has a filter with low selectivity, as the scan and filter can + // be the main bottleneck in this case, and we want to avoid duplicating this effort. + if (searchFrom(aggregationNode.getSource(), context.getLookup()) + .where(node -> node instanceof FilterNode filterNode && isSelective(filterNode, context.getStatsProvider())) + .matches()) { + return false; + } + + if (isAdditionalReadOverheadTooExpensive(aggregationNode, context)) { + return false; + } + + if (aggregationNode.hasSingleGlobalAggregation()) { + return true; + } + + PlanNodeStatsEstimate stats = context.getStatsProvider().getStats(aggregationNode); + double groupingKeysSizeInBytes = stats.getOutputSizeInBytes(aggregationNode.getGroupingKeys()); + + // estimated group by result size is big so that both calculating aggregation multiple times and join would be inefficient + return !(isNaN(groupingKeysSizeInBytes) || groupingKeysSizeInBytes > MAX_JOIN_GROUPING_KEYS_SIZE); + } + + private static boolean isAdditionalReadOverheadTooExpensive(AggregationNode aggregationNode, Rule.Context context) + { + Set distinctInputs = aggregationNode.getAggregations() + .values().stream() + .filter(AggregationNode.Aggregation::isDistinct) + .flatMap(aggregation -> aggregation.getArguments().stream()) + .filter(expression -> expression instanceof Reference) + .map(Symbol::from) + .collect(toImmutableSet()); + + TableScanNode tableScanNode = (TableScanNode) searchFrom(aggregationNode.getSource(), context.getLookup()).whereIsInstanceOfAny(TableScanNode.class).findOnlyElement(); + Set additionalColumns = Sets.difference(ImmutableSet.copyOf(tableScanNode.getOutputSymbols()), distinctInputs); + + // Group by columns need to read N times, where N is number of sub-queries. + // Distinct columns are read once. + double singleTableScanDataSize = context.getStatsProvider().getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols()); + double additionalColumnsDataSize = context.getStatsProvider().getStats(tableScanNode).getOutputSizeInBytes(additionalColumns); + long subqueryCount = distinctAggregationsUniqueArgumentCount(aggregationNode); + double distinctInputDataSize = singleTableScanDataSize - additionalColumnsDataSize; + double subqueriesTotalDataSize = additionalColumnsDataSize * subqueryCount + distinctInputDataSize; + + return isNaN(subqueriesTotalDataSize) || + isNaN(singleTableScanDataSize) || + // we would read more than 50% more data + subqueriesTotalDataSize / singleTableScanDataSize > 1.5; + } + + private static boolean isSelective(FilterNode filterNode, StatsProvider statsProvider) + { + double filterOutputRowCount = statsProvider.getStats(filterNode).getOutputRowCount(); + double filterSourceRowCount = statsProvider.getStats(filterNode.getSource()).getOutputRowCount(); + return filterOutputRowCount / filterSourceRowCount < 0.5; + } + + // Only table scan, union, filter and project are supported. + // PlanCopier.copyPlan must support all supported nodes here. + // Additionally, we should split the table scan only if reading single columns is efficient in the given connector. + private boolean isAggregationSourceSupportedForSubqueries(PlanNode source, Rule.Context context) + { + if (searchFrom(source, context.getLookup()) + .where(node -> !(node instanceof TableScanNode + || node instanceof FilterNode + || node instanceof ProjectNode + || node instanceof UnionNode)) + .findFirst() + .isPresent()) { + return false; + } + + List tableScans = searchFrom(source, context.getLookup()) + .whereIsInstanceOfAny(TableScanNode.class) + .findAll(); + + if (tableScans.isEmpty()) { + // at least one table scan is expected + return false; + } + + return tableScans.stream() + .allMatch(tableScanNode -> metadata.isColumnarTableScan(context.getSession(), ((TableScanNode) tableScanNode).getTable())); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationToGroupBy.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationToGroupBy.java index 936ec01cc179b6..6862f70fe34fc6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationToGroupBy.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationToGroupBy.java @@ -96,6 +96,11 @@ public static boolean canUsePreAggregate(AggregationNode aggregationNode) } public static boolean hasMultipleDistincts(AggregationNode aggregationNode) + { + return distinctAggregationsUniqueArgumentCount(aggregationNode) > 1; + } + + public static long distinctAggregationsUniqueArgumentCount(AggregationNode aggregationNode) { return aggregationNode.getAggregations() .values().stream() @@ -103,7 +108,7 @@ public static boolean hasMultipleDistincts(AggregationNode aggregationNode) .map(Aggregation::getArguments) .map(HashSet::new) .distinct() - .count() > 1; + .count(); } private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java index 68480429fbdb0a..fb517867ca2dc9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java @@ -13,14 +13,11 @@ */ package io.trino.sql.planner.iterative.rule; -import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.trino.Session; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.sql.ir.Expression; import io.trino.sql.planner.NodeAndMappings; import io.trino.sql.planner.PlanCopier; @@ -29,13 +26,10 @@ import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; -import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.JoinNode.EquiJoinClause; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.ProjectNode; -import io.trino.sql.planner.plan.TableScanNode; -import io.trino.sql.planner.plan.UnionNode; import java.util.Comparator; import java.util.HashMap; @@ -49,9 +43,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.SystemSessionProperties.distinctAggregationsStrategy; -import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES; -import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.Patterns.aggregation; import static java.util.Objects.requireNonNull; @@ -89,38 +80,38 @@ public class MultipleDistinctAggregationsToSubqueries implements Rule { private static final Pattern PATTERN = aggregation() - .matching( - Predicates.and( - // TODO: we could support non-distinct aggregations if SingleDistinctAggregationToGroupBy supports it - SingleDistinctAggregationToGroupBy::allDistinctAggregates, - DistinctAggregationToGroupBy::hasMultipleDistincts, - // if we have more than one grouping set, we can have duplicated grouping sets and handling this is complex - aggregation -> aggregation.getGroupingSetCount() == 1, - // hash symbol is added late in the planning, and handling it here would increase complexity - aggregation -> aggregation.getHashSymbol().isEmpty())); - private final Metadata metadata; - - public MultipleDistinctAggregationsToSubqueries(Metadata metadata) + .matching(MultipleDistinctAggregationsToSubqueries::isAggregationCandidateForSplittingToSubqueries); + + // In addition to this check, DistinctAggregationController.isAggregationSourceSupportedForSubqueries, that accesses Metadata, + // needs also pass, for the plan to be applicable for this rule, + public static boolean isAggregationCandidateForSplittingToSubqueries(AggregationNode aggregationNode) { - this.metadata = requireNonNull(metadata, "metadata is null"); + // TODO: we could support non-distinct aggregations if SingleDistinctAggregationToGroupBy supports it + return SingleDistinctAggregationToGroupBy.allDistinctAggregates(aggregationNode) && + DistinctAggregationToGroupBy.hasMultipleDistincts(aggregationNode) && + // if we have more than one grouping set, we can have duplicated grouping sets and handling this is complex + aggregationNode.getGroupingSetCount() == 1 && + // hash symbol is added late in the planning, and handling it here would increase complexity + aggregationNode.getHashSymbol().isEmpty(); } - @Override - public Pattern getPattern() + private final DistinctAggregationController distinctAggregationController; + + public MultipleDistinctAggregationsToSubqueries(DistinctAggregationController distinctAggregationController) { - return PATTERN; + this.distinctAggregationController = requireNonNull(distinctAggregationController, "distinctAggregationController is null"); } @Override - public boolean isEnabled(Session session) + public Pattern getPattern() { - return distinctAggregationsStrategy(session).equals(SPLIT_TO_SUBQUERIES); + return PATTERN; } @Override public Result apply(AggregationNode aggregationNode, Captures captures, Context context) { - if (!isAggregationSourceSupported(aggregationNode.getSource(), context)) { + if (!distinctAggregationController.shouldSplitToSubqueries(aggregationNode, context)) { return Result.empty(); } // group aggregations by arguments @@ -212,24 +203,4 @@ private JoinNode buildJoin(PlanNode left, List leftJoinSymbols, PlanNode ImmutableMap.of(), Optional.empty()); } - - // PlanCopier.copyPlan must support all supported nodes here. - // Additionally, we should split the table scan only if reading single columns is efficient in the given connector. - private boolean isAggregationSourceSupported(PlanNode source, Context context) - { - if (searchFrom(source, context.getLookup()) - .where(node -> !(node instanceof TableScanNode - || node instanceof FilterNode - || node instanceof ProjectNode - || node instanceof UnionNode)) - .findFirst() - .isPresent()) { - return false; - } - - return searchFrom(source, context.getLookup()) - .whereIsInstanceOfAny(TableScanNode.class) - .findAll().stream() - .allMatch(tableScanNode -> metadata.isColumnarTableScan(context.getSession(), ((TableScanNode) tableScanNode).getTable())); - } } diff --git a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java index f2cdfc6a02d42a..9e00f5fc2f0baf 100644 --- a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java +++ b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java @@ -843,7 +843,7 @@ public PlanOptimizersFactory getPlanOptimizersFactory(boolean forceSingleNode) estimatedExchangesCostCalculator, new CostComparator(optimizerConfig), taskCountEstimator, - new DistinctAggregationController(taskCountEstimator), + new DistinctAggregationController(taskCountEstimator, plannerContext.getMetadata()), nodePartitioningManager, new RuleStatsRecorder()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index be39a108e83dca..67fc9d321b55f7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -95,6 +95,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.toOptional; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.SystemSessionProperties.COST_ESTIMATION_WORKER_COUNT; import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; import static io.trino.SystemSessionProperties.DISTRIBUTED_SORT; import static io.trino.SystemSessionProperties.FILTERING_SEMI_JOIN_TO_INNER; @@ -142,6 +143,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.groupId; import static io.trino.sql.planner.assertions.PlanMatchPattern.identityProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.limit; @@ -159,6 +161,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.strictConstrainedTableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.symbol; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.topN; import static io.trino.sql.planner.assertions.PlanMatchPattern.topNRanking; @@ -205,6 +208,7 @@ public class TestLogicalPlanner private static final ResolvedFunction LOWER = FUNCTIONS.resolveFunction("lower", fromTypes(VARCHAR)); private static final ResolvedFunction COMBINE_HASH = FUNCTIONS.resolveFunction("combine_hash", fromTypes(BIGINT, BIGINT)); private static final ResolvedFunction HASH_CODE = createTestMetadataManager().resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(BIGINT)); + private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", fromTypes(VARCHAR, VARCHAR)); private static final WindowNode.Frame ROWS_FROM_CURRENT = new WindowNode.Frame( ROWS, @@ -427,6 +431,102 @@ public void testDistinctOverConstants() tableScan("orders", ImmutableMap.of("orderstatus", "orderstatus")))))); } + @Test + public void testSingleDistinct() + { + assertPlan("SELECT custkey, orderstatus, COUNT(DISTINCT orderkey) FROM orders GROUP BY custkey, orderstatus", + anyTree( + aggregation( + singleGroupingSet("custkey", "orderstatus"), + ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of("orderkey"))), + aggregation( + singleGroupingSet("custkey", "orderstatus", "orderkey"), + ImmutableMap.of(), + Optional.empty(), + FINAL, + exchange(aggregation( + singleGroupingSet("custkey", "orderstatus", "orderkey"), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + tableScan( + "orders", + ImmutableMap.of("orderstatus", "orderstatus", "custkey", "custkey", "orderkey", "orderkey")))))))); + } + + @Test + public void testPreAggregateDistinct() + { + assertPlan("SELECT COUNT(DISTINCT orderkey), COUNT(DISTINCT custkey) FROM orders", + anyTree( + aggregation( + singleGroupingSet(), + ImmutableMap.of(Optional.of("count1"), aggregationFunction("count", false, ImmutableList.of(symbol("orderkey"))), + Optional.of("count2"), aggregationFunction("count", false, ImmutableList.of(symbol("custkey")))), + ImmutableList.of(), + ImmutableList.of("gid-filter-0", "gid-filter-1"), + Optional.empty(), + SINGLE, + project( + ImmutableMap.of( + "gid-filter-0", expression(new Comparison(EQUAL, new Reference(BIGINT, "groupId"), new Constant(BIGINT, 0L))), + "gid-filter-1", expression(new Comparison(EQUAL, new Reference(BIGINT, "groupId"), new Constant(BIGINT, 1L)))), + aggregation( + singleGroupingSet("custkey", "orderkey", "groupId"), + ImmutableMap.of(), + Optional.empty(), + FINAL, + exchange(aggregation( + singleGroupingSet("orderkey", "custkey", "groupId"), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + filter( + new In(new Reference(BIGINT, "groupId"), ImmutableList.of(new Constant(BIGINT, 0L), new Constant(BIGINT, 1L))), + groupId( + ImmutableList.of(ImmutableList.of("orderkey"), ImmutableList.of("custkey")), + "groupId", + tableScan( + "orders", + ImmutableMap.of("custkey", "custkey", "orderkey", "orderkey"))))))))))); + } + + @Test + public void testMultipleDistinctUsingMarkDistinct() + { + assertPlan("SELECT orderstatus, orderstatus || '1', orderstatus || '2', COUNT(DISTINCT orderkey), COUNT(DISTINCT custkey) FROM orders GROUP BY 1, 2, 3", + Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty(COST_ESTIMATION_WORKER_COUNT, "6") + .build(), + anyTree( + aggregation( + singleGroupingSet("orderstatus", "orderstatus1", "orderstatus2"), + ImmutableMap.of(Optional.of("count1"), aggregationFunction("count", false, ImmutableList.of(symbol("custkey"))), + Optional.of("count2"), aggregationFunction("count", false, ImmutableList.of(symbol("orderkey")))), + ImmutableList.of(), + ImmutableList.of("custkey_mask", "orderkey_mask"), + Optional.empty(), + SINGLE, + markDistinct( + "custkey_mask", + ImmutableList.of("orderstatus", "orderstatus1", "orderstatus2", "custkey"), + markDistinct( + "orderkey_mask", + ImmutableList.of("orderstatus", "orderstatus1", "orderstatus2", "orderkey"), + exchange( + project( + ImmutableMap.of( + "orderstatus1", expression(new Call(CONCAT, ImmutableList.of( + new Cast(new Reference(createVarcharType(1), "orderstatus"), VARCHAR), + new Constant(VARCHAR, utf8Slice("1"))))), + "orderstatus2", expression(new Call(CONCAT, ImmutableList.of( + new Cast(new Reference(createVarcharType(1), "orderstatus"), VARCHAR), + new Constant(VARCHAR, utf8Slice("2")))))), + tableScan( + "orders", + ImmutableMap.of("custkey", "custkey", "orderkey", "orderkey", "orderstatus", "orderstatus"))))))))); + } + @Test public void testInnerInequalityJoinNoEquiJoinConjuncts() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationController.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationController.java index 062a738ea804d5..07409028e631c1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationController.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationController.java @@ -22,19 +22,25 @@ import io.trino.cost.SymbolStatsEstimate; import io.trino.cost.TaskCountEstimator; import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.Metadata; +import io.trino.metadata.TableHandle; import io.trino.metadata.TestingFunctionResolution; +import io.trino.security.AllowAllAccessControl; import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.PlannerContext; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.iterative.rule.TestMultipleDistinctAggregationsToSubqueries.DelegatingMetadata; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.TableScanNode; -import io.trino.sql.planner.plan.ValuesNode; +import io.trino.transaction.TestingTransactionManager; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.parallel.Execution; @@ -53,10 +59,13 @@ import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT; import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE; import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES; +import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.TransactionBuilder.transaction; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -68,15 +77,34 @@ public class TestDistinctAggregationController private static final int NODE_COUNT = 6; private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> NODE_COUNT); private static final TestingFunctionResolution functionResolution = new TestingFunctionResolution(); + private TestingTransactionManager transactionManager; + private Metadata metadata; + + @BeforeAll + public final void setUp() + { + this.transactionManager = new TestingTransactionManager(); + PlannerContext plannerContext = plannerContextBuilder() + .withTransactionManager(transactionManager) + .build(); + this.metadata = new DelegatingMetadata(plannerContext.getMetadata()) + { + @Override + public boolean isColumnarTableScan(Session session, TableHandle tableHandle) + { + return true; + } + }; + } @Test public void testSingleStepPreferredForHighCardinalitySingleGroupByKey() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", BIGINT); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(groupingKey), source, symbolAllocator); Rule.Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate(1_000_000, ImmutableMap.of( @@ -89,13 +117,13 @@ public void testSingleStepPreferredForHighCardinalitySingleGroupByKey() @Test public void testSingleStepPreferredForHighCardinalityMultipleGroupByKeys() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol lowCardinalityGroupingKey = symbolAllocator.newSymbol("lowCardinalityGroupingKey", BIGINT); Symbol highCardinalityGroupingKey = symbolAllocator.newSymbol("highCardinalityGroupingKey", BIGINT); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(lowCardinalityGroupingKey, highCardinalityGroupingKey), source, symbolAllocator); Rule.Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate(1_000_000, ImmutableMap.of( @@ -109,21 +137,21 @@ public void testSingleStepPreferredForHighCardinalityMultipleGroupByKeys() @Test public void testPreAggregatePreferredForLowCardinality2GroupByKeys() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( symbolAllocator.newSymbol("key1", BIGINT), symbolAllocator.newSymbol("key2", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); Rule.Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate( 1_000_000, groupingKeys.stream().collect(toImmutableMap( Function.identity(), - key -> SymbolStatsEstimate.builder().setDistinctValuesCount(10).build())))), + _ -> SymbolStatsEstimate.builder().setDistinctValuesCount(10).build())))), new SymbolAllocator()); assertThat(controller.shouldUsePreAggregate(aggregationNode, context)).isTrue(); assertThat(controller.shouldAddMarkDistinct(aggregationNode, context)).isFalse(); @@ -132,13 +160,13 @@ public void testPreAggregatePreferredForLowCardinality2GroupByKeys() @Test public void testPreAggregatePreferredForUnknownStatisticsAnd2GroupByKeys() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( symbolAllocator.newSymbol("key1", BIGINT), symbolAllocator.newSymbol("key2", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); Rule.Context context = context(ImmutableMap.of(), new SymbolAllocator()); assertThat(controller.shouldUsePreAggregate(aggregationNode, context)).isTrue(); @@ -148,11 +176,11 @@ public void testPreAggregatePreferredForUnknownStatisticsAnd2GroupByKeys() @Test public void testPreAggregatePreferredForMediumCardinalitySingleGroupByKey() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", BIGINT); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(groupingKey), source, symbolAllocator); Rule.Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10, ImmutableMap.of( @@ -165,44 +193,44 @@ public void testPreAggregatePreferredForMediumCardinalitySingleGroupByKey() @Test public void testSingleStepPreferredForMediumCardinality3GroupByKeys() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( symbolAllocator.newSymbol("key1", BIGINT), symbolAllocator.newSymbol("key2", BIGINT), symbolAllocator.newSymbol("key3", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); Rule.Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10, groupingKeys.stream().collect(toImmutableMap( Function.identity(), - key -> SymbolStatsEstimate.builder().setDistinctValuesCount(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10).build())))), + _ -> SymbolStatsEstimate.builder().setDistinctValuesCount(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10).build())))), symbolAllocator); assertShouldUseSingleStep(controller, aggregationNode, context); } @Test - public void testPreAggregatePreferredForGlobalAggregation() + public void testSplitToSubqueriesPreferredForGlobalAggregation() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(), source, symbolAllocator); - Rule.Context context = context( + assertThat((boolean) inTransaction(session -> controller.shouldSplitToSubqueries(aggregationNode, context( ImmutableMap.of(source, new PlanNodeStatsEstimate(1_000_000, ImmutableMap.of())), - symbolAllocator); - - assertThat(controller.shouldUsePreAggregate(aggregationNode, context)).isTrue(); + session, + symbolAllocator)))) + .isTrue(); } @Test public void testMarkDistinctPreferredForLowCardinality3GroupByKeys() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( @@ -210,14 +238,14 @@ public void testMarkDistinctPreferredForLowCardinality3GroupByKeys() symbolAllocator.newSymbol("key2", BIGINT), symbolAllocator.newSymbol("key3", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); Rule.Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate( 1_000_000, groupingKeys.stream().collect(toImmutableMap( Function.identity(), - key -> SymbolStatsEstimate.builder().setDistinctValuesCount(10).build())))), + _ -> SymbolStatsEstimate.builder().setDistinctValuesCount(10).build())))), new SymbolAllocator()); assertThat(controller.shouldAddMarkDistinct(aggregationNode, context)).isTrue(); } @@ -225,24 +253,24 @@ public void testMarkDistinctPreferredForLowCardinality3GroupByKeys() @Test public void testMarkDistinctPreferredForUnknownStatisticsAnd3GroupByKeys() { - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( symbolAllocator.newSymbol("key1", BIGINT), symbolAllocator.newSymbol("key2", BIGINT), symbolAllocator.newSymbol("key3", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); + PlanNode source = tableScan(); AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); - Rule.Context context = context(ImmutableMap.of(), new SymbolAllocator()); - assertThat(controller.shouldAddMarkDistinct(aggregationNode, context)).isTrue(); + assertThat((boolean) inTransaction(session -> controller.shouldAddMarkDistinct(aggregationNode, context(ImmutableMap.of(), session, symbolAllocator)))) + .isTrue(); } @Test public void testChoiceForcedByTheSessionProperty() { int clusterThreadCount = NODE_COUNT * getTaskConcurrency(TEST_SESSION); - DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + DistinctAggregationController controller = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", BIGINT); @@ -250,18 +278,24 @@ public void testChoiceForcedByTheSessionProperty() AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(groupingKey), source, symbolAllocator); // big NDV, distinct_aggregations_strategy = mark_distinct - assertThat(controller.shouldAddMarkDistinct(aggregationNode, context( - ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( - groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + assertThat((boolean) inTransaction( testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, MARK_DISTINCT.name()).build(), - symbolAllocator))).isTrue(); + session -> controller.shouldAddMarkDistinct(aggregationNode, context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + session, + symbolAllocator)))) + .isTrue(); // big NDV, distinct_aggregations_strategy = pre-aggregate - assertThat(controller.shouldUsePreAggregate(aggregationNode, context( - ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( - groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + assertThat((boolean) inTransaction( testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, PRE_AGGREGATE.name()).build(), - symbolAllocator))).isTrue(); + session -> controller.shouldUsePreAggregate(aggregationNode, context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + session, + symbolAllocator)))) + .isTrue(); // small NDV, distinct_aggregations_strategy = single_step assertShouldUseSingleStep(controller, aggregationNode, context( @@ -269,6 +303,32 @@ public void testChoiceForcedByTheSessionProperty() groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, SINGLE_STEP.name()).build(), symbolAllocator)); + + // big NDV, distinct_aggregations_strategy = split_to_subqueries + assertThat((boolean) inTransaction( + testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, SPLIT_TO_SUBQUERIES.name()).build(), + session -> controller.shouldSplitToSubqueries(aggregationNode, context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + session, + symbolAllocator)))) + .isTrue(); + } + + private T inTransaction(Function callback) + { + return inTransaction(TEST_SESSION, callback); + } + + private T inTransaction(Session session, Function callback) + { + return transaction(transactionManager, metadata, new AllowAllAccessControl()) + .execute(session, callback); + } + + private static PlanNode tableScan() + { + return new TableScanNode(new PlanNodeId("source"), TEST_TABLE_HANDLE, ImmutableList.of(), ImmutableMap.of(), TupleDomain.all(), Optional.empty(), false, Optional.empty()); } private static AggregationNode aggregationWithTwoDistinctAggregations(List groupingKeys, PlanNode source, SymbolAllocator symbolAllocator) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationToGroupBy.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationToGroupBy.java index c02556401085c3..b9e5b1bb5383d3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationToGroupBy.java @@ -30,6 +30,7 @@ import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.ir.Comparison.Operator.EQUAL; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -47,7 +48,9 @@ public class TestDistinctAggregationToGroupBy extends BaseRuleTest { private static final int NODES_COUNT = 4; - private static final DistinctAggregationController DISTINCT_AGGREGATION_CONTROLLER = new DistinctAggregationController(new TaskCountEstimator(() -> NODES_COUNT)); + private static final DistinctAggregationController DISTINCT_AGGREGATION_CONTROLLER = new DistinctAggregationController( + new TaskCountEstimator(() -> NODES_COUNT), + createTestMetadataManager()); @Test public void testGlobalWithNonDistinct() diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java index 5bbdb2991c9443..38af279dc6478b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java @@ -28,6 +28,7 @@ import java.util.Optional; import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; @@ -45,7 +46,7 @@ public class TestMultipleDistinctAggregationToMarkDistinct { private static final int NODES_COUNT = 4; private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> NODES_COUNT); - private static final DistinctAggregationController DISTINCT_AGGREGATION_CONTROLLER = new DistinctAggregationController(TASK_COUNT_ESTIMATOR); + private static final DistinctAggregationController DISTINCT_AGGREGATION_CONTROLLER = new DistinctAggregationController(TASK_COUNT_ESTIMATOR, createTestMetadataManager()); @Test public void testNoDistinct() diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java index 586c762be721a8..65eb88f0590044 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java @@ -21,6 +21,8 @@ import io.airlift.slice.Slice; import io.trino.Session; import io.trino.cost.PlanNodeStatsEstimate; +import io.trino.cost.SymbolStatsEstimate; +import io.trino.cost.TaskCountEstimator; import io.trino.metadata.AnalyzeMetadata; import io.trino.metadata.AnalyzeTableHandle; import io.trino.metadata.CatalogFunctionMetadata; @@ -44,6 +46,7 @@ import io.trino.metadata.TableProperties; import io.trino.metadata.TableSchema; import io.trino.metadata.TableVersion; +import io.trino.metadata.TestingFunctionResolution; import io.trino.metadata.ViewDefinition; import io.trino.metadata.ViewInfo; import io.trino.plugin.tpch.TpchColumnHandle; @@ -104,8 +107,12 @@ import io.trino.spi.statistics.TableStatisticsMetadata; import io.trino.spi.type.Type; import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Not; import io.trino.sql.ir.Reference; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.Symbol; @@ -114,6 +121,7 @@ import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.iterative.rule.test.RuleTester; +import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.PlanTester; import io.trino.testing.TestingTransactionHandle; @@ -133,6 +141,7 @@ import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; @@ -164,8 +173,13 @@ public class TestMultipleDistinctAggregationsToSubqueries private static final ColumnHandle COLUMN_4_HANDLE = new TpchColumnHandle(COLUMN_4, DATE); private static final String GROUPING_KEY_COLUMN = "suppkey"; private static final ColumnHandle GROUPING_KEY_COLUMN_HANDLE = new TpchColumnHandle(GROUPING_KEY_COLUMN, BIGINT); + private static final String GROUPING_KEY2_COLUMN = "comment"; + private static final ColumnHandle GROUPING_KEY2_COLUMN_HANDLE = new TpchColumnHandle(GROUPING_KEY2_COLUMN, VARCHAR); private static final String TABLE_NAME = "lineitem"; + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + private RuleTester ruleTester = tester(true); @AfterAll @@ -179,7 +193,7 @@ public final void tearDownTester() public void testDoesNotFire() { // no distinct aggregation - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol inputSymbol = p.symbol("inputSymbol"); @@ -194,7 +208,7 @@ public void testDoesNotFire() .doesNotFire(); // single distinct - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol inputSymbol = p.symbol("inputSymbol", BIGINT); @@ -210,7 +224,7 @@ public void testDoesNotFire() .doesNotFire(); // two distinct on the same input - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -228,7 +242,7 @@ public void testDoesNotFire() .doesNotFire(); // hash symbol - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -249,7 +263,7 @@ public void testDoesNotFire() .doesNotFire(); // non-distinct - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -270,7 +284,7 @@ public void testDoesNotFire() .doesNotFire(); // groupingSetCount > 1 - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -290,7 +304,7 @@ public void testDoesNotFire() .doesNotFire(); // complex subquery (join) - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -316,7 +330,7 @@ public void testDoesNotFire() .doesNotFire(); // complex subquery (filter on top of join to test recursion) - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -346,7 +360,7 @@ public void testDoesNotFire() // connector does not support efficient single column reads RuleTester ruleTesterNotObjectStore = tester(false); - ruleTesterNotObjectStore.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTesterNotObjectStore.getMetadata())) + ruleTesterNotObjectStore.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTesterNotObjectStore)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -366,7 +380,7 @@ public void testDoesNotFire() .doesNotFire(); // rule not enabled - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); @@ -383,12 +397,412 @@ public void testDoesNotFire() input2Symbol, COLUMN_2_HANDLE)))); }) .doesNotFire(); + + // automatic but single_step is preferred + String aggregationSourceId = "aggregationSourceId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder().addSymbolStatistics( + new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(1_000_000).build()).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(p.symbol("groupingKey", BIGINT)) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE))))); + }) + .doesNotFire(); + } + + @Test + public void testAutomaticDecisionForAggregationOnTableScan() + { + // automatic but single_step is preferred + String aggregationSourceId = "aggregationSourceId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder().addSymbolStatistics( + new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(1_000_000).build()).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(p.symbol("groupingKey")) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE))))); + }) + .doesNotFire(); + + // single_step is not preferred, the overhead of groupingKey is not big + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE))))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "group_by_key", PlanMatchPattern.expression(new Reference(BIGINT, "left_groupingKey"))), + join( + INNER, + builder -> builder + .equiCriteria("left_groupingKey", "right_groupingKey") + .left(aggregation( + singleGroupingSet("left_groupingKey"), + ImmutableMap.of(Optional.of("output1"), aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + Optional.empty(), + SINGLE, + tableScan( + TABLE_NAME, + ImmutableMap.of( + "input1Symbol", COLUMN_1, + "left_groupingKey", GROUPING_KEY_COLUMN)))) + .right(aggregation( + singleGroupingSet("right_groupingKey"), + ImmutableMap.of(Optional.of("output2"), aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + Optional.empty(), + SINGLE, + tableScan( + TABLE_NAME, + ImmutableMap.of( + "input2Symbol", COLUMN_2, + "right_groupingKey", GROUPING_KEY_COLUMN))))))); + + // single_step is not preferred, the overhead of groupingKeys is bigger than 50% + String aggregationId = "aggregationId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) + .addSymbolStatistics(new Symbol(BIGINT, "groupingKey2"), SymbolStatsEstimate.builder().setAverageRowSize(1_000_000).build()) + .build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(10).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol groupingKey2 = p.symbol("groupingKey2", VARCHAR); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey, groupingKey2) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey, groupingKey2)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE, + groupingKey2, GROUPING_KEY2_COLUMN_HANDLE))))); + }) + .doesNotFire(); + } + + @Test + public void testAutomaticDecisionForAggregationOnProjectedTableScan() + { + String aggregationSourceId = "aggregationSourceId"; + String aggregationId = "aggregationId"; + // the overhead of the projection is bigger than 50% + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput1"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput2"), SymbolStatsEstimate.builder().setAverageRowSize(1_000_000).build()) + .build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(10).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol projectionInput1 = p.symbol("projectionInput1", BIGINT); + Symbol projectionInput2 = p.symbol("projectionInput2", VARCHAR); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.project( + Assignments.builder() + .putIdentity(input1Symbol) + .putIdentity(input2Symbol) + .put(groupingKey, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "projectionInput1"), new Cast(new Reference(BIGINT, "projectionInput2"), BIGINT)))) + .build(), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, projectionInput1, projectionInput2)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + projectionInput1, GROUPING_KEY_COLUMN_HANDLE, + projectionInput2, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .doesNotFire(); + + // the big projection is used as distinct input. we could handle this case, but for simplicity sake, the rule won't fire here + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput1"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput2"), SymbolStatsEstimate.builder().setAverageRowSize(1_000_000).build()) + .build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(10).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol projectionInput1 = p.symbol("projectionInput1", BIGINT); + Symbol projectionInput2 = p.symbol("projectionInput2", VARCHAR); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.project( + Assignments.builder() + .put(input1Symbol, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "projectionInput1"), new Cast(new Reference(BIGINT, "projectionInput2"), BIGINT)))) + .putIdentity(input2Symbol) + .putIdentity(groupingKey) + .build(), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(groupingKey, input2Symbol, projectionInput1, projectionInput2)) + .setAssignments(ImmutableMap.of( + groupingKey, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + projectionInput1, GROUPING_KEY_COLUMN_HANDLE, + projectionInput2, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .doesNotFire(); + } + + @Test + public void testAutomaticDecisionForAggregationOnFilteredTableScan() + { + String aggregationSourceId = "aggregationSourceId"; + String aggregationId = "aggregationId"; + String filterId = "filterId"; + // selective filter + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1).build()) + .build()) + .overrideStats(filterId, PlanNodeStatsEstimate.builder().setOutputRowCount(1).build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(1).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol filterInput = p.symbol("filterInput", VARCHAR); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.filter( + new PlanNodeId(filterId), + new Not(new IsNull(new Reference(VARCHAR, "filterInput"))), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey, filterInput)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE, + filterInput, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .doesNotFire(); + + // non-selective filter + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1).build()) + .build()) + .overrideStats(filterId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol filterInput = p.symbol("filterInput", VARCHAR); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.filter( + new PlanNodeId(filterId), + new Not(new IsNull(new Reference(VARCHAR, "filterInput"))), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey, filterInput)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE, + filterInput, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "group_by_key", PlanMatchPattern.expression(new Reference(BIGINT, "left_groupingKey"))), + join( + INNER, + builder -> builder + .equiCriteria("left_groupingKey", "right_groupingKey") + .left(aggregation( + singleGroupingSet("left_groupingKey"), + ImmutableMap.of(Optional.of("output1"), aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + Optional.empty(), + SINGLE, + filter( + new Not(new IsNull(new Reference(BIGINT, "left_filterInput"))), + tableScan( + TABLE_NAME, + ImmutableMap.of( + "input1Symbol", COLUMN_1, + "left_groupingKey", GROUPING_KEY_COLUMN, + "left_filterInput", GROUPING_KEY2_COLUMN))))) + .right(aggregation( + singleGroupingSet("right_groupingKey"), + ImmutableMap.of(Optional.of("output2"), aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + Optional.empty(), + SINGLE, + filter( + new Not(new IsNull(new Reference(BIGINT, "right_filterInput"))), + tableScan( + TABLE_NAME, + ImmutableMap.of( + "input2Symbol", COLUMN_2, + "right_groupingKey", GROUPING_KEY_COLUMN, + "right_filterInput", GROUPING_KEY2_COLUMN)))))))); + } + + @Test + public void testAutomaticDecisionForAggregationOnFilteredUnion() + { + String aggregationSourceId = "aggregationSourceId"; + String aggregationId = "aggregationId"; + String filterId = "filterId"; + // union with additional columns to read + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1).build()) + .build()) + .overrideStats(filterId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input11Symbol = p.symbol("input1_1Symbol", BIGINT); + Symbol input12Symbol = p.symbol("input1_2Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol input21Symbol = p.symbol("input2_1Symbol", BIGINT); + Symbol input22Symbol = p.symbol("input2_2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol groupingKey1 = p.symbol("groupingKey1", BIGINT); + Symbol groupingKey2 = p.symbol("groupingKey2", BIGINT); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.union( + ImmutableListMultimap.builder() + .put(input1Symbol, input11Symbol) + .put(input1Symbol, input12Symbol) + .put(input2Symbol, input21Symbol) + .put(input2Symbol, input22Symbol) + .put(groupingKey, groupingKey1) + .put(groupingKey, groupingKey2) + .build(), + ImmutableList.of( + p.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input1_1Symbol"), new Constant(BIGINT, 0L)), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input11Symbol, input21Symbol, groupingKey1), + ImmutableMap.of( + input11Symbol, COLUMN_1_HANDLE, + input21Symbol, COLUMN_2_HANDLE, + groupingKey1, GROUPING_KEY_COLUMN_HANDLE))), + p.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input2_2Symbol"), new Constant(BIGINT, 2L)), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input12Symbol, input22Symbol, groupingKey2), + ImmutableMap.of( + input12Symbol, COLUMN_1_HANDLE, + input22Symbol, COLUMN_2_HANDLE, + groupingKey2, GROUPING_KEY_COLUMN_HANDLE))))))); + }) + .doesNotFire(); } @Test public void testGlobalDistinctToSubqueries() { - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -423,7 +837,7 @@ public void testGlobalDistinctToSubqueries() @Test public void testGlobalWith3DistinctToSubqueries() { - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -469,7 +883,7 @@ public void testGlobalWith3DistinctToSubqueries() @Test public void testGlobalWith4DistinctToSubqueries() { - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -524,7 +938,7 @@ public void testGlobalWith4DistinctToSubqueries() @Test public void testGlobal2DistinctOnTheSameInputToSubqueries() { - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .on(p -> { Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); @@ -564,7 +978,7 @@ public void testGlobal2DistinctOnTheSameInputToSubqueries() public void testGroupByWithDistinctToSubqueries() { String aggregationNodeId = "aggregationNodeId"; - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .overrideStats(aggregationNodeId, PlanNodeStatsEstimate.builder().setOutputRowCount(100_000).build()) .on(p -> { @@ -620,7 +1034,7 @@ public void testGroupByWithDistinctToSubqueries() public void testGroupByWithDistinctOverUnionToSubqueries() { String aggregationNodeId = "aggregationNodeId"; - ruleTester.assertThat(new MultipleDistinctAggregationsToSubqueries(ruleTester.getMetadata())) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") .overrideStats(aggregationNodeId, PlanNodeStatsEstimate.builder().setOutputRowCount(100_000).build()) .on(p -> { @@ -730,6 +1144,13 @@ public void testGroupByWithDistinctOverUnionToSubqueries() .withAlias("right_groupingKey", new SetOperationOutputMatcher(2))))))); } + private static MultipleDistinctAggregationsToSubqueries newMultipleDistinctAggregationsToSubqueries(RuleTester ruleTester) + { + return new MultipleDistinctAggregationsToSubqueries(new DistinctAggregationController( + new TaskCountEstimator(() -> Integer.MAX_VALUE), + ruleTester.getMetadata())); + } + private static TableHandle testTableHandle(RuleTester ruleTester) { return new TableHandle(ruleTester.getCurrentCatalogHandle(), new TpchTableHandle("sf1", TABLE_NAME, 1.0), TestingTransactionHandle.create()); @@ -756,12 +1177,12 @@ public boolean isColumnarTableScan(Session session, TableHandle tableHandle) return new RuleTester(planTester); } - private static class DelegatingMetadata + public static class DelegatingMetadata implements Metadata { private final Metadata metadata; - private DelegatingMetadata(Metadata metadata) + public DelegatingMetadata(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 59d440a27a08e1..c85ccadd310ecd 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -620,6 +620,7 @@ public static class TableScanBuilder private Optional statistics = Optional.empty(); private boolean updateTarget; private Optional useConnectorNodePartitioning = Optional.empty(); + private Optional nodeId = Optional.empty(); private TableScanBuilder(PlanNodeIdAllocator idAllocator) { @@ -667,6 +668,12 @@ public TableScanBuilder setUpdateTarget(boolean updateTarget) return this; } + public TableScanBuilder setNodeId(PlanNodeId id) + { + this.nodeId = Optional.of(id); + return this; + } + public TableScanBuilder setUseConnectorNodePartitioning(Optional useConnectorNodePartitioning) { this.useConnectorNodePartitioning = useConnectorNodePartitioning; @@ -676,7 +683,7 @@ public TableScanBuilder setUseConnectorNodePartitioning(Optional useCon public TableScanNode build() { return new TableScanNode( - idAllocator.getNextId(), + nodeId.orElseGet(idAllocator::getNextId), tableHandle, symbols, assignments,