Skip to content

Commit

Permalink
add unit test and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Rui Li committed Jun 10, 2014
1 parent fff4123 commit 99f843e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private[spark] class TaskSetManager(

// Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
var myLocalityLevels = computeValidLocalityLevels()
val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level

// Delay scheduling variables: we keep track of our current locality level and the time we
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
Expand Down Expand Up @@ -753,5 +753,6 @@ private[spark] class TaskSetManager(
logInfo("Re-computing pending task lists.")
pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_))
myLocalityLevels = computeValidLocalityLevels()
localityWaits = myLocalityLevels.map(getLocalityWait)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
override def isExecutorAlive(execId: String): Boolean = executors.contains(execId)

override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)

def addExecutor(newExecutors: (String, String)*) {
executors ++= newExecutors
}
}

class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
Expand Down Expand Up @@ -384,6 +388,43 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(sched.taskSetsFailed.contains(taskSet.id))
}

test("new executors get added") {
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc)
val taskSet = FakeTask.createTaskSet(4,
Seq(TaskLocation("host1", "execA")),
Seq(TaskLocation("host1", "execB")),
Seq(TaskLocation("host2", "execC")),
Seq())
val clock = new FakeClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// All tasks added to no-pref list since no preferred location is available
assert(manager.pendingTasksWithNoPrefs.size === 4)
// Only ANY is valid
assert(manager.myLocalityLevels.size === 1)
// Add a new executor
sched.addExecutor(("execD", "host1"))
manager.executorAdded()
// Task 0 and 1 should be removed from no-pref list
assert(manager.pendingTasksWithNoPrefs.size === 2)
// Valid locality should contain NODE_LOCAL and ANY
assert(manager.myLocalityLevels.size === 2)
// Offer host1, execD, at PROCESS_LOCAL level: task 0 should be chosen
// because PROCESS_LOCAL is not valid at the moment
assert(manager.resourceOffer("execD", "host1", PROCESS_LOCAL).get.index === 0)
// Add another executor
sched.addExecutor(("execC", "host2"))
manager.executorAdded()
// No-pref list now only contains task 3
assert(manager.pendingTasksWithNoPrefs.size === 1)
// Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY
assert(manager.myLocalityLevels.size === 3)
// Offer host2, execC, at PROCESS_LOCAL level: task 2 should be chosen
assert(manager.resourceOffer("execC", "host2", PROCESS_LOCAL).get.index === 2)
// Offer host1, execD again at PROCESS_LOCAL level: task 3 should be chosen
assert(manager.resourceOffer("execD", "host1", PROCESS_LOCAL).get.index === 3)
}

def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
Expand Down

0 comments on commit 99f843e

Please sign in to comment.