From bd7dcf1c43227532780705e9fc7b5a457e3be254 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 16 Jul 2015 23:31:02 +0800 Subject: [PATCH] Add an internal flag to Accumulable and send internal accumulator updates to the driver via heartbeats --- .../scala/org/apache/spark/Accumulators.scala | 64 ++++++++----------- .../scala/org/apache/spark/TaskContext.scala | 6 ++ .../org/apache/spark/TaskContextImpl.scala | 19 +++++- .../org/apache/spark/executor/Executor.scala | 6 +- .../apache/spark/executor/TaskMetrics.scala | 13 ++++ .../apache/spark/scheduler/DAGScheduler.scala | 3 +- .../spark/scheduler/DAGSchedulerEvent.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 7 +- .../apache/spark/scheduler/TaskResult.scala | 8 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 7 +- 10 files changed, 80 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5a8d17bd99933..d7a90d0e350d3 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -20,7 +20,8 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import scala.ref.WeakReference import scala.reflect.ClassTag @@ -42,22 +43,37 @@ import org.apache.spark.util.Utils * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ -class Accumulable[R, T] ( +class Accumulable[R, T] private[spark] ( @transient initialValue: R, param: AccumulableParam[R, T], - val name: Option[String]) + val name: Option[String], + internal: Boolean) extends Serializable { + private[spark] def this( + @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = { + this(initialValue, param, None, internal) + } + + def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = + this(initialValue, param, name, false) + def this(@transient initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None) val id: Long = Accumulators.newId - @transient private var value_ = initialValue // Current value on master + @volatile @transient private var value_ : R = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers private var deserialized = false - Accumulators.register(this, true) + Accumulators.register(this) + + /** + * Internal accumulators will be reported via heartbeats. For internal accumulators, `R` must be + * thread safe so that they can be reported correctly. + */ + private[spark] def isInternal: Boolean = internal /** * Add more data to this accumulator / accumulable @@ -132,7 +148,8 @@ class Accumulable[R, T] ( in.defaultReadObject() value_ = zero deserialized = true - Accumulators.register(this, false) + val taskContext = TaskContext.get() + taskContext.registerAccumulator(this) } override def toString: String = if (value_ == null) "null" else value_.toString @@ -284,16 +301,7 @@ private[spark] object Accumulators extends Logging { * It keeps weak references to these objects so that accumulators can be garbage-collected * once the RDDs and user-code that reference them are cleaned up. */ - val originals = Map[Long, WeakReference[Accumulable[_, _]]]() - - /** - * This thread-local map holds per-task copies of accumulators; it is used to collect the set - * of accumulator updates to send back to the driver when tasks complete. After tasks complete, - * this map is cleared by `Accumulators.clear()` (see Executor.scala). - */ - private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() { - override protected def initialValue() = Map[Long, Accumulable[_, _]]() - } + val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]() private var lastId: Long = 0 @@ -302,19 +310,8 @@ private[spark] object Accumulators extends Logging { lastId } - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = new WeakReference[Accumulable[_, _]](a) - } else { - localAccums.get()(a.id) = a - } - } - - // Clear the local (non-original) accumulators for the current thread - def clear() { - synchronized { - localAccums.get.clear() - } + def register(a: Accumulable[_, _]): Unit = synchronized { + originals(a.id) = new WeakReference[Accumulable[_, _]](a) } def remove(accId: Long) { @@ -323,15 +320,6 @@ private[spark] object Accumulators extends Logging { } } - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.get) { - ret(id) = accum.localValue - } - return ret - } - // Add values to the original accumulators with some given IDs def add(values: Map[Long, Any]): Unit = synchronized { for ((id, value) <- values) { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 248339148d9b7..11f19489fa725 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -152,4 +152,10 @@ abstract class TaskContext extends Serializable { * Returns the manager for this task's managed memory. */ private[spark] def taskMemoryManager(): TaskMemoryManager + + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit + + private[spark] def collectInternalAccumulators(): Map[Long, Any] + + private[spark] def collectAccumulators(): Map[Long, Any] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index b4d572cb52313..6e394f1b12445 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,12 +17,12 @@ package org.apache.spark +import scala.collection.mutable.{ArrayBuffer, HashMap} + import org.apache.spark.executor.TaskMetrics import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import scala.collection.mutable.ArrayBuffer - private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, @@ -94,5 +94,18 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = runningLocally override def isInterrupted(): Boolean = interrupted -} + @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] + + private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { + accumulators(a.id) = a + } + + private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized { + accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap + } + + private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { + accumulators.mapValues(_.localValue).toMap + } +} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1a02051c87f19..9087debde8c41 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -209,7 +209,7 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = try { + val (value, accumUpdates) = try { task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) } finally { // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; @@ -247,7 +247,6 @@ private[spark] class Executor( m.setResultSerializationTime(afterSerialization - beforeSerialization) } - val accumUpdates = Accumulators.values val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) val resultSize = serializedDirectResult.limit @@ -314,8 +313,6 @@ private[spark] class Executor( env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() - // Release memory used by this thread for accumulators - Accumulators.clear() runningTasks.remove(taskId) } } @@ -424,6 +421,7 @@ private[spark] class Executor( metrics.updateShuffleReadMetrics() metrics.updateInputMetrics() metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + metrics.updateAccumulators() if (isLocal) { // JobProgressListener will hold an reference of it during diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index e80feeeab4142..93d44492c0e3f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -223,6 +223,19 @@ class TaskMetrics extends Serializable { // overhead. _hostname = TaskMetrics.getCachedHostName(_hostname) } + + private var _accumulatorUpdates: Map[Long, Any] = Map.empty + @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null + + private[spark] def updateAccumulators(): Unit = synchronized { + _accumulatorUpdates = _accumulatorsUpdater() + } + + def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates + + private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = { + _accumulatorsUpdater = accumulatorsUpdater + } } private[spark] object TaskMetrics { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f3d87ee5c4fd1..09ee5fccbedab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -22,7 +22,8 @@ import java.util.Properties import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} +import scala.collection.Map +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2b6f7e4205c32..a927eae2b04be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.Map +import scala.collection.Map import scala.language.existentials import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 15101c64f0503..12bc4ee1e17bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,6 +45,8 @@ import org.apache.spark.util.Utils */ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { + type AccumulatorUpdates = Map[Long, Any] + /** * Called by [[Executor]] to run this task. * @@ -52,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task */ - final def run(taskAttemptId: Long, attemptNumber: Int): T = { + final def run(taskAttemptId: Long, attemptNumber: Int): (T, AccumulatorUpdates) = { context = new TaskContextImpl( stageId = stageId, partitionId = partitionId, @@ -62,12 +64,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) + context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators) taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) } try { - runTask(context) + (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() TaskContext.unset() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 8b2a742b96988..b82c7f3fa54f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -20,7 +20,8 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer -import scala.collection.mutable.Map +import scala.collection.Map +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.executor.TaskMetrics @@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long if (numUpdates == 0) { accumUpdates = null } else { - accumUpdates = Map() + val _accumUpdates = mutable.Map[Long, Any]() for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() + _accumUpdates(in.readLong()) = in.readObject() } + accumUpdates = _accumUpdates } metrics = in.readObject().asInstanceOf[TaskMetrics] valueObjectDeserialized = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 0060f3396dcde..cdae0d83d01dc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.scheduler import java.util.Random -import scala.collection.mutable.ArrayBuffer +import scala.collection.Map import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.util.ManualClock class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: mutable.Map[Long, Any], + accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics) { taskScheduler.endedTasks(taskInfo.index) = reason