Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
squito committed Jun 10, 2015
1 parent 8f7308f commit a9bf31f
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 17 deletions.
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ class SparkException(message: String, cause: Throwable)
*/
private[spark] class SparkDriverExecutionException(cause: Throwable)
extends SparkException("Execution error", cause)

/**
* Exception indicating an error internal to Spark -- it is in an inconsistent state, not due
* to any error by the user
*/
class SparkIllegalStateException(message: String, cause: Throwable)
extends SparkException(message, cause) {

def this(message: String) = this(message, null)
}
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ private[spark] class TaskContextImpl(
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val stageAttemptId: Int = 0, // for testing
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
Expand Down
20 changes: 13 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()


// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
Expand Down Expand Up @@ -894,7 +893,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
new ShuffleMapTask(stage.id, stage.attemptId, taskBinary, part, locs)
}

case stage: ResultStage =>
Expand All @@ -903,7 +902,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, stage.attemptId, taskBinary, part, locs, id)
}
}

Expand Down Expand Up @@ -977,6 +976,7 @@ class DAGScheduler(
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)

// REVIEWERS: does this need special handling for multiple completions of the same task?
outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
event.taskInfo.attempt, event.reason)

Expand Down Expand Up @@ -1039,10 +1039,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}

if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
Expand Down Expand Up @@ -1106,9 +1107,14 @@ class DAGScheduler(
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
if (failedStage.attemptId - 1 > task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId}, which has already failed")
} else {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
}
}

if (disallowStageRetryForTest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, new Partition { override def index: Int = 0 }, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
var partitionId: Int) extends Serializable {

/**
* Called by [[Executor]] to run this task.
Expand All @@ -55,6 +58,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(
stageId = stageId,
stageAttemptId = stageAttemptId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ private[spark] class TaskSchedulerImpl(
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val taskSetsPerStage = activeTaskSets.values.filterNot(_.isZombie).groupBy(_.stageId)
taskSetsPerStage.foreach { case (stage, taskSets) =>
if (taskSets.size > 1) {
throw new SparkIllegalStateException("more than one active taskSet for stage " + stage)
}
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, 0, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
override def runTask(context: TaskContext): Int = 0

override def preferredLocations: Seq[TaskLocation] = prefLocs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
extends Task[Array[Byte]](stageId, 0) {
extends Task[Array[Byte]](stageId, 0, 0) {

override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val task = new ResultTask[String, String](
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
val task = new ResultTask[String, String](0, 0,
sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0, 0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
Expand Down

0 comments on commit a9bf31f

Please sign in to comment.