Skip to content

Commit

Permalink
Revert "Index active task sets by stage Id rather than by task set id"
Browse files Browse the repository at this point in the history
This reverts commit baf46e1.
  • Loading branch information
squito committed Jul 14, 2015
1 parent f025154 commit c0d4d90
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ private[spark] class TaskSchedulerImpl(

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val stageIdToActiveTaskSet = new HashMap[Int, TaskSetManager]
val activeTaskSets = new HashMap[String, TaskSetManager]
val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]]

val taskIdToStageId = new HashMap[Long, Int]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -162,13 +163,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
stageIdToActiveTaskSet(taskSet.stageId) = manager
val stageId = taskSet.stageId
stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
throw new IllegalStateException(
s"Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}")
activeTaskSets(taskSet.id) = manager
val stage = taskSet.stageId
val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.attempt) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
stageIdToActiveTaskSet(stageId) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down Expand Up @@ -198,7 +203,7 @@ private[spark] class TaskSchedulerImpl(

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
stageIdToActiveTaskSet.get(stageId).map {tsm =>
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
Expand All @@ -220,7 +225,13 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
stageIdToActiveTaskSet -= manager.stageId
activeTaskSets -= manager.taskSet.id
taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.attempt
if (taskSetsForStage.isEmpty) {
taskSetsByStage -= manager.taskSet.stageId
}
}
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
Expand All @@ -241,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToStageId(tid) = taskSet.taskSet.stageId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -325,13 +336,13 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskIdToStageId.get(tid) match {
case Some(stageId) =>
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
if (TaskState.isFinished(state)) {
taskIdToStageId.remove(tid)
taskIdToTaskSetId.remove(tid)
taskIdToExecutorId.remove(tid)
}
stageIdToActiveTaskSet.get(stageId).foreach { taskSet =>
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
Expand Down Expand Up @@ -369,8 +380,8 @@ private[spark] class TaskSchedulerImpl(

val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
taskIdToStageId.get(id)
.flatMap(stageIdToActiveTaskSet.get)
taskIdToTaskSetId.get(id)
.flatMap(activeTaskSets.get)
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
}
}
Expand Down Expand Up @@ -403,9 +414,9 @@ private[spark] class TaskSchedulerImpl(

def error(message: String) {
synchronized {
if (stageIdToActiveTaskSet.nonEmpty) {
if (activeTaskSets.nonEmpty) {
// Have each task set throw a SparkException with the error
for ((_, manager) <- stageIdToActiveTaskSet) {
for ((taskSetId, manager) <- activeTaskSets) {
try {
manager.abort(message)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToStageId(task.taskId)
scheduler.stageIdToActiveTaskSet.get(taskSetId).foreach { taskSet =>
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }

// OK to submit multiple if previous attempts are all zombie
taskScheduler.stageIdToActiveTaskSet(attempt1.stageId).isZombie = true
taskScheduler.activeTaskSets(attempt1.id).isZombie = true
taskScheduler.submitTasks(attempt2)
val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null)
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
taskScheduler.stageIdToActiveTaskSet(attempt2.stageId).isZombie = true
taskScheduler.activeTaskSets(attempt2.id).isZombie = true
taskScheduler.submitTasks(attempt3)
}

Expand Down

0 comments on commit c0d4d90

Please sign in to comment.