From d7f1ef2350a380569f191ca11e662bdbff015d3d Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 15 Jul 2015 13:08:20 -0500 Subject: [PATCH] get rid of activeTaskSets --- .../spark/scheduler/TaskSchedulerImpl.scala | 93 +++++++++++-------- .../org/apache/spark/scheduler/TaskSet.scala | 4 +- .../CoarseGrainedSchedulerBackend.scala | 3 +- 3 files changed, 59 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4eebff8dbb516..ae02ba64020d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToTaskSetId = new HashMap[Long, String] + val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -163,10 +162,9 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager val stage = taskSet.stageId val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) - stageTaskSets(taskSet.attempt) = manager + stageTaskSets(taskSet.stageAttemptId) = manager val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => ts.taskSet != taskSet && !ts.isZombie } @@ -203,19 +201,21 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + taskSetsByStage.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) } } @@ -225,9 +225,8 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage => - taskSetsForStage -= manager.taskSet.attempt + taskSetsForStage -= manager.taskSet.stageAttemptId if (taskSetsForStage.isEmpty) { taskSetsByStage -= manager.taskSet.stageId } @@ -252,7 +251,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToStageIdAndAttempt(tid) = (taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId) taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -336,26 +335,24 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskSetManagerForTask(tid) match { + case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToStageIdAndAttempt.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) - } + if (state == TaskState.FINISHED) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskSet.removeRunningTask(tid) + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") - .format(state, tid)) + "likely the result of receiving duplicate task finished status updates)") + .format(state, tid)) } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -380,9 +377,13 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) - .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) + for { + (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id) + attempts <- taskSetsByStage.get(stageId) + taskSetMgr <- attempts.get(stageAttemptId) + } yield { + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + } } } dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId) @@ -414,9 +415,12 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.nonEmpty) { + if (taskSetsByStage.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for { + attempts <- taskSetsByStage.values + manager <- attempts.values + } { try { manager.abort(message) } catch { @@ -537,6 +541,21 @@ private[spark] class TaskSchedulerImpl( override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() + private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = { + taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) => + taskSetManagerForAttempt(stageId, stageAttemptId) + } + } + + private[scheduler] def taskSetManagerForAttempt(stageId: Int, stageAttemptId: Int): Option[TaskSetManager] = { + for { + attempts <- taskSetsByStage.get(stageId) + manager <- attempts.get(stageAttemptId) + } yield { + manager + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156f53..be8526ba9b94f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -26,10 +26,10 @@ import java.util.Properties private[spark] class TaskSet( val tasks: Array[Task[_]], val stageId: Int, - val attempt: Int, + val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7c7f70d8a193b..69cea02674388 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -191,8 +191,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +