Skip to content

Commit

Permalink
Change StateMachine to use CompletableFuture instead of ListenableFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Aug 18, 2015
1 parent 3d1f505 commit e5f1a75
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 36 deletions.
Expand Up @@ -18,6 +18,8 @@
import com.facebook.presto.metadata.Split;
import com.facebook.presto.sql.planner.plan.PlanNodeId;

import java.util.concurrent.CompletableFuture;

public interface RemoteTask
{
String getNodeId();
Expand All @@ -34,6 +36,8 @@ public interface RemoteTask

void addStateChangeListener(StateChangeListener<TaskInfo> stateChangeListener);

CompletableFuture<TaskInfo> getStateChange(TaskInfo taskInfo);

void cancel();

void abort();
Expand Down
Expand Up @@ -45,6 +45,7 @@
import static com.facebook.presto.util.Failures.toFailures;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.airlift.concurrent.MoreFutures.toListenableFuture;
import static java.util.Objects.requireNonNull;

public class SqlTask
Expand Down Expand Up @@ -231,7 +232,7 @@ public ListenableFuture<TaskInfo> getTaskInfo(TaskState callersCurrentState)
return Futures.immediateFuture(getTaskInfo());
}

ListenableFuture<TaskState> futureTaskState = taskStateMachine.getStateChange(callersCurrentState);
ListenableFuture<TaskState> futureTaskState = toListenableFuture(taskStateMachine.getStateChange(callersCurrentState));
return Futures.transform(futureTaskState, (TaskState input) -> getTaskInfo());
}

Expand Down
Expand Up @@ -17,10 +17,6 @@
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.log.Logger;
import io.airlift.units.Duration;

Expand All @@ -31,6 +27,8 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -59,7 +57,7 @@ public class StateMachine<T>
private final List<StateChangeListener<T>> stateChangeListeners = new ArrayList<>();

@GuardedBy("lock")
private final Set<SettableFuture<T>> futureStateChanges = newIdentityHashSet();
private final Set<CompletableFuture<T>> futureStateChanges = newIdentityHashSet();

