Skip to content

Commit

Permalink
Fix leak of failed/aborted queries in coordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Nov 6, 2018
1 parent e3c5097 commit 4d4b0bb
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 63 deletions.
Expand Up @@ -36,6 +36,7 @@
import com.google.common.collect.Sets;
import io.airlift.units.Duration;

import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

import java.net.URI;
Expand Down Expand Up @@ -81,8 +82,11 @@ public final class SqlStageExecution

private final Map<Node, Set<RemoteTask>> tasks = new ConcurrentHashMap<>();
private final AtomicInteger nextTaskId = new AtomicInteger();
@GuardedBy("this")
private final Set<TaskId> allTasks = newConcurrentHashSet();
@GuardedBy("this")
private final Set<TaskId> finishedTasks = newConcurrentHashSet();
@GuardedBy("this")
private final Set<TaskId> doneTasks = newConcurrentHashSet();
private final AtomicBoolean splitsScheduled = new AtomicBoolean();

Expand Down Expand Up @@ -138,9 +142,11 @@ public SqlStageExecution(StageStateMachine stateMachine, RemoteTaskFactory remot
this.exchangeSources = fragmentToExchangeSource.build();

stateMachine.addStateChangeListener(newState -> {
// when stage transitions to a done state, check if all tasks have final status information
if (newState.isDone() && doneTasks.containsAll(allTasks)) {
stateMachine.setAllTasksFinal();
synchronized (this) {
// when stage transitions to a done state, check if all tasks have final status information
if (newState.isDone() && doneTasks.containsAll(allTasks)) {
stateMachine.setAllTasksFinal();
}
}
});
}
Expand Down Expand Up @@ -440,6 +446,71 @@ private static Split createRemoteSplitFor(TaskId taskId, URI taskLocation)
return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(splitLocation));
}

private synchronized void updateTaskStatus(TaskStatus taskStatus)
{
try {
// always update done tasks before, state transitions to ensure
// the transition to "final status info" is not missed
if (taskStatus.getState().isDone()) {
doneTasks.add(taskStatus.getTaskId());
}

StageState stageState = getState();
if (stageState.isDone()) {
return;
}

TaskState taskState = taskStatus.getState();
if (taskState == TaskState.FAILED) {
RuntimeException failure = taskStatus.getFailures().stream()
.findFirst()
.map(this::rewriteTransportFailure)
.map(ExecutionFailureInfo::toException)
.orElse(new PrestoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"));
stateMachine.transitionToFailed(failure);
}
else if (taskState == TaskState.ABORTED) {
// A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED)
stateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState));
}
else if (taskState == TaskState.FINISHED) {
finishedTasks.add(taskStatus.getTaskId());
}

if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING) {
if (taskState == TaskState.RUNNING) {
stateMachine.transitionToRunning();
}
if (finishedTasks.containsAll(allTasks)) {
stateMachine.transitionToFinished();
}
}
}
finally {
// after updating state, check if all tasks have final status information
if (stateMachine.getState().isDone() && doneTasks.containsAll(allTasks)) {
stateMachine.setAllTasksFinal();
}
}
}

private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo)
{
if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) {
return executionFailureInfo;
}

return new ExecutionFailureInfo(
executionFailureInfo.getType(),
executionFailureInfo.getMessage(),
executionFailureInfo.getCause(),
executionFailureInfo.getSuppressed(),
executionFailureInfo.getStack(),
executionFailureInfo.getErrorLocation(),
REMOTE_HOST_GONE.toErrorCode(),
executionFailureInfo.getRemoteHost());
}

@Override
public String toString()
{
Expand All @@ -457,51 +528,11 @@ private class StageTaskListener
public void stateChanged(TaskStatus taskStatus)
{
try {
// always update done tasks before, state transitions to ensure
// the transition to "final status info" is not missed
if (taskStatus.getState().isDone()) {
doneTasks.add(taskStatus.getTaskId());
}

updateMemoryUsage(taskStatus);
updateCompletedDriverGroups(taskStatus);

StageState stageState = getState();
if (stageState.isDone()) {
return;
}

TaskState taskState = taskStatus.getState();
if (taskState == TaskState.FAILED) {
RuntimeException failure = taskStatus.getFailures().stream()
.findFirst()
.map(this::rewriteTransportFailure)
.map(ExecutionFailureInfo::toException)
.orElse(new PrestoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"));
stateMachine.transitionToFailed(failure);
}
else if (taskState == TaskState.ABORTED) {
// A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED)
stateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState));
}
else if (taskState == TaskState.FINISHED) {
finishedTasks.add(taskStatus.getTaskId());
}

if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING) {
if (taskState == TaskState.RUNNING) {
stateMachine.transitionToRunning();
}
if (finishedTasks.containsAll(allTasks)) {
stateMachine.transitionToFinished();
}
}
}
finally {
// after updating state, check if all tasks have final status information
if (stateMachine.getState().isDone() && doneTasks.containsAll(allTasks)) {
stateMachine.setAllTasksFinal();
}
updateTaskStatus(taskStatus);
}
}

