diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 1df616144743d..7c746faa386df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -154,7 +154,7 @@ private[spark] class TaskSetManager( // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling var myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. @@ -753,5 +753,6 @@ private[spark] class TaskSetManager( logInfo("Re-computing pending task lists.") pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c92b6dc96c8eb..2320b84653e3b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -77,6 +77,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) + + def addExecutor(newExecutors: (String, String)*) { + executors ++= newExecutors + } } class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { @@ -384,6 +388,43 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.taskSetsFailed.contains(taskSet.id)) } + test("new executors get added") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // All tasks added to no-pref list since no preferred location is available + assert(manager.pendingTasksWithNoPrefs.size === 4) + // Only ANY is valid + assert(manager.myLocalityLevels.size === 1) + // Add a new executor + sched.addExecutor(("execD", "host1")) + manager.executorAdded() + // Task 0 and 1 should be removed from no-pref list + assert(manager.pendingTasksWithNoPrefs.size === 2) + // Valid locality should contain NODE_LOCAL and ANY + assert(manager.myLocalityLevels.size === 2) + // Offer host1, execD, at PROCESS_LOCAL level: task 0 should be chosen + // because PROCESS_LOCAL is not valid at the moment + assert(manager.resourceOffer("execD", "host1", PROCESS_LOCAL).get.index === 0) + // Add another executor + sched.addExecutor(("execC", "host2")) + manager.executorAdded() + // No-pref list now only contains task 3 + assert(manager.pendingTasksWithNoPrefs.size === 1) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.size === 3) + // Offer host2, execC, at PROCESS_LOCAL level: task 2 should be chosen + assert(manager.resourceOffer("execC", "host2", PROCESS_LOCAL).get.index === 2) + // Offer host1, execD again at PROCESS_LOCAL level: task 3 should be chosen + assert(manager.resourceOffer("execD", "host1", PROCESS_LOCAL).get.index === 3) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)