Skip to content

Commit

Permalink
When scheduling only wait for RemoteSourceNodes
Browse files Browse the repository at this point in the history
Several parts of the scheduler system would use PlanFragment.getSources
to find all nodes to schedule, and this method would look for all leaf nodes
which are not in a black list. This black list was error prone and is missing
ValuesNode.  This missing entry causes the scheduler to wait for ValuesNodes
to be scheduled which never happens since values nodes are not schduled.

The getSources is not really needed as all callers are searching for
RemoteSourceNodes, so there is simply a method for that now.
  • Loading branch information
dain committed Mar 9, 2015
1 parent d07c0b0 commit d191d59
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 65 deletions.
Expand Up @@ -24,7 +24,6 @@
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
Expand Down Expand Up @@ -198,7 +197,7 @@ public synchronized QueryInfo getQueryInfo(StageInfo rootStage)
totalUserTime += stageStats.getTotalUserTime().roundTo(NANOSECONDS);
totalBlockedTime += stageStats.getTotalBlockedTime().roundTo(NANOSECONDS);

if (Iterables.any(stageInfo.getPlan().getSources(), Predicates.instanceOf(TableScanNode.class))) {
if (stageInfo.getPlan().getPartitionedSourceNode() instanceof TableScanNode) {
rawInputDataSize += stageStats.getRawInputDataSize().toBytes();
rawInputPositions += stageStats.getRawInputPositions();

Expand Down
Expand Up @@ -32,7 +32,6 @@
import com.facebook.presto.sql.planner.PlanFragment.PlanDistribution;
import com.facebook.presto.sql.planner.StageExecutionPlan;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -64,6 +63,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -74,6 +74,8 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.facebook.presto.OutputBuffers.INITIAL_EMPTY_OUTPUT_BUFFERS;
import static com.facebook.presto.spi.StandardErrorCode.NO_NODES_AVAILABLE;
Expand Down Expand Up @@ -107,6 +109,7 @@ public class SqlStageExecution
private final StageId stageId;
private final URI location;
private final PlanFragment fragment;
private final Set<PlanNodeId> allSources;
private final Map<PlanFragmentId, StageExecutionNode> subStages;

private final Multimap<Node, TaskId> localNodeTaskMap = HashMultimap.create();
Expand Down Expand Up @@ -209,6 +212,13 @@ private SqlStageExecution(@Nullable StageExecutionNode parent,
this.initialHashPartitions = initialHashPartitions;
this.executor = executor;

this.allSources = Stream.concat(
Stream.of(fragment.getPartitionedSource()),
fragment.getRemoteSourceNodes().stream()
.map(RemoteSourceNode::getId))
.filter(Objects::nonNull)
.collect(Collectors.toSet());

ImmutableMap.Builder<PlanFragmentId, StageExecutionNode> subStages = ImmutableMap.builder();
for (StageExecutionPlan subStagePlan : plan.getSubStages()) {
PlanFragmentId subStageFragmentId = subStagePlan.getFragment().getId();
Expand Down Expand Up @@ -459,18 +469,15 @@ private Multimap<PlanNodeId, URI> getNewExchangeLocations()
Multimap<PlanNodeId, URI> exchangeLocations = this.exchangeLocations.get();

ImmutableMultimap.Builder<PlanNodeId, URI> newExchangeLocations = ImmutableMultimap.builder();
for (PlanNode planNode : fragment.getSources()) {
if (planNode instanceof RemoteSourceNode) {
RemoteSourceNode remoteSourceNode = (RemoteSourceNode) planNode;
for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
StageExecutionNode subStage = subStages.get(planFragmentId);
checkState(subStage != null, "Unknown sub stage %s, known stages %s", planFragmentId, subStages.keySet());

// add new task locations
for (URI taskLocation : subStage.getTaskLocations()) {
if (!exchangeLocations.containsEntry(remoteSourceNode.getId(), taskLocation)) {
newExchangeLocations.putAll(remoteSourceNode.getId(), taskLocation);
}
for (RemoteSourceNode remoteSourceNode : fragment.getRemoteSourceNodes()) {
for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
StageExecutionNode subStage = subStages.get(planFragmentId);
checkState(subStage != null, "Unknown sub stage %s, known stages %s", planFragmentId, subStages.keySet());

// add new task locations
for (URI taskLocation : subStage.getTaskLocations()) {
if (!exchangeLocations.containsEntry(remoteSourceNode.getId(), taskLocation)) {
newExchangeLocations.putAll(remoteSourceNode.getId(), taskLocation);
}
}
}
Expand Down Expand Up @@ -786,7 +793,7 @@ private boolean addNewExchangesAndBuffers()
{
// get new exchanges and update exchange state
Set<PlanNodeId> completeSources = updateCompleteSources();
boolean allSourceComplete = completeSources.containsAll(fragment.getSourceIds());
boolean allSourceComplete = completeSources.containsAll(allSources);
Multimap<PlanNodeId, URI> newExchangeLocations = getNewExchangeLocations();
exchangeLocations.set(ImmutableMultimap.<PlanNodeId, URI>builder()
.putAll(exchangeLocations.get())
Expand Down Expand Up @@ -819,7 +826,7 @@ private synchronized void waitForNewExchangesOrBuffers()
while (!getState().isDone()) {
// if next loop will finish, don't wait
Set<PlanNodeId> completeSources = updateCompleteSources();
boolean allSourceComplete = completeSources.containsAll(fragment.getSourceIds());
boolean allSourceComplete = completeSources.containsAll(allSources);
if (allSourceComplete && getCurrentOutputBuffers().isNoMoreBufferIds()) {
return;
}
Expand All @@ -846,9 +853,8 @@ private synchronized void waitForNewExchangesOrBuffers()

private Set<PlanNodeId> updateCompleteSources()
{
for (PlanNode planNode : fragment.getSources()) {
if (!completeSources.contains(planNode.getId()) && planNode instanceof RemoteSourceNode) {
RemoteSourceNode remoteSourceNode = (RemoteSourceNode) planNode;
for (RemoteSourceNode remoteSourceNode : fragment.getRemoteSourceNodes()) {
if (!completeSources.contains(remoteSourceNode.getId())) {
boolean exchangeFinished = true;
for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
StageExecutionNode subStage = subStages.get(planFragmentId);
Expand All @@ -860,7 +866,7 @@ private Set<PlanNodeId> updateCompleteSources()
}
}
if (exchangeFinished) {
completeSources.add(planNode.getId());
completeSources.add(remoteSourceNode.getId());
}
}
}
Expand Down
Expand Up @@ -33,6 +33,7 @@
import com.facebook.presto.operator.TaskStats;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -67,6 +68,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
Expand All @@ -80,6 +82,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_REQUESTS_FAILED;
Expand Down Expand Up @@ -407,15 +410,23 @@ private synchronized void scheduleUpdate()

private synchronized List<TaskSource> getSources()
{
ImmutableList.Builder<TaskSource> sources = ImmutableList.builder();
for (PlanNodeId planNodeId : planFragment.getSourceIds()) {
Set<ScheduledSplit> splits = pendingSplits.get(planNodeId);
boolean noMoreSplits = this.noMoreSplits.contains(planNodeId);
if (!splits.isEmpty() || noMoreSplits) {
sources.add(new TaskSource(planNodeId, splits, noMoreSplits));
}
return Stream.concat(Stream.of(planFragment.getPartitionedSourceNode()), planFragment.getRemoteSourceNodes().stream())
.filter(Objects::nonNull)
.map(PlanNode::getId)
.map(this::getSource)
.filter(Objects::nonNull)
.collect(toImmutableList());
}

private TaskSource getSource(PlanNodeId planNodeId)
{
Set<ScheduledSplit> splits = pendingSplits.get(planNodeId);
boolean noMoreSplits = this.noMoreSplits.contains(planNodeId);
TaskSource element = null;
if (!splits.isEmpty() || noMoreSplits) {
element = new TaskSource(planNodeId, splits, noMoreSplits);
}
return sources.build();
return element;
}

@Override
Expand Down
Expand Up @@ -14,10 +14,10 @@
package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.plan.IndexSourceNode;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
Expand All @@ -28,8 +28,8 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.MoreObjects.toStringHelper;
Expand All @@ -47,7 +47,7 @@ public enum PlanDistribution
COORDINATOR_ONLY
}

public static enum OutputPartitioning
public enum OutputPartitioning
{
NONE,
HASH
Expand All @@ -60,8 +60,8 @@ public static enum OutputPartitioning
private final PlanDistribution distribution;
private final PlanNodeId partitionedSource;
private final List<Type> types;
private final List<PlanNode> sources;
private final Set<PlanNodeId> sourceIds;
private final PlanNode partitionedSourceNode;
private final List<RemoteSourceNode> remoteSourceNodes;
private final OutputPartitioning outputPartitioning;
private final List<Symbol> partitionBy;
private final Optional<Symbol> hash;
Expand Down Expand Up @@ -94,18 +94,11 @@ public PlanFragment(
.map(symbols::get)
.collect(toImmutableList());

ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();
findSources(root, sources, partitionedSource);
this.sources = sources.build();
this.partitionedSourceNode = findSource(root, partitionedSource);

ImmutableSet.Builder<PlanNodeId> sourceIds = ImmutableSet.builder();
for (PlanNode source : this.sources) {
sourceIds.add(source.getId());
}
if (partitionedSource != null) {
sourceIds.add(partitionedSource);
}
this.sourceIds = sourceIds.build();
ImmutableList.Builder<RemoteSourceNode> remoteSourceNodes = ImmutableList.builder();
findRemoteSourceNodes(root, remoteSourceNodes);
this.remoteSourceNodes = remoteSourceNodes.build();

this.outputPartitioning = checkNotNull(outputPartitioning, "outputPartitioning is null");
}
Expand Down Expand Up @@ -169,24 +162,37 @@ public List<Type> getTypes()
return types;
}

public List<PlanNode> getSources()
public PlanNode getPartitionedSourceNode()
{
return sources;
return partitionedSourceNode;
}

public Set<PlanNodeId> getSourceIds()
public List<RemoteSourceNode> getRemoteSourceNodes()
{
return sourceIds;
return remoteSourceNodes;
}

private static PlanNode findSource(PlanNode node, PlanNodeId nodeId)
{
if (node.getId().equals(nodeId)) {
return node;
}

return node.getSources().stream()
.map(source -> findSource(source, nodeId))
.filter(Objects::nonNull)
.findAny()
.orElse(null);
}

private static void findSources(PlanNode node, Builder<PlanNode> builder, PlanNodeId partitionedSource)
private static void findRemoteSourceNodes(PlanNode node, Builder<RemoteSourceNode> builder)
{
for (PlanNode source : node.getSources()) {
findSources(source, builder, partitionedSource);
findRemoteSourceNodes(source, builder);
}

if ((node.getSources().isEmpty() && !(node instanceof IndexSourceNode)) || node.getId().equals(partitionedSource)) {
builder.add(node);
if (node instanceof RemoteSourceNode) {
builder.add((RemoteSourceNode) node);
}
}

Expand Down
Expand Up @@ -67,9 +67,7 @@ public List<PlanFragment> getAllFragments()

public void sanityCheck()
{
Multiset<PlanFragmentId> exchangeIds = fragment.getSources().stream()
.filter(RemoteSourceNode.class::isInstance)
.map(RemoteSourceNode.class::cast)
Multiset<PlanFragmentId> exchangeIds = fragment.getRemoteSourceNodes().stream()
.map(RemoteSourceNode::getSourceFragmentIds)
.flatMap(List::stream)
.collect(toImmutableMultiset());
Expand Down
Expand Up @@ -85,7 +85,6 @@
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.type.TypeRegistry;
import com.facebook.presto.type.TypeUtils;
Expand Down Expand Up @@ -399,13 +398,7 @@ public List<Driver> createDrivers(Session session, @Language("SQL") String sql,
// generate sources
List<TaskSource> sources = new ArrayList<>();
long sequenceId = 0;
for (PlanNode sourceNode : subplan.getFragment().getSources()) {
if (sourceNode instanceof ValuesNode) {
continue;
}

TableScanNode tableScan = (TableScanNode) sourceNode;

for (TableScanNode tableScan : findTableScanNodes(subplan.getFragment().getRoot())) {
SplitSource splitSource = splitManager.getPartitionSplits(tableScan.getTable(), getPartitions(tableScan));

ImmutableSet.Builder<ScheduledSplit> scheduledSplits = ImmutableSet.builder();
Expand Down Expand Up @@ -543,6 +536,24 @@ private Split getLocalQuerySplit(TableHandle tableHandle)
}
}

private static List<TableScanNode> findTableScanNodes(PlanNode node)
{
ImmutableList.Builder<TableScanNode> tableScanNodes = ImmutableList.builder();
findTableScanNodes(node, tableScanNodes);
return tableScanNodes.build();
}

private static void findTableScanNodes(PlanNode node, ImmutableList.Builder<TableScanNode> builder)
{
for (PlanNode source : node.getSources()) {
findTableScanNodes(source, builder);
}

if (node instanceof TableScanNode) {
builder.add((TableScanNode) node);
}
}

private static class HashProjectionFunction
implements ProjectionFunction
{
Expand Down
Expand Up @@ -26,6 +26,7 @@
import com.facebook.presto.sql.planner.TestingColumnHandle;
import com.facebook.presto.sql.planner.TestingTableHandle;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.google.common.collect.HashMultimap;
Expand All @@ -43,10 +44,12 @@
import java.net.URI;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.execution.StateMachine.StateChangeListener;
Expand Down Expand Up @@ -191,7 +194,13 @@ public void addSplits(PlanNodeId sourceId, Iterable<Split> splits)
public void noMoreSplits(PlanNodeId sourceId)
{
noMoreSplits.add(sourceId);
if (noMoreSplits.containsAll(fragment.getSourceIds())) {

boolean allSourcesComplete = Stream.concat(Stream.of(fragment.getPartitionedSourceNode()), fragment.getRemoteSourceNodes().stream())
.filter(Objects::nonNull)
.map(PlanNode::getId)
.allMatch(noMoreSplits::contains);

if (allSourcesComplete) {
taskStateMachine.finished();
}
}
Expand Down

0 comments on commit d191d59

Please sign in to comment.