Skip to content

Commit

Permalink
Code review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jul 30, 2014
1 parent f8535dc commit f7364db
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class DAGScheduler(
private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))

// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()

private[scheduler] var eventProcessActor: ActorRef = _

private def initializeEventProcessActor() {
Expand Down Expand Up @@ -722,9 +726,9 @@ class DAGScheduler(
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] =
if (stage.isShuffleMap) {
Utils.serializeTaskClosure((stage.rdd, stage.shuffleDep.get) : AnyRef)
closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array()
} else {
Utils.serializeTaskClosure((stage.rdd, stage.resultOfJob.get.func) : AnyRef)
closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array()
}
taskBinary = sc.broadcast(taskBinaryBytes)
} catch {
Expand Down Expand Up @@ -765,7 +769,7 @@ class DAGScheduler(
// We've already serialized RDDs and closures in taskBinary, but here we check for all other
// objects such as Partition.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
closureSerializer.serialize(tasks.head)
} catch {
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ import org.apache.spark.rdd.RDD
*
* @param stageId id of the stage this task belongs to
* @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
* partition of the given RDD.
* partition of the given RDD. Once deserialized, the type should be
* (RDD[T], (TaskContext, Iterator[T]) => U).
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
taskBinary: Broadcast[Array[Byte]], // (RDD[T], (TaskContext, Iterator[T]) => U)
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter
* See [[org.apache.spark.scheduler.Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param taskBinary broadcast version of of the RDD and the ShuffleDependency
* @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized,
* the type should be (RDD[_], ShuffleDependency[_, _, _]).
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
*/
private[spark] class ShuffleMapTask(
stageId: Int,
taskBinary: Broadcast[Array[Byte]], // (RDD[_], ShuffleDependency[_, _, _])
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
Expand Down
6 changes: 0 additions & 6 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ private[spark] object Utils extends Logging {
new File(sparkHome + File.separator + "bin", which + suffix)
}

/** Serialize an object using the closure serializer. */
def serializeTaskClosure[T: ClassTag](o: T): Array[Byte] = {
val ser = SparkEnv.get.closureSerializer.newInstance()
ser.serialize(o).array()
}

/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
Expand Down
14 changes: 14 additions & 0 deletions core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
rddBuffer.clear()
runGC()
postGCTester.assertCleanup()

// Make sure the broadcasted task closure no longer exists after GC.
val taskClosureBroadcastId = broadcastIds.max + 1
assert(sc.env.blockManager.master.getMatchingBlockIds({
case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
case _ => false
}, askSlaves = true).isEmpty)
}

test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
Expand Down Expand Up @@ -195,6 +202,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
rddBuffer.clear()
runGC()
postGCTester.assertCleanup()

// Make sure the broadcasted task closure no longer exists after GC.
val taskClosureBroadcastId = broadcastIds.max + 1
assert(sc.env.blockManager.master.getMatchingBlockIds({
case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
case _ => false
}, askSlaves = true).isEmpty)
}

//------ Helper functions ------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ package org.apache.spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter

import org.apache.spark.LocalSparkContext
import org.apache.spark.Partition
import org.apache.spark.SparkContext
import org.apache.spark.TaskContext
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

Expand All @@ -39,9 +36,10 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
sys.error("failed")
}
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val task = new ResultTask[String, String](
0, sc.broadcast(Utils.serializeTaskClosure((rdd, func))), rdd.partitions(0), Seq(), 0)
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0)
}
Expand Down

0 comments on commit f7364db

Please sign in to comment.