Skip to content

Commit

Permalink
Rename PlanNodeCostEstimate to PlanCostEstimate
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi authored and wenleix committed Mar 2, 2019
1 parent 232e2c1 commit c8d69e3
Show file tree
Hide file tree
Showing 20 changed files with 187 additions and 187 deletions.
Expand Up @@ -40,7 +40,7 @@ public class CachingCostProvider
private final Session session; private final Session session;
private final TypeProvider types; private final TypeProvider types;


private final Map<PlanNode, PlanNodeCostEstimate> cache = new IdentityHashMap<>(); private final Map<PlanNode, PlanCostEstimate> cache = new IdentityHashMap<>();


public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Session session, TypeProvider types) public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Session session, TypeProvider types)
{ {
Expand All @@ -57,10 +57,10 @@ public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsPro
} }


@Override @Override
public PlanNodeCostEstimate getCumulativeCost(PlanNode node) public PlanCostEstimate getCumulativeCost(PlanNode node)
{ {
if (!isEnableStatsCalculator(session)) { if (!isEnableStatsCalculator(session)) {
return PlanNodeCostEstimate.unknown(); return PlanCostEstimate.unknown();
} }


requireNonNull(node, "node is null"); requireNonNull(node, "node is null");
Expand All @@ -70,7 +70,7 @@ public PlanNodeCostEstimate getCumulativeCost(PlanNode node)
return getGroupCost((GroupReference) node); return getGroupCost((GroupReference) node);
} }


PlanNodeCostEstimate cumulativeCost = cache.get(node); PlanCostEstimate cumulativeCost = cache.get(node);
if (cumulativeCost != null) { if (cumulativeCost != null) {
return cumulativeCost; return cumulativeCost;
} }
Expand All @@ -82,37 +82,37 @@ public PlanNodeCostEstimate getCumulativeCost(PlanNode node)
catch (RuntimeException e) { catch (RuntimeException e) {
if (isIgnoreStatsCalculatorFailures(session)) { if (isIgnoreStatsCalculatorFailures(session)) {
log.error(e, "Error occurred when computing cost for query %s", session.getQueryId()); log.error(e, "Error occurred when computing cost for query %s", session.getQueryId());
return PlanNodeCostEstimate.unknown(); return PlanCostEstimate.unknown();
} }
throw e; throw e;
} }
} }


private PlanNodeCostEstimate getGroupCost(GroupReference groupReference) private PlanCostEstimate getGroupCost(GroupReference groupReference)
{ {
int group = groupReference.getGroupId(); int group = groupReference.getGroupId();
Memo memo = this.memo.orElseThrow(() -> new IllegalStateException("CachingCostProvider without memo cannot handle GroupReferences")); Memo memo = this.memo.orElseThrow(() -> new IllegalStateException("CachingCostProvider without memo cannot handle GroupReferences"));


Optional<PlanNodeCostEstimate> cost = memo.getCumulativeCost(group); Optional<PlanCostEstimate> cost = memo.getCumulativeCost(group);
if (cost.isPresent()) { if (cost.isPresent()) {
return cost.get(); return cost.get();
} }


PlanNodeCostEstimate cumulativeCost = calculateCumulativeCost(memo.getNode(group)); PlanCostEstimate cumulativeCost = calculateCumulativeCost(memo.getNode(group));
verify(!memo.getCumulativeCost(group).isPresent(), "Group cost already set"); verify(!memo.getCumulativeCost(group).isPresent(), "Group cost already set");
memo.storeCumulativeCost(group, cumulativeCost); memo.storeCumulativeCost(group, cumulativeCost);
return cumulativeCost; return cumulativeCost;
} }


private PlanNodeCostEstimate calculateCumulativeCost(PlanNode node) private PlanCostEstimate calculateCumulativeCost(PlanNode node)
{ {
PlanNodeCostEstimate localCosts = costCalculator.calculateCost(node, statsProvider, session, types); PlanCostEstimate localCosts = costCalculator.calculateCost(node, statsProvider, session, types);


PlanNodeCostEstimate sourcesCost = node.getSources().stream() PlanCostEstimate sourcesCost = node.getSources().stream()
.map(this::getCumulativeCost) .map(this::getCumulativeCost)
.reduce(PlanNodeCostEstimate.zero(), PlanNodeCostEstimate::add); .reduce(PlanCostEstimate.zero(), PlanCostEstimate::add);


PlanNodeCostEstimate cumulativeCost = localCosts.add(sourcesCost); PlanCostEstimate cumulativeCost = localCosts.add(sourcesCost);
return cumulativeCost; return cumulativeCost;
} }
} }
Expand Up @@ -36,7 +36,7 @@ public interface CostCalculator
* @param node The node to compute cost for. * @param node The node to compute cost for.
* @param stats The stats provider for node's stats and child nodes' stats, to be used if stats are needed to compute cost for the {@code node} * @param stats The stats provider for node's stats and child nodes' stats, to be used if stats are needed to compute cost for the {@code node}
*/ */
PlanNodeCostEstimate calculateCost( PlanCostEstimate calculateCost(
PlanNode node, PlanNode node,
StatsProvider stats, StatsProvider stats,
Session session, Session session,
Expand Down
Expand Up @@ -50,7 +50,7 @@
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost; import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost; import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost; import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost;
import static com.facebook.presto.cost.PlanNodeCostEstimate.cpuCost; import static com.facebook.presto.cost.PlanCostEstimate.cpuCost;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;
Expand All @@ -71,14 +71,14 @@ public CostCalculatorUsingExchanges(TaskCountEstimator taskCountEstimator)
} }


