diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0a9181345add9..014d3c126f70b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -76,6 +76,7 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. val activeTaskSets = new HashMap[String, TaskSetManager] + val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]] val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] @@ -164,13 +165,14 @@ private[spark] class TaskSchedulerImpl( val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager val stage = taskSet.stageId - val conflictingTaskSet = activeTaskSets.exists { case (id, ts) => - // if the id matches, it really should be the same taskSet, but in some unit tests - // we add new taskSets with the same id - id != taskSet.id && !ts.isZombie && ts.stageId == stage + val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + stageTaskSets(taskSet.attempt) = manager + val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => + ts.taskSet != taskSet && !ts.isZombie } if (conflictingTaskSet) { - throw new SparkIllegalStateException(s"more than one active taskSet for stage $stage") + throw new SparkIllegalStateException(s"more than one active taskSet for stage $stage:" + + s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) @@ -224,6 +226,12 @@ private[spark] class TaskSchedulerImpl( */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { activeTaskSets -= manager.taskSet.id + taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsForStage -= manager.taskSet.attempt + if (taskSetsForStage.isEmpty) { + taskSetsByStage -= manager.taskSet.stageId + } + } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name))