From ada7726d4599875355114e59a7d5444f05df4685 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 1 Jul 2015 15:18:28 -0500 Subject: [PATCH] reviewer feedback --- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 13 ++++++++----- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 14 ++++++++------ .../spark/scheduler/TaskSchedulerImplSuite.scala | 10 +++++++++- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 60173e21b64a8..0a9181345add9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -163,11 +163,14 @@ private[spark] class TaskSchedulerImpl( this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager - val taskSetsPerStage = activeTaskSets.values.filterNot(_.isZombie).groupBy(_.stageId) - taskSetsPerStage.foreach { case (stage, taskSets) => - if (taskSets.size > 1) { - throw new SparkIllegalStateException("more than one active taskSet for stage " + stage) - } + val stage = taskSet.stageId + val conflictingTaskSet = activeTaskSets.exists { case (id, ts) => + // if the id matches, it really should be the same taskSet, but in some unit tests + // we add new taskSets with the same id + id != taskSet.id && !ts.isZombie && ts.stageId == stage + } + if (conflictingTaskSet) { + throw new SparkIllegalStateException(s"more than one active taskSet for stage $stage") } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index fd60bf5782590..38bf4f79d6bf0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -554,8 +554,10 @@ class DAGSchedulerSuite assert(sparkListener.failedStages.size == 1) } - /** This tests the case where another FetchFailed comes in while the map stage is getting - * re-run. */ + /** + * This tests the case where another FetchFailed comes in while the map stage is getting + * re-run. + */ test("late fetch failures don't cause multiple concurrent attempts for the same map stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -607,15 +609,15 @@ class DAGSchedulerSuite createFakeTaskInfo(), null)) - // Another ResubmitFailedStages event should not result result in another attempt for the map + // Another ResubmitFailedStages event should not result in another attempt for the map // stage being run concurrently. + // NOTE: the actual ResubmitFailedStages may get called at any time during this, shouldn't + // effect anything -- our calling it just makes *SURE* it gets called between the desired event + // and our check. runEvent(ResubmitFailedStages) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 2) - // NOTE: the actual ResubmitFailedStages may get called at any time during this, shouldn't - // effect anything -- our calling it just makes *SURE* it gets called between the desired event - // and our check. } /** This tests the case where a late FetchFailed comes in after the map stage has finished getting diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 07bdb84cd756f..8af47a0809e0d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -141,7 +141,15 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L val attempt1 = new TaskSet(Array(new FakeTask(0)), 0, 0, 0, null) val attempt2 = new TaskSet(Array(new FakeTask(0)), 0, 1, 0, null) taskScheduler.submitTasks(attempt1) - intercept[SparkIllegalStateException] { taskScheduler.submitTasks(attempt2)} + intercept[SparkIllegalStateException] { taskScheduler.submitTasks(attempt2) } + + // OK to submit multiple if previous attempts are all zombie + taskScheduler.activeTaskSets(attempt1.id).isZombie = true + taskScheduler.submitTasks(attempt2) + val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null) + intercept[SparkIllegalStateException] { taskScheduler.submitTasks(attempt3) } + taskScheduler.activeTaskSets(attempt2.id).isZombie = true + taskScheduler.submitTasks(attempt3) } }