Skip to content

Commit

Permalink
[SPARK-6980] Added addMessageIfTimeout for when a Future is completed…
Browse files Browse the repository at this point in the history
… with TimeoutException
  • Loading branch information
BryanCutler committed Jun 4, 2015
1 parent 235919b commit 2f94095
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
50 changes: 42 additions & 8 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.util.concurrent.TimeoutException

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.concurrent.{Awaitable, Await, Future}
import scala.language.postfixOps

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{RpcUtils, Utils}
import org.apache.spark.util.{ThreadUtils, RpcUtils, Utils}


/**
Expand Down Expand Up @@ -187,6 +187,13 @@ private[spark] object RpcAddress {
}


/**
* An exception thrown if RpcTimeout modifies a [[TimeoutException]].
*/
private[rpc] class RpcTimeoutException(message: String)
extends TimeoutException(message)


/**
* Associates a timeout with a description so that a when a TimeoutException occurs, additional
* context about the timeout can be amended to the exception message.
Expand All @@ -202,17 +209,44 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
def message: String = description

/** Amends the standard message of TimeoutException to include the description */
def amend(te: TimeoutException): TimeoutException = {
new TimeoutException(te.getMessage() + " " + description)
def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
new RpcTimeoutException(te.getMessage() + " " + description)
}

/**
* Add a callback to the given Future so that if it completes as failed with a TimeoutException
* then the timeout description is added to the message
*/
def addMessageIfTimeout[T](future: Future[T]): Future[T] = {
future.recover {
// Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
case rte: RpcTimeoutException => throw new RpcTimeoutException(rte.getMessage() +
" (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)")
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
case te: TimeoutException => throw createRpcTimeoutException(te)
}(ThreadUtils.sameThread)
}

/** Applies the duration to create future before calling addMessageIfTimeout*/
def addMessageIfTimeout[T](f: FiniteDuration => Future[T]): Future[T] = {
addMessageIfTimeout(f(duration))
}

/** Wait on a future result to catch and amend a TimeoutException */
def awaitResult[T](future: Future[T]): T = {
/**
* Waits for a completed result to catch and amend a TimeoutException message
* @param awaitable the `Awaitable` to be awaited
* @throws RpcTimeoutException if after waiting for the specified time `awaitable`
* is still not ready
*/
def awaitResult[T](awaitable: Awaitable[T]): T = {
try {
Await.result(future, duration)
Await.result(awaitable, duration)
}
catch {
case te: TimeoutException => throw amend(te)
// The exception has already been converted to a RpcTimeoutException so just raise it
case rte: RpcTimeoutException => throw rte
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
case te: TimeoutException => throw createRpcTimeoutException(te)
}
}
}
Expand Down
32 changes: 18 additions & 14 deletions core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,10 @@ private[spark] class AkkaRpcEnv private[akka] (

override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
import actorSystem.dispatcher
actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
defaultLookupTimeout.addMessageIfTimeout(
actorSystem.actorSelection(uri).resolveOne(_).
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
)
}

override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
Expand Down Expand Up @@ -295,18 +297,20 @@ private[akka] class AkkaRpcEndpointRef(
}

override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
} else {
Future.successful(message)
}
case AkkaFailure(e) =>
Future.failed(e)
}(ThreadUtils.sameThread).mapTo[T]
timeout.addMessageIfTimeout(
actorRef.ask(AkkaMessage(message, true))(_).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
} else {
Future.successful(message)
}
case AkkaFailure(e) =>
Future.failed(e)
}(ThreadUtils.sameThread).mapTo[T]
)
}

override def toString: String = s"${getClass.getSimpleName}($actorRef)"
Expand Down

0 comments on commit 2f94095

Please sign in to comment.