Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 17, 2015
2 parents 6edb5ac + 812b63b commit b2e358e
Show file tree
Hide file tree
Showing 44 changed files with 908 additions and 346 deletions.
68 changes: 30 additions & 38 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark
import java.io.{ObjectInputStream, Serializable}

import scala.collection.generic.Growable
import scala.collection.mutable.Map
import scala.collection.Map
import scala.collection.mutable
import scala.ref.WeakReference
import scala.reflect.ClassTag

Expand All @@ -39,25 +40,44 @@ import org.apache.spark.util.Utils
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `R` and `T`
* @param name human-readable name for use in Spark's web UI
* @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported
* to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be
* thread safe so that they can be reported correctly.
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
class Accumulable[R, T] (
class Accumulable[R, T] private[spark] (
@transient initialValue: R,
param: AccumulableParam[R, T],
val name: Option[String])
val name: Option[String],
internal: Boolean)
extends Serializable {

private[spark] def this(
@transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = {
this(initialValue, param, None, internal)
}

def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
this(initialValue, param, name, false)

def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
this(initialValue, param, None)

val id: Long = Accumulators.newId

@transient private var value_ = initialValue // Current value on master
@volatile @transient private var value_ : R = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
private var deserialized = false

Accumulators.register(this, true)
Accumulators.register(this)

/**
* If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver
* via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be
* reported correctly.
*/
private[spark] def isInternal: Boolean = internal

/**
* Add more data to this accumulator / accumulable
Expand Down Expand Up @@ -132,7 +152,8 @@ class Accumulable[R, T] (
in.defaultReadObject()
value_ = zero
deserialized = true
Accumulators.register(this, false)
val taskContext = TaskContext.get()
taskContext.registerAccumulator(this)
}

override def toString: String = if (value_ == null) "null" else value_.toString
Expand Down Expand Up @@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging {
* It keeps weak references to these objects so that accumulators can be garbage-collected
* once the RDDs and user-code that reference them are cleaned up.
*/
val originals = Map[Long, WeakReference[Accumulable[_, _]]]()

/**
* This thread-local map holds per-task copies of accumulators; it is used to collect the set
* of accumulator updates to send back to the driver when tasks complete. After tasks complete,
* this map is cleared by `Accumulators.clear()` (see Executor.scala).
*/
private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
}
val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()

private var lastId: Long = 0

Expand All @@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging {
lastId
}

def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
} else {
localAccums.get()(a.id) = a
}
}

// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
localAccums.get.clear()
}
def register(a: Accumulable[_, _]): Unit = synchronized {
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
}

def remove(accId: Long) {
Expand All @@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging {
}
}

// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.get) {
ret(id) = accum.localValue
}
return ret
}

// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
// Asynchronously kill the executor to avoid blocking the current thread
killExecutorThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
sc.killExecutor(executorId)
// Note: we want to get an executor back after expiring this one,
// so do not simply call `sc.killExecutor` here (SPARK-8119)
sc.killAndReplaceExecutor(executorId)
}
})
}
Expand Down
40 changes: 38 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request that the cluster manager kill the specified executors.
*
* Note: This is an indication to the cluster manager that the application wishes to adjust
* its resource usage downwards. If the application wishes to replace the executors it kills
* through this method with new ones, it should follow up explicitly with a call to
* {{SparkContext#requestExecutors}}.
*
* This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
Expand All @@ -1436,12 +1442,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli

/**
* :: DeveloperApi ::
* Request that cluster manager the kill the specified executor.
* This is currently only supported in Yarn mode. Return whether the request is received.
* Request that the cluster manager kill the specified executor.
*
* Note: This is an indication to the cluster manager that the application wishes to adjust
* its resource usage downwards. If the application wishes to replace the executor it kills
* through this method with a new one, it should follow up explicitly with a call to
* {{SparkContext#requestExecutors}}.
*
* This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)

/**
* Request that the cluster manager kill the specified executor without adjusting the
* application resource requirements.
*
* The effect is that a new executor will be launched in place of the one killed by
* this request. This assumes the cluster manager will automatically and eventually
* fulfill all missing application resource requests.
*
* Note: The replace is by no means guaranteed; another application on the same cluster
* can steal the window of opportunity and acquire this application's resources in the
* mean time.
*
* This is currently only supported in YARN mode. Return whether the request is received.
*/
private[spark] def killAndReplaceExecutor(executorId: String): Boolean = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(Seq(executorId), replace = true)
case _ =>
logWarning("Killing executors is only supported in coarse-grained mode")
false
}
}

