Skip to content

Commit

Permalink
Send plan in update request only if required
Browse files Browse the repository at this point in the history
We need the plan fragment only for creating the SqlTaskExecution. Plans
for multistage queries can be huge and increase the size of the update
request. Change this so we send the plan only if it is required.
  • Loading branch information
nileema authored and dain committed Feb 14, 2016
1 parent 30cedfc commit 71666f5
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 28 deletions.
Expand Up @@ -34,15 +34,18 @@

import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static com.facebook.presto.util.Failures.toFailures;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;

Expand All @@ -63,6 +66,7 @@ public class SqlTask
private final AtomicLong nextTaskInfoVersion = new AtomicLong(TaskInfo.STARTING_VERSION);

private final AtomicReference<TaskHolder> taskHolderReference = new AtomicReference<>(new TaskHolder());
private final AtomicBoolean needsPlan = new AtomicBoolean(true);

public SqlTask(
TaskId taskId,
Expand Down Expand Up @@ -214,7 +218,8 @@ private TaskInfo createTaskInfo(TaskHolder taskHolder)
sharedBuffer.getInfo(),
noMoreSplits,
taskStats,
failures);
failures,
needsPlan.get());
}

public CompletableFuture<TaskInfo> getTaskInfo(TaskState callersCurrentState)
Expand All @@ -233,7 +238,7 @@ public CompletableFuture<TaskInfo> getTaskInfo(TaskState callersCurrentState)
return futureTaskState.thenApply(input -> getTaskInfo());
}