/**
* Creates a state machine with the specified initial state and no terminal states.
Expand Down Expand Up @@ -108,7 +106,7 @@ public T set(T newState)
requireNonNull(newState, "newState is null");

T oldState;
ImmutableList<SettableFuture<T>> futureStateChanges;
ImmutableList<CompletableFuture<T>> futureStateChanges;
ImmutableList<StateChangeListener<T>> stateChangeListeners;
synchronized (lock) {
if (state.equals(newState)) {
Expand Down Expand Up @@ -180,7 +178,7 @@ public boolean compareAndSet(T expectedState, T newState)
requireNonNull(expectedState, "expectedState is null");
requireNonNull(newState, "newState is null");

ImmutableList<SettableFuture<T>> futureStateChanges;
ImmutableList<CompletableFuture<T>> futureStateChanges;
ImmutableList<StateChangeListener<T>> stateChangeListeners;
synchronized (lock) {
if (!state.equals(expectedState)) {
Expand Down Expand Up @@ -212,16 +210,16 @@ public boolean compareAndSet(T expectedState, T newState)
return true;
}

private void fireStateChanged(T newState, List<SettableFuture<T>> futureStateChanges, List<StateChangeListener<T>> stateChangeListeners)
private void fireStateChanged(T newState, List<CompletableFuture<T>> futureStateChanges, List<StateChangeListener<T>> stateChangeListeners)
{
checkState(!Thread.holdsLock(lock), "Can not fire state change event while holding the lock");
requireNonNull(newState, "newState is null");

executor.execute(() -> {
checkState(!Thread.holdsLock(lock), "Can not notify while holding the lock");
for (SettableFuture<T> futureStateChange : futureStateChanges) {
for (CompletableFuture<T> futureStateChange : futureStateChanges) {
try {
futureStateChange.set(newState);
futureStateChange.complete(newState);
}
catch (Throwable e) {
log.error(e, "Error setting future state for %s", name);
Expand All @@ -241,32 +239,23 @@ private void fireStateChanged(T newState, List<SettableFuture<T>> futureStateCha
/**
* Gets a future that completes when the state is no longer {@code .equals()} to {@code currentState)}
*/
public ListenableFuture<T> getStateChange(T currentState)
public CompletableFuture<T> getStateChange(T currentState)
{
checkState(!Thread.holdsLock(lock), "Can not wait for state change while holding the lock");
requireNonNull(currentState, "currentState is null");

synchronized (lock) {
// return a completed future if the state has already changed, or we are in a terminal state
if (!isPossibleStateChange(currentState)) {
return Futures.immediateFuture(state);
return CompletableFuture.completedFuture(state);
}

SettableFuture<T> futureStateChange = SettableFuture.create();
CompletableFuture<T> futureStateChange = new CompletableFuture<>();
futureStateChanges.add(futureStateChange);
Futures.addCallback(futureStateChange, new FutureCallback<T>()
{
@Override
public void onSuccess(T result)
{
// no-op. The futureStateChanges list is already cleared before fireStateChanged is called.
}

@Override
public void onFailure(Throwable t)
{
// Remove the Future early, in case it's cancelled.
synchronized (lock) {
futureStateChange.whenComplete((value, throwable) -> {
// Remove the Future early, in case it's cancelled.
if (throwable instanceof CancellationException) {
synchronized (StateMachine.this) {
futureStateChanges.remove(futureStateChange);
}
}
Expand Down Expand Up @@ -347,7 +336,7 @@ synchronized List<StateChangeListener<T>> getStateChangeListeners()
}

@VisibleForTesting
synchronized Set<SettableFuture<T>> getFutureStateChanges()
synchronized Set<CompletableFuture<T>> getFutureStateChanges()
{
return ImmutableSet.copyOf(futureStateChanges);
}
Expand Down
Expand Up @@ -14,14 +14,13 @@
package com.facebook.presto.execution;

import com.facebook.presto.execution.StateMachine.StateChangeListener;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import org.joda.time.DateTime;

import javax.annotation.concurrent.ThreadSafe;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;

Expand Down Expand Up @@ -70,15 +69,15 @@ public TaskState getState()
return taskState.get();
}

public ListenableFuture<TaskState> getStateChange(TaskState currentState)
public CompletableFuture<TaskState> getStateChange(TaskState currentState)
{
checkNotNull(currentState, "currentState is null");
checkArgument(!currentState.isDone(), "Current state is already done");

ListenableFuture<TaskState> future = taskState.getStateChange(currentState);
CompletableFuture<TaskState> future = taskState.getStateChange(currentState);
TaskState state = taskState.get();
if (state.isDone()) {
return Futures.immediateFuture(state);
return CompletableFuture.completedFuture(state);
}
return future;
}
Expand Down
Expand Up @@ -73,6 +73,7 @@
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -309,6 +310,12 @@ public void addStateChangeListener(StateChangeListener<TaskInfo> stateChangeList
}
}

@Override
public CompletableFuture<TaskInfo> getStateChange(TaskInfo taskInfo)
{
return this.taskInfo.getStateChange(taskInfo);
}

private synchronized void updateTaskInfo(TaskInfo newValue)
{
updateTaskInfo(newValue, ImmutableList.of());
Expand Down
Expand Up @@ -49,6 +49,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Stream;
Expand Down Expand Up @@ -224,6 +225,12 @@ public void addStateChangeListener(StateChangeListener<TaskInfo> stateChangeList
taskStateMachine.addStateChangeListener(newValue -> stateChangeListener.stateChanged(getTaskInfo()));
}

@Override
public CompletableFuture<TaskInfo> getStateChange(TaskInfo taskInfo)
{
return taskStateMachine.getStateChange(taskInfo.getState()).thenApply(ignored -> getTaskInfo());
}

@Override
public void cancel()
{
Expand Down
Expand Up @@ -474,6 +474,12 @@ public void addStateChangeListener(StateChangeListener<TaskInfo> stateChangeList
taskStateMachine.addStateChangeListener(newValue -> stateChangeListener.stateChanged(getTaskInfo()));
}

@Override
public CompletableFuture<TaskInfo> getStateChange(TaskInfo taskInfo)
{
return taskStateMachine.getStateChange(taskInfo.getState()).thenApply(ignored -> getTaskInfo());
}

@Override
public void cancel()
{
Expand Down
Expand Up @@ -15,7 +15,6 @@

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.units.Duration;
import org.testng.annotations.AfterClass;
Expand Down Expand Up @@ -228,7 +227,7 @@ private void assertStateChange(StateMachine<State> stateMachine, StateChanger st
throws Exception
{
State initialState = stateMachine.get();
ListenableFuture<State> futureChange = stateMachine.getStateChange(initialState);
Future<State> futureChange = stateMachine.getStateChange(initialState);

SettableFuture<State> listenerChange = SettableFuture.create();
stateMachine.addStateChangeListener(listenerChange::set);
Expand Down Expand Up @@ -267,7 +266,7 @@ private void assertNoStateChange(StateMachine<State> stateMachine, StateChanger
throws Exception
{
State initialState = stateMachine.get();
ListenableFuture<State> futureChange = stateMachine.getStateChange(initialState);
Future<State> futureChange = stateMachine.getStateChange(initialState);

SettableFuture<State> listenerChange = SettableFuture.create();
stateMachine.addStateChangeListener(listenerChange::set);
Expand Down

0 comments on commit e5f1a75

Please sign in to comment.