Skip to content

Commit

Permalink
[SPARK-6980] Akka ask timeout description refactored to RPC layer
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed May 15, 2015
1 parent 97dee31 commit 97523e0
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 51 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/deploy/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
println("... waiting before polling master for driver state")
Thread.sleep(5000)
println("... polling master for driver state")
val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout)
val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout.duration)
.mapTo[DriverStatusResponse]
val statusResponse = Await.result(statusFuture, timeout)
val statusResponse = timeout.awaitResult(statusFuture)
statusResponse.found match {
case false =>
println(s"ERROR: Cluster master did not recognize $driverId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,9 @@ private[spark] class AppClient(
if (actor != null) {
try {
val timeout = RpcUtils.askTimeout(conf)
val future = actor.ask(StopAppClient)(timeout)
Await.result(future, timeout)
val future = actor.ask(StopAppClient)(timeout.duration)
// TODO(bryanc) - RpcTimeout use awaitResult ???
Await.result(future, timeout.duration)
} catch {
case e: TimeoutException =>
logInfo("Stop request to Master timed out; it may already be shut down.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.text.SimpleDateFormat
import java.util.Date

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
Expand Down Expand Up @@ -940,8 +939,8 @@ private[deploy] object Master extends Logging {
val actor = actorSystem.actorOf(
Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName)
val timeout = RpcUtils.askTimeout(conf)
val portsRequest = actor.ask(BoundPortsRequest)(timeout)
val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse]
val portsRequest = actor.ask(BoundPortsRequest)(timeout.duration)
val portsResponse = timeout.awaitResult(portsRequest).asInstanceOf[BoundPortsResponse]
(actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.deploy.master.ui

import javax.servlet.http.HttpServletRequest

import scala.concurrent.Await
import scala.xml.Node

import akka.pattern.ask
Expand All @@ -38,8 +37,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
val state = Await.result(stateFuture, timeout)
val stateFuture = (master ? RequestMasterState)(timeout.duration).mapTo[MasterStateResponse]
val state = timeout.awaitResult(stateFuture)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.deploy.master.ui

import javax.servlet.http.HttpServletRequest

import scala.concurrent.Await
import scala.xml.Node

import akka.pattern.ask
Expand All @@ -36,8 +35,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private val timeout = parent.timeout

def getMasterState: MasterStateResponse = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
Await.result(stateFuture, timeout)
val stateFuture = (master ? RequestMasterState)(timeout.duration).mapTo[MasterStateResponse]
timeout.awaitResult(stateFuture)
}

override def renderJson(request: HttpServletRequest): JValue = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.deploy.worker.ui

import scala.concurrent.Await
import scala.xml.Node

import akka.pattern.ask
Expand All @@ -36,14 +35,14 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
private val timeout = parent.timeout

override def renderJson(request: HttpServletRequest): JValue = {
val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
val workerState = Await.result(stateFuture, timeout)
val stateFuture = (workerActor ? RequestWorkerState)(timeout.duration).mapTo[WorkerStateResponse]
val workerState = timeout.awaitResult(stateFuture)
JsonProtocol.writeWorkerState(workerState)
}

def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
val workerState = Await.result(stateFuture, timeout)
val stateFuture = (workerActor ? RequestWorkerState)(timeout.duration).mapTo[WorkerStateResponse]
val workerState = timeout.awaitResult(stateFuture)

val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs")
val runningExecutors = workerState.executors
Expand Down
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.rpc

import scala.concurrent.{Await, Future}
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.Future
import scala.reflect.ClassTag

import org.apache.spark.util.RpcUtils
Expand Down Expand Up @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
*
* This method only sends the message once and never retries.
*/
def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T]
def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
Expand Down Expand Up @@ -91,15 +90,15 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = {
def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
var attempts = 0
var lastException: Exception = null
while (attempts < maxRetries) {
attempts += 1
try {
val future = ask[T](message, timeout)
val result = Await.result(future, timeout)
val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("Actor returned null")
}
Expand All @@ -110,7 +109,10 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
lastException = e
logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
}
Thread.sleep(retryWaitMs)

if (attempts < maxRetries) {
Thread.sleep(retryWaitMs)
}
}

throw new SparkException(
Expand Down
70 changes: 69 additions & 1 deletion core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package org.apache.spark.rpc

import java.net.URI
import java.util.concurrent.TimeoutException

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.language.postfixOps

Expand Down Expand Up @@ -94,7 +97,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action.
*/
def setupEndpointRefByURI(uri: String): RpcEndpointRef = {
Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout)
Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout.duration)
}

/**
Expand Down Expand Up @@ -182,3 +185,68 @@ private[spark] object RpcAddress {
RpcAddress(host, port)
}
}


/**
* Associates a timeout with a configuration property so that a TimeoutException can be
* traced back to the controlling property.
* @param timeout timeout duration in seconds
* @param description description to be displayed in a timeout exception
*/
private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {

/** Get the timeout duration */
def duration: FiniteDuration = timeout

/** Get the message associated with this timeout */
def message: String = description

/** Amends the standard message of TimeoutException to include the description */
def amend(te: TimeoutException): TimeoutException = {
new TimeoutException(te.getMessage() + " " + description)
}

/** Wait on a future result to catch and amend a TimeoutException */
def awaitResult[T](future: Future[T]): T = {
try {
Await.result(future, duration)
}
catch {
case te: TimeoutException =>
throw amend(te)
}
}

// TODO(bryanc) wrap Await.ready also
}

object RpcTimeout {

private[this] val messagePrefix = "This timeout is controlled by "

/**
* Lookup the timeout property in the configuration and create
* a RpcTimeout with the property key in the description.
* @param conf configuration properties containing the timeout
* @param timeoutProp property key for the timeout in seconds
* @throws NoSuchElementException if property is not set
*/
def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = {
val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds }
new RpcTimeout(timeout, messagePrefix + timeoutProp)
}

/**
* Lookup the timeout property in the configuration and create
* a RpcTimeout with the property key in the description.
* Uses the given default value if property is not set
* @param conf configuration properties containing the timeout
* @param timeoutProp property key for the timeout in seconds
* @param defaultValue default timeout value in seconds if property not found
*/
def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = {
val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds }
new RpcTimeout(timeout, messagePrefix + timeoutProp)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka
import java.util.concurrent.ConcurrentHashMap

import scala.concurrent.Future
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
Expand Down Expand Up @@ -212,7 +211,7 @@ private[spark] class AkkaRpcEnv private[akka] (

override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
import actorSystem.dispatcher
actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout).
actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
}

Expand Down Expand Up @@ -293,9 +292,9 @@ private[akka] class AkkaRpcEndpointRef(
actorRef ! AkkaMessage(message, false)
}

override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = {
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
import scala.concurrent.ExecutionContext.Implicits.global
actorRef.ask(AkkaMessage(message, true))(timeout).flatMap {
actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.storage

import scala.concurrent.{Await, Future}
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global

import org.apache.spark.rpc.RpcEndpointRef
Expand Down Expand Up @@ -105,7 +105,7 @@ class BlockManagerMaster(
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}")
}
if (blocking) {
Await.result(future, timeout)
timeout.awaitResult(future)
}
}

Expand All @@ -117,7 +117,7 @@ class BlockManagerMaster(
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}")
}
if (blocking) {
Await.result(future, timeout)
timeout.awaitResult(future)
}
}

