Skip to content

Commit

Permalink
better fix and simpler test case
Browse files Browse the repository at this point in the history
  • Loading branch information
squito committed Jun 10, 2015
1 parent 28d70aa commit c443def
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 56 deletions.
65 changes: 34 additions & 31 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1102,44 +1102,47 @@ class DAGScheduler(
case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
if (failedStage.attemptId - 1 > task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId}, which has already failed")
} else {

// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
if (failedStage.attemptId - 1 > task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId}, which has already failed")
} else {
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
} else {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage, " +
s"which is no longer running")
}
}

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
}

case commitDenied: TaskCommitDenied =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,33 @@ import org.apache.spark._

class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {

// TODO we should run this with a matrix of configurations: different shufflers,
// external shuffle service, etc. But that is really pushing the question of how to run
// such a long test ...

ignore("no concurrent retries for stage attempts (SPARK-7308)") {
// see SPARK-7308 for a detailed description of the conditions this is trying to recreate.
// note that this is somewhat convoluted for a test case, but isn't actually very unusual
// under a real workload. We only fail the first attempt of stage 2, but that
// could be enough to cause havoc.

(0 until 100).foreach { idx =>
println(new Date() + "\ttrial " + idx)
test("no concurrent retries for stage attempts (SPARK-8103)") {
// make sure that if we get fetch failures after the retry has started, we ignore them,
// and so don't end up submitting multiple concurrent attempts for the same stage

(0 until 20).foreach { idx =>
logInfo(new Date() + "\ttrial " + idx)

val conf = new SparkConf().set("spark.executor.memory", "100m")
val clusterSc = new SparkContext("local-cluster[5,4,100]", "test-cluster", conf)
val clusterSc = new SparkContext("local-cluster[2,2,100]", "test-cluster", conf)
val bms = ArrayBuffer[BlockManagerId]()
val stageFailureCount = HashMap[Int, Int]()
val stageSubmissionCount = HashMap[Int, Int]()
clusterSc.addSparkListener(new SparkListener {
override def onBlockManagerAdded(bmAdded: SparkListenerBlockManagerAdded): Unit = {
bms += bmAdded.blockManagerId
}

override def onStageSubmitted(stageSubmited: SparkListenerStageSubmitted): Unit = {
val stage = stageSubmited.stageInfo.stageId
stageSubmissionCount(stage) = stageSubmissionCount.getOrElse(stage, 0) + 1
}


override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
if (stageCompleted.stageInfo.failureReason.isDefined) {
val stage = stageCompleted.stageInfo.stageId
stageFailureCount(stage) = stageFailureCount.getOrElse(stage, 0) + 1
val reason = stageCompleted.stageInfo.failureReason.get
println("stage " + stage + " failed: " + stageFailureCount(stage))
}
}
})
Expand All @@ -66,34 +64,37 @@ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
// to avoid broadcast failures
val someBlockManager = bms.filter{!_.isDriver}(0)

val shuffled = rawData.groupByKey(100).mapPartitionsWithIndex { case (idx, itr) =>
val shuffled = rawData.groupByKey(20).mapPartitionsWithIndex { case (idx, itr) =>
// we want one failure quickly, and more failures after stage 0 has finished its
// second attempt
val stageAttemptId = TaskContext.get().asInstanceOf[TaskContextImpl].stageAttemptId
if (stageAttemptId == 0) {
if (idx == 0) {
throw new FetchFailedException(someBlockManager, 0, 0, idx,
cause = new RuntimeException("simulated fetch failure"))
} else if (idx > 0 && math.random < 0.2) {
Thread.sleep(5000)
} else if (idx == 1) {
Thread.sleep(2000)
throw new FetchFailedException(someBlockManager, 0, 0, idx,
cause = new RuntimeException("simulated fetch failure"))
} else {
// want to make sure plenty of these finish after task 0 fails, and some even finish
// after the previous stage is retried and this stage retry is started
Thread.sleep((500 + math.random * 5000).toLong)
}
} else {
// just to make sure the second attempt doesn't finish before we trigger more failures
// from the first attempt
Thread.sleep(2000)
}
itr.map { x => ((x._1 + 5) % 100) -> x._2 }
}
val data = shuffled.mapPartitions { itr => itr.flatMap(_._2) }.collect()
val data = shuffled.mapPartitions { itr =>
itr.flatMap(_._2)
}.cache().collect()
val count = data.size
assert(count === 1e6.toInt)
assert(data.toSet === (1 to 1e6.toInt).toSet)

assert(stageFailureCount.getOrElse(1, 0) === 0)
assert(stageFailureCount.getOrElse(2, 0) == 1)
assert(stageFailureCount.getOrElse(3, 0) == 0)
assert(stageFailureCount.getOrElse(2, 0) === 1)
assert(stageSubmissionCount.getOrElse(1, 0) <= 2)
assert(stageSubmissionCount.getOrElse(2, 0) === 2)
} finally {
clusterSc.stop()
}
Expand Down

0 comments on commit c443def

Please sign in to comment.