public TaskInfo updateTask(Session session, PlanFragment fragment, List<TaskSource> sources, OutputBuffers outputBuffers)
public TaskInfo updateTask(Session session, Optional<PlanFragment> fragment, List<TaskSource> sources, OutputBuffers outputBuffers)
{
try {
// assure the task execution is only created once
Expand All @@ -246,8 +251,10 @@ public TaskInfo updateTask(Session session, PlanFragment fragment, List<TaskSour
}
taskExecution = taskHolder.getTaskExecution();
if (taskExecution == null) {
taskExecution = sqlTaskExecutionFactory.create(session, queryContext, taskStateMachine, sharedBuffer, fragment, sources);
checkState(fragment.isPresent(), "fragment must be present");
taskExecution = sqlTaskExecutionFactory.create(session, queryContext, taskStateMachine, sharedBuffer, fragment.get(), sources);
taskHolderReference.compareAndSet(taskHolder, new TaskHolder(taskExecution));
needsPlan.set(false);
}
}

Expand Down
Expand Up @@ -47,6 +47,7 @@

import java.io.Closeable;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -251,7 +252,7 @@ public CompletableFuture<TaskInfo> getTaskInfo(TaskId taskId, TaskState currentS
}

@Override
public TaskInfo updateTask(Session session, TaskId taskId, PlanFragment fragment, List<TaskSource> sources, OutputBuffers outputBuffers)
public TaskInfo updateTask(Session session, TaskId taskId, Optional<PlanFragment> fragment, List<TaskSource> sources, OutputBuffers outputBuffers)
{
requireNonNull(session, "session is null");
requireNonNull(taskId, "taskId is null");
Expand Down
Expand Up @@ -59,6 +59,7 @@ public class TaskInfo
private final Set<PlanNodeId> noMoreSplits;
private final TaskStats stats;
private final List<ExecutionFailureInfo> failures;
private final boolean needsPlan;

@JsonCreator
public TaskInfo(@JsonProperty("taskId") TaskId taskId,
Expand All @@ -70,7 +71,8 @@ public TaskInfo(@JsonProperty("taskId") TaskId taskId,
@JsonProperty("outputBuffers") SharedBufferInfo outputBuffers,
@JsonProperty("noMoreSplits") Set<PlanNodeId> noMoreSplits,
@JsonProperty("stats") TaskStats stats,
@JsonProperty("failures") List<ExecutionFailureInfo> failures)
@JsonProperty("failures") List<ExecutionFailureInfo> failures,
@JsonProperty("needsPlan") boolean needsPlan)
{
this.taskId = requireNonNull(taskId, "taskId is null");
this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null");
Expand All @@ -89,6 +91,7 @@ public TaskInfo(@JsonProperty("taskId") TaskId taskId,
else {
this.failures = ImmutableList.of();
}
this.needsPlan = needsPlan;
}

@JsonProperty
Expand Down Expand Up @@ -151,9 +154,15 @@ public List<ExecutionFailureInfo> getFailures()
return failures;
}

@JsonProperty
public boolean isNeedsPlan()
{
return needsPlan;
}

public TaskInfo summarize()
{
return new TaskInfo(taskId, taskInstanceId, version, state, self, lastHeartbeat, outputBuffers, noMoreSplits, stats.summarize(), failures);
return new TaskInfo(taskId, taskInstanceId, version, state, self, lastHeartbeat, outputBuffers, noMoreSplits, stats.summarize(), failures, needsPlan);
}

@Override
Expand Down
Expand Up @@ -22,6 +22,7 @@
import io.airlift.units.DataSize;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

public interface TaskManager
Expand Down Expand Up @@ -58,7 +59,7 @@ public interface TaskManager
* Updates the task plan, sources and output buffers. If the task does not
* already exist, is is created and then updated.
*/
TaskInfo updateTask(Session session, TaskId taskId, PlanFragment fragment, List<TaskSource> sources, OutputBuffers outputBuffers);
TaskInfo updateTask(Session session, TaskId taskId, Optional<PlanFragment> fragment, List<TaskSource> sources, OutputBuffers outputBuffers);

/**
* Cancels a task. If the task does not already exist, is is created and then
Expand Down
Expand Up @@ -70,6 +70,7 @@
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CancellationException;
Expand Down Expand Up @@ -154,6 +155,7 @@ public final class HttpRemoteTask
private final RequestErrorTracker getErrorTracker;

private final AtomicBoolean needsUpdate = new AtomicBoolean(true);
private final AtomicBoolean sendPlan = new AtomicBoolean(true);

private final PartitionedSplitCountTracker partitionedSplitCountTracker;

Expand Down Expand Up @@ -228,7 +230,8 @@ public HttpRemoteTask(Session session,
new SharedBufferInfo(BufferState.OPEN, true, true, 0, 0, 0, 0, bufferStates),
ImmutableSet.<PlanNodeId>of(),
taskStats,
ImmutableList.<ExecutionFailureInfo>of()));
ImmutableList.<ExecutionFailureInfo>of(),
true));

long timeout = minErrorDuration.toMillis() / 3;
requestTimeout = new Duration(timeout + refreshMaxWait.toMillis(), MILLISECONDS);
Expand Down Expand Up @@ -441,8 +444,13 @@ private synchronized void scheduleUpdate()
}

List<TaskSource> sources = getSources();

Optional<PlanFragment> fragment = Optional.empty();
if (sendPlan.get()) {
fragment = Optional.of(planFragment);
}
TaskUpdateRequest updateRequest = new TaskUpdateRequest(session.toSessionRepresentation(),
planFragment,
fragment,
sources,
outputBuffers.get());

Expand Down Expand Up @@ -537,7 +545,8 @@ public synchronized void abort()
taskInfo.getOutputBuffers(),
taskInfo.getNoMoreSplits(),
taskInfo.getStats(),
ImmutableList.<ExecutionFailureInfo>of()));
ImmutableList.<ExecutionFailureInfo>of(),
taskInfo.isNeedsPlan()));

// send abort to task and ignore response
Request request = prepareDelete()
Expand Down Expand Up @@ -601,7 +610,8 @@ private void failTask(Throwable cause)
taskInfo.getOutputBuffers(),
taskInfo.getNoMoreSplits(),
taskInfo.getStats(),
ImmutableList.of(toFailure(cause))));
ImmutableList.of(toFailure(cause)),
taskInfo.isNeedsPlan()));
}

@Override
Expand Down Expand Up @@ -629,6 +639,7 @@ public void success(TaskInfo value)
try {
synchronized (HttpRemoteTask.this) {
currentRequest = null;
sendPlan.set(value.isNeedsPlan());
}
updateTaskInfo(value, sources);
updateErrorTracker.requestSucceeded();
Expand Down
Expand Up @@ -22,21 +22,22 @@
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public class TaskUpdateRequest
{
private final SessionRepresentation session;
private final PlanFragment fragment;
private final Optional<PlanFragment> fragment;
private final List<TaskSource> sources;
private final OutputBuffers outputIds;

@JsonCreator
public TaskUpdateRequest(
@JsonProperty("session") SessionRepresentation session,
@JsonProperty("fragment") PlanFragment fragment,
@JsonProperty("fragment") Optional<PlanFragment> fragment,
@JsonProperty("sources") List<TaskSource> sources,
@JsonProperty("outputIds") OutputBuffers outputIds)
{
Expand All @@ -58,7 +59,7 @@ public SessionRepresentation getSession()
}

@JsonProperty
public PlanFragment getFragment()
public Optional<PlanFragment> getFragment()
{
return fragment;
}
Expand Down
Expand Up @@ -211,7 +211,8 @@ public TaskInfo getTaskInfo()
sharedBuffer.getInfo(),
ImmutableSet.<PlanNodeId>of(),
taskContext.getTaskStats(),
failures);
failures,
true);
}

public synchronized void finishSplits(int splits)
Expand Down
Expand Up @@ -123,6 +123,6 @@ public static LocalExecutionPlanner createTestingPlanner()

public static TaskInfo updateTask(SqlTask sqlTask, List<TaskSource> taskSources, OutputBuffers outputBuffers)
{
return sqlTask.updateTask(TEST_SESSION, PLAN_FRAGMENT, taskSources, outputBuffers);
return sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), taskSources, outputBuffers);
}
}
Expand Up @@ -34,6 +34,7 @@
import org.testng.annotations.Test;

import java.net.URI;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -101,7 +102,7 @@ public void testEmptyQuery()
SqlTask sqlTask = createInitialTask();

TaskInfo taskInfo = sqlTask.updateTask(TEST_SESSION,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.<TaskSource>of(),
INITIAL_EMPTY_OUTPUT_BUFFERS);
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand All @@ -110,7 +111,7 @@ public void testEmptyQuery()
assertEquals(taskInfo.getState(), TaskState.RUNNING);

taskInfo = sqlTask.updateTask(TEST_SESSION,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.<ScheduledSplit>of(), true)),
INITIAL_EMPTY_OUTPUT_BUFFERS.withNoMoreBufferIds());
assertEquals(taskInfo.getState(), TaskState.FINISHED);
Expand All @@ -126,7 +127,7 @@ public void testSimpleQuery()
SqlTask sqlTask = createInitialTask();

TaskInfo taskInfo = sqlTask.updateTask(TEST_SESSION,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)),
INITIAL_EMPTY_OUTPUT_BUFFERS.withBuffer(OUT, 0).withNoMoreBufferIds());
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand Down Expand Up @@ -161,7 +162,7 @@ public void testCancel()
SqlTask sqlTask = createInitialTask();

TaskInfo taskInfo = sqlTask.updateTask(TEST_SESSION,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.<TaskSource>of(),
INITIAL_EMPTY_OUTPUT_BUFFERS);
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand All @@ -187,7 +188,7 @@ public void testAbort()
SqlTask sqlTask = createInitialTask();

TaskInfo taskInfo = sqlTask.updateTask(TEST_SESSION,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)),
INITIAL_EMPTY_OUTPUT_BUFFERS.withBuffer(OUT, 0).withNoMoreBufferIds());
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand Down
Expand Up @@ -35,6 +35,7 @@
import org.testng.annotations.Test;

import java.net.URI;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static com.facebook.presto.OutputBuffers.INITIAL_EMPTY_OUTPUT_BUFFERS;
Expand Down Expand Up @@ -79,7 +80,7 @@ public void testEmptyQuery()
TaskId taskId = TASK_ID;
TaskInfo taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.<TaskSource>of(),
INITIAL_EMPTY_OUTPUT_BUFFERS);
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand All @@ -89,7 +90,7 @@ public void testEmptyQuery()

taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.<ScheduledSplit>of(), true)),
INITIAL_EMPTY_OUTPUT_BUFFERS.withNoMoreBufferIds());
assertEquals(taskInfo.getState(), TaskState.FINISHED);
Expand All @@ -107,7 +108,7 @@ public void testSimpleQuery()
TaskId taskId = TASK_ID;
TaskInfo taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)),
INITIAL_EMPTY_OUTPUT_BUFFERS.withBuffer(OUT, 0).withNoMoreBufferIds());
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand Down Expand Up @@ -143,7 +144,7 @@ public void testCancel()
TaskId taskId = TASK_ID;
TaskInfo taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.<TaskSource>of(),
INITIAL_EMPTY_OUTPUT_BUFFERS);
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand Down Expand Up @@ -171,7 +172,7 @@ public void testAbort()
TaskId taskId = TASK_ID;
TaskInfo taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.<TaskSource>of(),
INITIAL_EMPTY_OUTPUT_BUFFERS);
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand Down Expand Up @@ -199,7 +200,7 @@ public void testAbortResults()
TaskId taskId = TASK_ID;
TaskInfo taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)),
INITIAL_EMPTY_OUTPUT_BUFFERS.withBuffer(OUT, 0).withNoMoreBufferIds());
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand All @@ -226,7 +227,7 @@ public void testRemoveOldTasks()

TaskInfo taskInfo = sqlTaskManager.updateTask(TEST_SESSION,
taskId,
PLAN_FRAGMENT,
Optional.of(PLAN_FRAGMENT),
ImmutableList.<TaskSource>of(),
INITIAL_EMPTY_OUTPUT_BUFFERS);
assertEquals(taskInfo.getState(), TaskState.RUNNING);
Expand Down

0 comments on commit 71666f5

Please sign in to comment.