From 2f9409599940e582feeb376f4f93c419d579ee7d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 4 Jun 2015 16:03:17 -0700 Subject: [PATCH] [SPARK-6980] Added addMessageIfTimeout for when a Future is completed with TimeoutException --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 50 ++++++++++++++++--- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 32 ++++++------ 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 48e315249ffa6..acf13b8deb1fa 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -22,11 +22,11 @@ import java.util.concurrent.TimeoutException import scala.concurrent.duration.FiniteDuration import scala.concurrent.duration._ -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, RpcUtils, Utils} /** @@ -187,6 +187,13 @@ private[spark] object RpcAddress { } +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String) + extends TimeoutException(message) + + /** * Associates a timeout with a description so that a when a TimeoutException occurs, additional * context about the timeout can be amended to the exception message. @@ -202,17 +209,44 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) { def message: String = description /** Amends the standard message of TimeoutException to include the description */ - def amend(te: TimeoutException): TimeoutException = { - new TimeoutException(te.getMessage() + " " + description) + def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + " " + description) + } + + /** + * Add a callback to the given Future so that if it completes as failed with a TimeoutException + * then the timeout description is added to the message + */ + def addMessageIfTimeout[T](future: Future[T]): Future[T] = { + future.recover { + // Add a warning message if Future is passed to addMessageIfTimeoutTest more than once + case rte: RpcTimeoutException => throw new RpcTimeoutException(rte.getMessage() + + " (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)") + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + }(ThreadUtils.sameThread) + } + + /** Applies the duration to create future before calling addMessageIfTimeout*/ + def addMessageIfTimeout[T](f: FiniteDuration => Future[T]): Future[T] = { + addMessageIfTimeout(f(duration)) } - /** Wait on a future result to catch and amend a TimeoutException */ - def awaitResult[T](future: Future[T]): T = { + /** + * Waits for a completed result to catch and amend a TimeoutException message + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { try { - Await.result(future, duration) + Await.result(awaitable, duration) } catch { - case te: TimeoutException => throw amend(te) + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 8bba16874ca2c..f8ad6bbf60d61 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -213,8 +213,10 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + defaultLookupTimeout.addMessageIfTimeout( + actorSystem.actorSelection(uri).resolveOne(_). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + ) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -295,18 +297,20 @@ private[akka] class AkkaRpcEndpointRef( } override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { - // The function will run in the calling thread, so it should be short and never block. - case msg @ AkkaMessage(message, reply) => - if (reply) { - logError(s"Receive $msg but the sender cannot reply") - Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) - } else { - Future.successful(message) - } - case AkkaFailure(e) => - Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T] + timeout.addMessageIfTimeout( + actorRef.ask(AkkaMessage(message, true))(_).flatMap { + // The function will run in the calling thread, so it should be short and never block. + case msg @ AkkaMessage(message, reply) => + if (reply) { + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }(ThreadUtils.sameThread).mapTo[T] + ) } override def toString: String = s"${getClass.getSimpleName}($actorRef)"