Skip to content

Commit

Permalink
Record estimation stats during query optimization
Browse files Browse the repository at this point in the history
When optimizer returns a optimized plan, it will also return the estimation of stats for each node with the plan,
However, instead of returning the exact stats which are used in optimization, it's actually recalculating the stats.
This can be a problem. For example, currently CBO returns empty stats if the aggregation step is not single for an aggregation
This means that, we will not get any CBO stats for partial and final aggregation, and all other node which are downstream of the aggregation.
In this PR, it will record the stats during query optimization. For the same node, later stats will override previous ones.
  • Loading branch information
feilong-liu committed May 23, 2024
1 parent 584636f commit 69f8f1a
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 12 deletions.
15 changes: 15 additions & 0 deletions presto-main/src/main/java/com/facebook/presto/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
Expand All @@ -25,6 +27,7 @@
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.spi.security.AccessControlContext;
import com.facebook.presto.spi.security.Identity;
Expand Down Expand Up @@ -99,6 +102,8 @@ public final class Session
private final OptimizerInformationCollector optimizerInformationCollector = new OptimizerInformationCollector();
private final OptimizerResultCollector optimizerResultCollector = new OptimizerResultCollector();
private final CTEInformationCollector cteInformationCollector = new CTEInformationCollector();
private final Map<PlanNodeId, PlanNodeStatsEstimate> planNodeStatsMap = new HashMap<>();
private final Map<PlanNodeId, PlanCostEstimate> planNodeCostMap = new HashMap<>();

public Session(
QueryId queryId,
Expand Down Expand Up @@ -337,6 +342,16 @@ public CTEInformationCollector getCteInformationCollector()
return cteInformationCollector;
}

public Map<PlanNodeId, PlanNodeStatsEstimate> getPlanNodeStatsMap()
{
return planNodeStatsMap;
}

public Map<PlanNodeId, PlanCostEstimate> getPlanNodeCostMap()
{
return planNodeCostMap;
}

public Session beginTransactionId(TransactionId transactionId, TransactionManager transactionManager, AccessControl accessControl)
{
requireNonNull(transactionId, "transactionId is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ public final class SystemSessionProperties
public static final String SKIP_HASH_GENERATION_FOR_JOIN_WITH_TABLE_SCAN_INPUT = "skip_hash_generation_for_join_with_table_scan_input";
public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters";
public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression";
public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache";

// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1915,6 +1916,12 @@ public SystemSessionProperties(
"Rewrite left join with is null check to semi join",
featuresConfig.isRewriteExpressionWithConstantVariable(),
false),
booleanProperty(
PRINT_ESTIMATED_STATS_FROM_CACHE,
"When printing estimated plan stats after optimization is complete, such as in an EXPLAIN query or for logging in a QueryCompletedEvent, " +
"get stats from a cache that was populated during query optimization rather than recalculating the stats on the final plan.",
featuresConfig.isPrintEstimatedStatsFromCache(),
false),
new PropertyMetadata<>(
DEFAULT_VIEW_SECURITY_MODE,
format("Set default view security mode. Options are: %s",
Expand Down Expand Up @@ -3218,4 +3225,9 @@ public static boolean isJoinPrefilterEnabled(Session session)
{
return session.getSystemProperty(JOIN_PREFILTER_BUILD_SIDE, Boolean.class);
}

public static boolean isPrintEstimatedStatsFromCacheEnabled(Session session)
{
return session.getSystemProperty(PRINT_ESTIMATED_STATS_FROM_CACHE, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,20 @@ public PlanCostEstimate getCost(PlanNode node)

try {
if (node instanceof GroupReference) {
return getGroupCost((GroupReference) node);
PlanCostEstimate result = getGroupCost((GroupReference) node);
session.getPlanNodeCostMap().put(node.getId(), result);
return result;
}

PlanCostEstimate cost = cache.get(node);
if (cost != null) {
session.getPlanNodeCostMap().put(node.getId(), cost);
return cost;
}

cost = calculateCost(node);
verify(cache.put(node, cost) == null, "Cost already set");
session.getPlanNodeCostMap().put(node.getId(), cost);
return cost;
}
catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,20 @@ public PlanNodeStatsEstimate getStats(PlanNode node)

try {
if (node instanceof GroupReference) {
return getGroupStats((GroupReference) node);
PlanNodeStatsEstimate result = getGroupStats((GroupReference) node);
session.getPlanNodeStatsMap().put(node.getId(), result);
return result;
}

PlanNodeStatsEstimate stats = cache.get(node);
if (stats != null) {
session.getPlanNodeStatsMap().put(node.getId(), stats);
return stats;
}

stats = statsCalculator.calculateStats(node, this, lookup, session, types);
verify(cache.put(node, stats) == null, "Stats already set");
session.getPlanNodeStatsMap().put(node.getId(), stats);
return stats;
}
catch (RuntimeException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.fasterxml.jackson.annotation.JsonCreator;
Expand All @@ -24,6 +25,7 @@
import java.util.Map;
import java.util.Objects;

import static com.facebook.presto.SystemSessionProperties.isPrintEstimatedStatsFromCacheEnabled;
import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -77,15 +79,16 @@ public StatsAndCosts getForSubplan(PlanNode root)
return new StatsAndCosts(filteredStats.build(), filteredCosts.build());
}

public static StatsAndCosts create(PlanNode root, StatsProvider statsProvider, CostProvider costProvider)
public static StatsAndCosts create(PlanNode root, StatsProvider statsProvider, CostProvider costProvider, Session session)
{
Iterable<PlanNode> planIterator = Traverser.forTree(PlanNode::getSources)
.depthFirstPreOrder(root);
ImmutableMap.Builder<PlanNodeId, PlanNodeStatsEstimate> stats = ImmutableMap.builder();
ImmutableMap.Builder<PlanNodeId, PlanCostEstimate> costs = ImmutableMap.builder();
boolean printStatsFromCache = isPrintEstimatedStatsFromCacheEnabled(session);
for (PlanNode node : planIterator) {
stats.put(node.getId(), statsProvider.getStats(node));
costs.put(node.getId(), costProvider.getCost(node));
stats.put(node.getId(), printStatsFromCache ? session.getPlanNodeStatsMap().getOrDefault(node.getId(), PlanNodeStatsEstimate.unknown()) : statsProvider.getStats(node));
costs.put(node.getId(), printStatsFromCache ? session.getPlanNodeCostMap().getOrDefault(node.getId(), PlanCostEstimate.unknown()) : costProvider.getCost(node));
}
return new StatsAndCosts(stats.build(), costs.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private StatsAndCosts computeStats(PlanNode root, TypeProvider types)
(node instanceof JoinNode) || (node instanceof SemiJoinNode)).matches()) {
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session);
return StatsAndCosts.create(root, statsProvider, costProvider);
return StatsAndCosts.create(root, statsProvider, costProvider, session);
}
return StatsAndCosts.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ public class FeaturesConfig
private long kHyperLogLogAggregationGroupNumberLimit;
private boolean limitNumberOfGroupsForKHyperLogLogAggregations = true;
private boolean generateDomainFilters;
private boolean printEstimatedStatsFromCache;
private CreateView.Security defaultViewSecurityMode = DEFINER;

public enum PartitioningPrecisionStrategy
Expand Down Expand Up @@ -3101,4 +3102,17 @@ public FeaturesConfig setDefaultViewSecurityMode(CreateView.Security securityMod
this.defaultViewSecurityMode = securityMode;
return this;
}

public boolean isPrintEstimatedStatsFromCache()
{
return this.printEstimatedStatsFromCache;
}

@Config("optimizer.print-estimated-stats-from-cache")
@ConfigDescription("In the end of query optimization, print the estimation stats from cache populated during optimization instead of calculating from ground")
public FeaturesConfig setPrintEstimatedStatsFromCache(boolean printEstimatedStatsFromCache)
{
this.printEstimatedStatsFromCache = printEstimatedStatsFromCache;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ private CostAssertionBuilder assertCostSingleStageFragmentedPlan(
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider);
CostProvider costProvider = new TestingCostProvider(costs, costCalculatorUsingExchanges, statsProvider, session);
// Explicitly generate the statsAndCosts, bypass fragment generation and sanity checks for mock plans.
StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider).getForSubplan(node);
StatsAndCosts statsAndCosts = StatsAndCosts.create(node, statsProvider, costProvider, session).getForSubplan(node);
return new CostAssertionBuilder(statsAndCosts.getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown()));
}

Expand Down Expand Up @@ -807,7 +807,7 @@ private PlanCostEstimate calculateCostFragmentedPlan(PlanNode node, StatsCalcula
TypeProvider typeProvider = TypeProvider.copyOf(types);
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider);
CostProvider costProvider = new CachingCostProvider(costCalculatorUsingExchanges, statsProvider, Optional.empty(), session);
SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider)));
SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider, session)));
return subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ public void testDefaults()
.setDefaultWriterReplicationCoefficient(3.0)
.setDefaultViewSecurityMode(DEFINER)
.setCteHeuristicReplicationThreshold(4)
.setLegacyJsonCast(true));
.setLegacyJsonCast(true)
.setPrintEstimatedStatsFromCache(false));
}