/** The version of Spark on which this application is running. */
def version: String = SPARK_VERSION

Expand Down
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,22 @@ abstract class TaskContext extends Serializable {
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager

/**
* Register an accumulator that belongs to this task. Accumulators must call this method when
* deserializing in executors.
*/
private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit

/**
* Return the local values of internal accumulators that belong to this task. The key of the Map
* is the accumulator id and the value of the Map is the latest accumulator local value.
*/
private[spark] def collectInternalAccumulators(): Map[Long, Any]

/**
* Return the local values of accumulators that belong to this task. The key of the Map is the
* accumulator id and the value of the Map is the latest accumulator local value.
*/
private[spark] def collectAccumulators(): Map[Long, Any]
}
19 changes: 16 additions & 3 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

import scala.collection.mutable.ArrayBuffer

private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
Expand Down Expand Up @@ -94,5 +94,18 @@ private[spark] class TaskContextImpl(
override def isRunningLocally(): Boolean = runningLocally

override def isInterrupted(): Boolean = interrupted
}

@transient private val accumulators = new HashMap[Long, Accumulable[_, _]]

private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
accumulators(a.id) = a
}

private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
}

private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
accumulators.mapValues(_.localValue).toMap
}
}
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private[spark] class Executor(

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = try {
val (value, accumUpdates) = try {
task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
} finally {
// Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
Expand Down Expand Up @@ -247,7 +247,6 @@ private[spark] class Executor(
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}

val accumUpdates = Accumulators.values
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
Expand Down Expand Up @@ -314,8 +313,6 @@ private[spark] class Executor(
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
// Release memory used by this thread for accumulators
Accumulators.clear()
runningTasks.remove(taskId)
}
}
Expand Down Expand Up @@ -424,6 +421,7 @@ private[spark] class Executor(
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
metrics.updateAccumulators()

if (isLocal) {
// JobProgressListener will hold an reference of it during
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,22 @@ class TaskMetrics extends Serializable {
// overhead.
_hostname = TaskMetrics.getCachedHostName(_hostname)
}

private var _accumulatorUpdates: Map[Long, Any] = Map.empty
@transient private var _accumulatorsUpdater: () => Map[Long, Any] = null

private[spark] def updateAccumulators(): Unit = synchronized {
_accumulatorUpdates = _accumulatorsUpdater()
}

/**
* Return the latest updates of accumulators in this task.
*/
def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates

private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
_accumulatorsUpdater = accumulatorsUpdater
}
}

private[spark] object TaskMetrics {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
import scala.collection.Map
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
Expand Down Expand Up @@ -556,6 +557,9 @@ class DAGScheduler(
case JobFailed(exception: Exception) =>
logInfo("Job %d failed: %s, took %f s".format
(waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
// SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
val callerStackTrace = Thread.currentThread().getStackTrace.tail
exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
throw exception
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import java.util.Properties

import scala.collection.mutable.Map
import scala.collection.Map
import scala.language.existentials

import org.apache.spark._
Expand Down
13 changes: 10 additions & 3 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {

/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
* local value.
*/
type AccumulatorUpdates = Map[Long, Any]

/**
* Called by [[Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task
* @return the result of the task along with updates of Accumulators.
*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
final def run(taskAttemptId: Long, attemptNumber: Int): (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
stageId = stageId,
partitionId = partitionId,
Expand All @@ -62,12 +68,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
try {
runTask(context)
(runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
TaskContext.unset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer

import scala.collection.mutable.Map
import scala.collection.Map
import scala.collection.mutable

import org.apache.spark.SparkEnv
import org.apache.spark.executor.TaskMetrics
Expand Down Expand Up @@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
if (numUpdates == 0) {
accumUpdates = null
} else {
accumUpdates = Map()
val _accumUpdates = mutable.Map[Long, Any]()
for (i <- 0 until numUpdates) {
accumUpdates(in.readLong()) = in.readObject()
_accumUpdates(in.readLong()) = in.readObject()
}
accumUpdates = _accumUpdates
}
metrics = in.readObject().asInstanceOf[TaskMetrics]
valueObjectDeserialized = false
Expand Down
Loading

0 comments on commit b2e358e

Please sign in to comment.