@Override @Override
public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Session session, TypeProvider types) public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, Session session, TypeProvider types)
{ {
CostEstimator costEstimator = new CostEstimator(stats, types, taskCountEstimator); CostEstimator costEstimator = new CostEstimator(stats, types, taskCountEstimator);
return node.accept(costEstimator, null); return node.accept(costEstimator, null);
} }


private static class CostEstimator private static class CostEstimator
extends PlanVisitor<PlanNodeCostEstimate, Void> extends PlanVisitor<PlanCostEstimate, Void>
{ {
private final StatsProvider stats; private final StatsProvider stats;
private final TypeProvider types; private final TypeProvider types;
Expand All @@ -92,25 +92,25 @@ private static class CostEstimator
} }


@Override @Override
protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) protected PlanCostEstimate visitPlan(PlanNode node, Void context)
{ {
return PlanNodeCostEstimate.unknown(); return PlanCostEstimate.unknown();
} }


@Override @Override
public PlanNodeCostEstimate visitGroupReference(GroupReference node, Void context) public PlanCostEstimate visitGroupReference(GroupReference node, Void context)
{ {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }


@Override @Override
public PlanNodeCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context) public PlanCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context)
{ {
return cpuCost(getStats(node).getOutputSizeInBytes(ImmutableList.of(node.getIdColumn()), types)); return cpuCost(getStats(node).getOutputSizeInBytes(ImmutableList.of(node.getIdColumn()), types));
} }


@Override @Override
public PlanNodeCostEstimate visitRowNumber(RowNumberNode node, Void context) public PlanCostEstimate visitRowNumber(RowNumberNode node, Void context)
{ {
List<Symbol> symbols = node.getOutputSymbols(); List<Symbol> symbols = node.getOutputSymbols();
// when maxRowCountPerPartition is set, the RowNumberOperator // when maxRowCountPerPartition is set, the RowNumberOperator
Expand All @@ -124,49 +124,49 @@ public PlanNodeCostEstimate visitRowNumber(RowNumberNode node, Void context)
PlanNodeStatsEstimate stats = getStats(node); PlanNodeStatsEstimate stats = getStats(node);
double cpuCost = stats.getOutputSizeInBytes(symbols, types); double cpuCost = stats.getOutputSizeInBytes(symbols, types);
double memoryCost = node.getPartitionBy().isEmpty() ? 0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types); double memoryCost = node.getPartitionBy().isEmpty() ? 0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types);
return new PlanNodeCostEstimate(cpuCost, memoryCost, 0); return new PlanCostEstimate(cpuCost, memoryCost, 0);
} }


@Override @Override
public PlanNodeCostEstimate visitOutput(OutputNode node, Void context) public PlanCostEstimate visitOutput(OutputNode node, Void context)
{ {
return PlanNodeCostEstimate.zero(); return PlanCostEstimate.zero();
} }


@Override @Override
public PlanNodeCostEstimate visitTableScan(TableScanNode node, Void context) public PlanCostEstimate visitTableScan(TableScanNode node, Void context)
{ {
// TODO: add network cost, based on input size in bytes? Or let connector provide this cost? // TODO: add network cost, based on input size in bytes? Or let connector provide this cost?
return cpuCost(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); return cpuCost(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
} }


@Override @Override
public PlanNodeCostEstimate visitFilter(FilterNode node, Void context) public PlanCostEstimate visitFilter(FilterNode node, Void context)
{ {
return cpuCost(getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols(), types)); return cpuCost(getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols(), types));
} }


@Override @Override
public PlanNodeCostEstimate visitProject(ProjectNode node, Void context) public PlanCostEstimate visitProject(ProjectNode node, Void context)
{ {
return cpuCost(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); return cpuCost(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
} }


@Override @Override
public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) public PlanCostEstimate visitAggregation(AggregationNode node, Void context)
{ {
if (node.getStep() != FINAL && node.getStep() != SINGLE) { if (node.getStep() != FINAL && node.getStep() != SINGLE) {
return PlanNodeCostEstimate.unknown(); return PlanCostEstimate.unknown();
} }
PlanNodeStatsEstimate aggregationStats = getStats(node); PlanNodeStatsEstimate aggregationStats = getStats(node);
PlanNodeStatsEstimate sourceStats = getStats(node.getSource()); PlanNodeStatsEstimate sourceStats = getStats(node.getSource());
double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types); double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types);
double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols(), types); double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols(), types);
return new PlanNodeCostEstimate(cpuCost, memoryCost, 0); return new PlanCostEstimate(cpuCost, memoryCost, 0);
} }