@Test
Expand Down Expand Up @@ -485,6 +486,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.default-writer-replication-coefficient", "5.0")
.put("default-view-security-mode", INVOKER.name())
.put("cte-heuristic-replication-threshold", "2")
.put("optimizer.print-estimated-stats-from-cache", "true")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -696,7 +698,8 @@ public void testExplicitPropertyMappings()
.setDefaultWriterReplicationCoefficient(5.0)
.setDefaultViewSecurityMode(INVOKER)
.setCteHeuristicReplicationThreshold(2)
.setLegacyJsonCast(false);
.setLegacyJsonCast(false)
.setPrintEstimatedStatsFromCache(true);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static void assertPlan(Session session, Metadata metadata, StatsProvider
// TODO (Issue #13231) add back printing unresolved plan once we have no need to translate OriginalExpression to RowExpression
if (!matches.isMatch()) {
PlanNode resolvedPlan = resolveGroupReferences(actual.getRoot(), lookup);
String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), StatsAndCosts.create(resolvedPlan, statsProvider, node -> PlanCostEstimate.unknown()), metadata.getFunctionAndTypeManager(), session, 0);
String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), StatsAndCosts.create(resolvedPlan, statsProvider, node -> PlanCostEstimate.unknown(), session), metadata.getFunctionAndTypeManager(), session, 0);
throw new AssertionError(format(
"Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n]",
pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ private String formatPlan(PlanNode plan, TypeProvider types)
{
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, session);
return inTransaction(session -> textLogicalPlan(plan, types, StatsAndCosts.create(plan, statsProvider, costProvider), metadata.getFunctionAndTypeManager(), session, 2, false, isVerboseOptimizerInfoEnabled(session)));
return inTransaction(session -> textLogicalPlan(plan, types, StatsAndCosts.create(plan, statsProvider, costProvider, session), metadata.getFunctionAndTypeManager(), session, 2, false, isVerboseOptimizerInfoEnabled(session)));
}

private <T> T inTransaction(Function<Session, T> transactionSessionConsumer)
Expand Down

0 comments on commit 69f8f1a

Please sign in to comment.