diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index acf13b8deb1fa..8efe690e78216 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -26,7 +26,7 @@ import scala.concurrent.{Awaitable, Await, Future} import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{ThreadUtils, RpcUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** @@ -190,8 +190,8 @@ private[spark] object RpcAddress { /** * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. */ -private[rpc] class RpcTimeoutException(message: String) - extends TimeoutException(message) +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } /** @@ -209,27 +209,23 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) { def message: String = description /** Amends the standard message of TimeoutException to include the description */ - def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { - new RpcTimeoutException(te.getMessage() + " " + description) + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + " " + description, te) } /** - * Add a callback to the given Future so that if it completes as failed with a TimeoutException - * then the timeout description is added to the message + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) */ - def addMessageIfTimeout[T](future: Future[T]): Future[T] = { - future.recover { - // Add a warning message if Future is passed to addMessageIfTimeoutTest more than once - case rte: RpcTimeoutException => throw new RpcTimeoutException(rte.getMessage() + - " (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)") - // Any other TimeoutException get converted to a RpcTimeoutException with modified message - case te: TimeoutException => throw createRpcTimeoutException(te) - }(ThreadUtils.sameThread) - } - - /** Applies the duration to create future before calling addMessageIfTimeout*/ - def addMessageIfTimeout[T](f: FiniteDuration => Future[T]): Future[T] = { - addMessageIfTimeout(f(duration)) + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) } /** @@ -241,13 +237,7 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) { def awaitResult[T](awaitable: Awaitable[T]): T = { try { Await.result(awaitable, duration) - } - catch { - // The exception has already been converted to a RpcTimeoutException so just raise it - case rte: RpcTimeoutException => throw rte - // Any other TimeoutException get converted to a RpcTimeoutException with modified message - case te: TimeoutException => throw createRpcTimeoutException(te) - } + } catch addMessageIfTimeout } } @@ -299,13 +289,10 @@ object RpcTimeout { // Find the first set property or use the default value with the first property val itr = timeoutPropList.iterator - var foundProp = None: Option[(String, String)] + var foundProp: Option[(String, String)] = None while (itr.hasNext && foundProp.isEmpty){ val propKey = itr.next() - conf.getOption(propKey) match { - case Some(prop) => foundProp = Some(propKey,prop) - case None => - } + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } } val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f8ad6bbf60d61..cd032a3301b95 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -213,10 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - defaultLookupTimeout.addMessageIfTimeout( - actorSystem.actorSelection(uri).resolveOne(_). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) - ) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -297,20 +298,19 @@ private[akka] class AkkaRpcEndpointRef( } override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - timeout.addMessageIfTimeout( - actorRef.ask(AkkaMessage(message, true))(_).flatMap { - // The function will run in the calling thread, so it should be short and never block. - case msg @ AkkaMessage(message, reply) => - if (reply) { - logError(s"Receive $msg but the sender cannot reply") - Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) - } else { - Future.successful(message) - } - case AkkaFailure(e) => - Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T] - ) + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { + // The function will run in the calling thread, so it should be short and never block. + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)"