diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0a89761108726..4eebff8dbb516 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,10 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val stageIdToActiveTaskSet = new HashMap[Int, TaskSetManager] + val activeTaskSets = new HashMap[String, TaskSetManager] + val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]] - val taskIdToStageId = new HashMap[Long, Int] + val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -162,13 +163,17 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - stageIdToActiveTaskSet(taskSet.stageId) = manager - val stageId = taskSet.stageId - stageIdToActiveTaskSet.get(stageId).map { activeTaskSet => - throw new IllegalStateException( - s"Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}") + activeTaskSets(taskSet.id) = manager + val stage = taskSet.stageId + val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.attempt) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie + } + if (conflictingTaskSet) { + throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") } - stageIdToActiveTaskSet(stageId) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -198,7 +203,7 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - stageIdToActiveTaskSet.get(stageId).map {tsm => + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => // There are two possible cases here: // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task and then abort @@ -220,7 +225,13 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - stageIdToActiveTaskSet -= manager.stageId + activeTaskSets -= manager.taskSet.id + taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.attempt + if (taskSetsForStage.isEmpty) { + taskSetsByStage -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -241,7 +252,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToStageId(tid) = taskSet.taskSet.stageId + taskIdToTaskSetId(tid) = taskSet.taskSet.id taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -325,13 +336,13 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToStageId.get(tid) match { - case Some(stageId) => + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => if (TaskState.isFinished(state)) { - taskIdToStageId.remove(tid) + taskIdToTaskSetId.remove(tid) taskIdToExecutorId.remove(tid) } - stageIdToActiveTaskSet.get(stageId).foreach { taskSet => + activeTaskSets.get(taskSetId).foreach { taskSet => if (state == TaskState.FINISHED) { taskSet.removeRunningTask(tid) taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) @@ -369,8 +380,8 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToStageId.get(id) - .flatMap(stageIdToActiveTaskSet.get) + taskIdToTaskSetId.get(id) + .flatMap(activeTaskSets.get) .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) } } @@ -403,9 +414,9 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (stageIdToActiveTaskSet.nonEmpty) { + if (activeTaskSets.nonEmpty) { // Have each task set throw a SparkException with the error - for ((_, manager) <- stageIdToActiveTaskSet) { + for ((taskSetId, manager) <- activeTaskSets) { try { manager.abort(message) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f2bd76aaef8ee..7c7f70d8a193b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -191,8 +191,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToStageId(task.taskId) - scheduler.stageIdToActiveTaskSet.get(taskSetId).foreach { taskSet => + val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) + scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 48eda6741b8d6..55be409afcf31 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -144,11 +144,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } // OK to submit multiple if previous attempts are all zombie - taskScheduler.stageIdToActiveTaskSet(attempt1.stageId).isZombie = true + taskScheduler.activeTaskSets(attempt1.id).isZombie = true taskScheduler.submitTasks(attempt2) val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null) intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } - taskScheduler.stageIdToActiveTaskSet(attempt2.stageId).isZombie = true + taskScheduler.activeTaskSets(attempt2.id).isZombie = true taskScheduler.submitTasks(attempt3) }