Skip to content

Commit

Permalink
simplify going from taskId to taskSetMgr
Browse files Browse the repository at this point in the history
  • Loading branch information
squito committed Jul 17, 2015
1 parent e43ac25 commit 584acd4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]

val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)]
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -252,8 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToStageIdAndAttempt(tid) =
(taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -337,10 +336,10 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskSetManagerForTask(tid) match {
taskIdToTaskSetManager.get(tid) match {
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToStageIdAndAttempt.remove(tid)
taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FINISHED) {
Expand Down Expand Up @@ -379,12 +378,8 @@ private[spark] class TaskSchedulerImpl(

val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
for {
(stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
taskSetMgr <- attempts.get(stageAttemptId)
} yield {
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
}
}
}
Expand Down Expand Up @@ -543,12 +538,6 @@ private[spark] class TaskSchedulerImpl(

override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()

private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = {
taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
taskSetManagerForAttempt(stageId, stageAttemptId)
}
}

private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
stageAttemptId: Int): Option[TaskSetManager] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet =>
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
taskSet.abort(msg)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
taskScheduler.submitTasks(attempt2)
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions3.length)
val mgr = taskScheduler.taskSetManagerForTask(taskDescriptions3(0).taskId).get
val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
assert(mgr.taskSet.stageAttemptId === 1)
}

Expand Down Expand Up @@ -232,7 +232,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(10 === taskDescriptions3.length)

taskDescriptions3.foreach{ task =>
val mgr = taskScheduler.taskSetManagerForTask(task.taskId).get
val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
assert(mgr.taskSet.stageAttemptId === 1)
}
}
Expand Down

0 comments on commit 584acd4

Please sign in to comment.