Expand All @@ -131,7 +131,7 @@ class BlockManagerMaster(
s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}")
}
if (blocking) {
Await.result(future, timeout)
timeout.awaitResult(future)
}
}

Expand Down Expand Up @@ -169,7 +169,7 @@ class BlockManagerMaster(
val response = driverEndpoint.
askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
val result = Await.result(Future.sequence(futures), timeout)
val result = timeout.awaitResult(Future.sequence(futures))
if (result == null) {
throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
}
Expand All @@ -192,7 +192,7 @@ class BlockManagerMaster(
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg)
Await.result(future, timeout)
timeout.awaitResult(future)
}

/** Stop the driver endpoint, called only on the Spark driver node */
Expand Down
16 changes: 8 additions & 8 deletions core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.util

import org.apache.spark.rpc.RpcTimeout

import scala.collection.JavaConversions.mapAsJavaMap
import scala.concurrent.Await
import scala.concurrent.duration.FiniteDuration

import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem}
import akka.pattern.ask
Expand Down Expand Up @@ -147,7 +147,7 @@ private[spark] object AkkaUtils extends Logging {
def askWithReply[T](
message: Any,
actor: ActorRef,
timeout: FiniteDuration): T = {
timeout: RpcTimeout): T = {
askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout)
}

Expand All @@ -160,7 +160,7 @@ private[spark] object AkkaUtils extends Logging {
actor: ActorRef,
maxAttempts: Int,
retryInterval: Long,
timeout: FiniteDuration): T = {
timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
if (actor == null) {
throw new SparkException(s"Error sending message [message = $message]" +
Expand All @@ -171,8 +171,8 @@ private[spark] object AkkaUtils extends Logging {
while (attempts < maxAttempts) {
attempts += 1
try {
val future = actor.ask(message)(timeout)
val result = Await.result(future, timeout)
val future = actor.ask(message)(timeout.duration)
val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("Actor returned null")
}
Expand Down Expand Up @@ -200,7 +200,7 @@ private[spark] object AkkaUtils extends Logging {
val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name)
val timeout = RpcUtils.lookupTimeout(conf)
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
}

def makeExecutorRef(
Expand All @@ -214,7 +214,7 @@ private[spark] object AkkaUtils extends Logging {
val url = address(protocol(actorSystem), executorActorSystemName, host, port, name)
val timeout = RpcUtils.lookupTimeout(conf)
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration))
}

def protocol(actorSystem: ActorSystem): String = {
Expand Down
Loading

0 comments on commit 97523e0

Please sign in to comment.