From 584acd4ce80ad0c8409638c49874ea1e46099bc6 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 17 Jul 2015 13:30:52 -0500 Subject: [PATCH] simplify going from taskId to taskSetMgr --- .../spark/scheduler/TaskSchedulerImpl.scala | 25 ++++++------------- .../CoarseGrainedSchedulerBackend.scala | 4 +-- .../scheduler/TaskSchedulerImplSuite.scala | 4 +-- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a34b67db388f6..1705e7f962de2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] + private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)] + private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -252,8 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToStageIdAndAttempt(tid) = - (taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId) + taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -337,10 +336,10 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskSetManagerForTask(tid) match { + taskIdToTaskSetManager.get(tid) match { case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToStageIdAndAttempt.remove(tid) + taskIdToTaskSetManager.remove(tid) taskIdToExecutorId.remove(tid) } if (state == TaskState.FINISHED) { @@ -379,12 +378,8 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - for { - (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id) - attempts <- taskSetsByStageIdAndAttempt.get(stageId) - taskSetMgr <- attempts.get(stageAttemptId) - } yield { - (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + taskIdToTaskSetManager.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) } } } @@ -543,12 +538,6 @@ private[spark] class TaskSchedulerImpl( override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() - private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = { - taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) => - taskSetManagerForAttempt(stageId, stageAttemptId) - } - } - private[scheduler] def taskSetManagerForAttempt( stageId: Int, stageAttemptId: Int): Option[TaskSetManager] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 69cea02674388..8d2369ef99aef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -191,14 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet => + scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + "spark.akka.frameSize or using broadcast variables for large values." msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize, AkkaUtils.reservedSizeBytes) - taskSet.abort(msg) + taskSetMgr.abort(msg) } catch { case e: Exception => logError("Exception in error callback", e) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index cb0dce44536d1..b734d3ae0be7c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -188,7 +188,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L taskScheduler.submitTasks(attempt2) val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten assert(1 === taskDescriptions3.length) - val mgr = taskScheduler.taskSetManagerForTask(taskDescriptions3(0).taskId).get + val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get assert(mgr.taskSet.stageAttemptId === 1) } @@ -232,7 +232,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(10 === taskDescriptions3.length) taskDescriptions3.foreach{ task => - val mgr = taskScheduler.taskSetManagerForTask(task.taskId).get + val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get assert(mgr.taskSet.stageAttemptId === 1) } }