@Override @Override
public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) public PlanCostEstimate visitJoin(JoinNode node, Void context)
{ {
return calculateJoinCost( return calculateJoinCost(
node, node,
Expand All @@ -175,39 +175,39 @@ public PlanNodeCostEstimate visitJoin(JoinNode node, Void context)
Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED))); Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)));
} }


private PlanNodeCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated) private PlanCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated)
{ {
PlanNodeCostEstimate joinInputCost = calculateJoinInputCost( PlanCostEstimate joinInputCost = calculateJoinInputCost(
probe, probe,
build, build,
stats, stats,
types, types,
replicated, replicated,
taskCountEstimator.estimateSourceDistributedTaskCount()); taskCountEstimator.estimateSourceDistributedTaskCount());
PlanNodeCostEstimate joinOutputCost = calculateJoinOutputCost(join); PlanCostEstimate joinOutputCost = calculateJoinOutputCost(join);
return joinInputCost.add(joinOutputCost); return joinInputCost.add(joinOutputCost);
} }


private PlanNodeCostEstimate calculateJoinOutputCost(PlanNode join) private PlanCostEstimate calculateJoinOutputCost(PlanNode join)
{ {
PlanNodeStatsEstimate outputStats = getStats(join); PlanNodeStatsEstimate outputStats = getStats(join);
double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols(), types); double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols(), types);
return cpuCost(joinOutputSize); return cpuCost(joinOutputSize);
} }


@Override @Override
public PlanNodeCostEstimate visitExchange(ExchangeNode node, Void context) public PlanCostEstimate visitExchange(ExchangeNode node, Void context)
{ {
double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types); double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types);
switch (node.getScope()) { switch (node.getScope()) {
case LOCAL: case LOCAL:
switch (node.getType()) { switch (node.getType()) {
case GATHER: case GATHER:
return PlanNodeCostEstimate.zero(); return PlanCostEstimate.zero();
case REPARTITION: case REPARTITION:
return calculateLocalRepartitionCost(inputSizeInBytes); return calculateLocalRepartitionCost(inputSizeInBytes);
case REPLICATE: case REPLICATE:
return PlanNodeCostEstimate.zero(); return PlanCostEstimate.zero();
default: default:
throw new IllegalArgumentException("Unexpected type: " + node.getType()); throw new IllegalArgumentException("Unexpected type: " + node.getType());
} }
Expand All @@ -231,7 +231,7 @@ public PlanNodeCostEstimate visitExchange(ExchangeNode node, Void context)
} }


@Override @Override
public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) public PlanCostEstimate visitSemiJoin(SemiJoinNode node, Void context)
{ {
return calculateJoinCost( return calculateJoinCost(
node, node,
Expand All @@ -241,7 +241,7 @@ public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context)
} }


@Override @Override
public PlanNodeCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context) public PlanCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context)
{ {
return calculateJoinCost( return calculateJoinCost(
node, node,
Expand All @@ -251,19 +251,19 @@ public PlanNodeCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context)
} }


@Override @Override
public PlanNodeCostEstimate visitValues(ValuesNode node, Void context) public PlanCostEstimate visitValues(ValuesNode node, Void context)
{ {
return PlanNodeCostEstimate.zero(); return PlanCostEstimate.zero();
} }


@Override @Override
public PlanNodeCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) public PlanCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context)
{ {
return PlanNodeCostEstimate.zero(); return PlanCostEstimate.zero();
} }


@Override @Override
public PlanNodeCostEstimate visitLimit(LimitNode node, Void context) public PlanCostEstimate visitLimit(LimitNode node, Void context)
{ {
// This is just a wild guess. First of all, LimitNode is rather rare except as a top node of a query plan, // This is just a wild guess. First of all, LimitNode is rather rare except as a top node of a query plan,
// so proper cost estimation is not that important. Second, since LimitNode can lead to incomplete evaluation // so proper cost estimation is not that important. Second, since LimitNode can lead to incomplete evaluation
Expand All @@ -273,16 +273,16 @@ public PlanNodeCostEstimate visitLimit(LimitNode node, Void context)
} }


@Override @Override
public PlanNodeCostEstimate visitUnion(UnionNode node, Void context) public PlanCostEstimate visitUnion(UnionNode node, Void context)
{ {
// Cost will be accounted either in CostCalculatorUsingExchanges#CostEstimator#visitExchanged // Cost will be accounted either in CostCalculatorUsingExchanges#CostEstimator#visitExchanged
// or in CostCalculatorWithEstimatedExchanges#CostEstimator#visitUnion // or in CostCalculatorWithEstimatedExchanges#CostEstimator#visitUnion
// This stub is needed just to avoid the cumulative cost being set to unknown // This stub is needed just to avoid the cumulative cost being set to unknown
return PlanNodeCostEstimate.zero(); return PlanCostEstimate.zero();
} }


@Override @Override
public PlanNodeCostEstimate visitSort(SortNode node, Void context) public PlanCostEstimate visitSort(SortNode node, Void context)
{ {
return cpuCost(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types)); return cpuCost(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
} }
Expand Down

0 comments on commit c8d69e3

Please sign in to comment.