Expand Down Expand Up @@ -532,25 +563,6 @@ private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus)
// Making changes to completedDriverGroups will change newlyCompletedDriverGroups.
completedDriverGroups.addAll(newlyCompletedDriverGroups);
}

private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo)
{
if (executionFailureInfo.getRemoteHost() != null &&
failureDetector.getState(executionFailureInfo.getRemoteHost()) == GONE) {
return new ExecutionFailureInfo(
executionFailureInfo.getType(),
executionFailureInfo.getMessage(),
executionFailureInfo.getCause(),
executionFailureInfo.getSuppressed(),
executionFailureInfo.getStack(),
executionFailureInfo.getErrorLocation(),
REMOTE_HOST_GONE.toErrorCode(),
executionFailureInfo.getRemoteHost());
}
else {
return executionFailureInfo;
}
}
}

private static class ListenerManager<T>
Expand Down
@@ -0,0 +1,179 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;

import com.facebook.presto.client.NodeVersion;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.execution.TestSqlTaskManager.MockLocationFactory;
import com.facebook.presto.execution.scheduler.SplitSchedulerStats;
import com.facebook.presto.failureDetector.NoOpFailureDetector;
import com.facebook.presto.metadata.PrestoNode;
import com.facebook.presto.operator.StageExecutionStrategy;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.util.FinalizerService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.SettableFuture;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.net.URI;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;

import static com.facebook.presto.OutputBuffers.BufferType.ARBITRARY;
import static com.facebook.presto.OutputBuffers.createInitialEmptyOutputBuffers;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static java.util.concurrent.TimeUnit.MINUTES;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestSqlStageExecution
{
private ExecutorService executor;
private ScheduledExecutorService scheduledExecutor;

@BeforeClass
public void setUp()
{
executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s"));
scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s"));
}

@AfterClass
public void tearDown()
{
executor.shutdownNow();
scheduledExecutor.shutdownNow();
}

@Test(timeOut = 2 * 60 * 1000)
public void testFinalStageInfo()
throws Exception
{
// run test a few times to catch any race conditions
// this is not done with TestNG invocation count so there can be a global time limit on the test
for (int iteration = 0; iteration < 10; iteration++) {
testFinalStageInfoInternal();
}
}

private void testFinalStageInfoInternal()
throws Exception
{
NodeTaskMap nodeTaskMap = new NodeTaskMap(new FinalizerService());

StageId stageId = new StageId(new QueryId("query"), 0);
SqlStageExecution stage = new SqlStageExecution(
stageId,
new MockLocationFactory().createStageLocation(stageId),
createExchangePlanFragment(),
new MockRemoteTaskFactory(executor, scheduledExecutor),
TEST_SESSION,
true,
nodeTaskMap,
executor,
new NoOpFailureDetector(),
new SplitSchedulerStats());
stage.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY));

// add listener that fetches stage info when the final status is available
SettableFuture<StageInfo> finalStageInfo = SettableFuture.create();
stage.addFinalStatusListener(value -> finalStageInfo.set(stage.getStageInfo()));

// in a background thread add a ton of tasks
CountDownLatch latch = new CountDownLatch(1000);
Future<?> addTasksTask = executor.submit(() -> {
try {
for (int i = 0; i < 1_000_000; i++) {
if (Thread.interrupted()) {
return;
}
Node node = new PrestoNode(
"source" + i,
URI.create("http://10.0.0." + (i / 10_000) + ":" + (i % 10_000)),
NodeVersion.UNKNOWN,
false);
stage.scheduleTask(node, i, OptionalInt.empty());
latch.countDown();
}
}
finally {
while (latch.getCount() > 0) {
latch.countDown();
}
}
});

// wait for some tasks to be created, and then abort the query
latch.await(1, MINUTES);
assertFalse(stage.getStageInfo().getTasks().isEmpty());
stage.abort();

// once the final stage info is available, verify that it is complete
StageInfo stageInfo = finalStageInfo.get(1, MINUTES);
assertFalse(stageInfo.getTasks().isEmpty());
assertTrue(stageInfo.isCompleteInfo());
assertTrue(stage.getStageInfo().isCompleteInfo());

// cancel the background thread adding tasks
addTasksTask.cancel(true);
}

private static PlanFragment createExchangePlanFragment()
{
PlanNode planNode = new RemoteSourceNode(
new PlanNodeId("exchange"),
ImmutableList.of(new PlanFragmentId("source")),
ImmutableList.of(new Symbol("column")),
Optional.empty(),
REPARTITION);

ImmutableMap.Builder<Symbol, Type> types = ImmutableMap.builder();
for (Symbol symbol : planNode.getOutputSymbols()) {
types.put(symbol, VARCHAR);
}
return new PlanFragment(
new PlanFragmentId("exchange_fragment_id"),
planNode,
types.build(),
SOURCE_DISTRIBUTION,
ImmutableList.of(planNode.getId()),
new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()),
StageExecutionStrategy.ungroupedExecution(),
StatsAndCosts.empty());
}
}

0 comments on commit 4d4b0bb

Please sign in to comment.