Skip to content

Commit

Permalink
Removed use of BoundedHashMap, and made BlockManagerSlaveActor cleanu…
Browse files Browse the repository at this point in the history
…p shuffle metadata in MapOutputTrackerWorker.
  • Loading branch information
tdas committed Mar 18, 2014
1 parent a7260d3 commit 892b952
Show file tree
Hide file tree
Showing 17 changed files with 196 additions and 147 deletions.
19 changes: 13 additions & 6 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import org.apache.spark.storage.StorageLevel

/** Listener class used for testing when any item has been cleaned by the Cleaner class */
private[spark] trait CleanerListener {
Expand Down Expand Up @@ -61,19 +62,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/**
* Clean RDD data. Do not perform any time or resource intensive
* Schedule cleanup of RDD data. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
*/
def cleanRDD(rddId: Int) {
def scheduleRDDCleanup(rddId: Int) {
enqueue(CleanRDD(rddId))
logDebug("Enqueued RDD " + rddId + " for cleaning up")
}

/**
* Clean shuffle data. Do not perform any time or resource intensive
* Schedule cleanup of shuffle data. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
*/
def cleanShuffle(shuffleId: Int) {
def scheduleShuffleCleanup(shuffleId: Int) {
enqueue(CleanShuffle(shuffleId))
logDebug("Enqueued shuffle " + shuffleId + " for cleaning up")
}
Expand All @@ -83,6 +84,13 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
listeners += listener
}

/** Unpersists RDD and remove all blocks for it from memory and disk. */
def unpersistRDD(rddId: Int, blocking: Boolean) {
logDebug("Unpersisted RDD " + rddId)
sc.env.blockManager.master.removeRdd(rddId, blocking)
sc.persistentRdds.remove(rddId)
}

/**
* Enqueue a cleaning task. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
Expand Down Expand Up @@ -115,8 +123,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private def doCleanRDD(rddId: Int) {
try {
logDebug("Cleaning RDD " + rddId)
blockManagerMaster.removeRdd(rddId, false)
sc.persistentRdds.remove(rddId)
unpersistRDD(rddId, false)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ class ShuffleDependency[K, V](
override def finalize() {
try {
if (rdd != null) {
rdd.sparkContext.cleaner.cleanShuffle(shuffleId)
rdd.sparkContext.cleaner.scheduleShuffleCleanup(shuffleId)
}
} catch {
case t: Throwable =>
// Paranoia - If logError throws error as well, report to stderr.
try {
logError("Error in finalize", t)
} catch {
case _ =>
case _ : Throwable =>
System.err.println("Error in finalize (and could not write to logError): " + t)
}
} finally {
Expand Down
106 changes: 55 additions & 51 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.{HashSet, Map}
import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await

import akka.actor._
Expand All @@ -34,6 +34,7 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

/** Actor class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
extends Actor with Logging {
def receive = {
Expand All @@ -50,28 +51,35 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}

/**
* Class that keeps track of the location of the location of the map output of
* Class that keeps track of the location of the map output of
* a stage. This is abstract because different versions of MapOutputTracker
* (driver and worker) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {

private val timeout = AkkaUtils.askTimeout(conf)

// Set to the MapOutputTrackerActor living on the driver
/** Set to the MapOutputTrackerActor living on the driver */
var trackerActor: ActorRef = _

/** This HashMap needs to have different storage behavior for driver and worker */
protected val mapStatuses: Map[Int, Array[MapStatus]]

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
/**
* Incremented every time a fetch fails so that client nodes know to clear
* their cache of map output locations if this happens.
*/
protected var epoch: Long = 0
protected val epochLock = new java.lang.Object

// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
private def askTracker(message: Any): Any = {
/** Remembers which map output locations are currently being fetched on a worker */
private val fetching = new HashSet[Int]

/**
* Send a message to the trackerActor and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
protected def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
Await.result(future, timeout)
Expand All @@ -81,17 +89,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}

// Send a one-way message to the trackerActor, to which we expect it to reply with true.
private def communicate(message: Any) {
/** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
protected def sendTracker(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}

// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]

// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
/**
* Called from executors to get the server URIs and
* output sizes of the map outputs of a given shuffle
*/
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
Expand Down Expand Up @@ -150,22 +158,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}

def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
trackerActor = null
}

// Called to get current epoch number
/** Called to get current epoch number */
def getEpoch: Long = {
epochLock.synchronized {
return epoch
}
}

// Called on workers to update the epoch number, potentially clearing old outputs
// because of a fetch failure. (Each worker task calls this with the latest epoch
// number on the master at the time it was created.)
/**
* Called from executors to update the epoch number, potentially clearing old outputs
* because of a fetch failure. Each worker task calls this with the latest epoch
* number on the master at the time it was created.
*/
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
if (newEpoch > epoch) {
Expand All @@ -175,24 +179,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
}
}

/**
* MapOutputTracker for the workers. This uses BoundedHashMap to keep track of
* a limited number of most recently used map output information.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
/** Unregister shuffle data */
def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
}

/**
* Bounded HashMap for storing serialized statuses in the worker. This allows
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
* automatically repopulated by fetching them again from the driver. Its okay to
* keep the cache size small as it unlikely that there will be a very large number of
* stages active simultaneously in the worker.
*/
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](
conf.getInt("spark.mapOutputTracker.cacheSize", 100), true
)
def stop() {
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerActor = null
}
}

/**
Expand All @@ -202,7 +199,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
extends MapOutputTracker(conf) {

// Cache a serialized version of the output statuses for each shuffle to send them out faster
/** Cache a serialized version of the output statuses for each shuffle to send them out faster */
private var cacheEpoch = epoch

/**
Expand All @@ -211,7 +208,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
* by TTL-based cleaning (if set). Other than these two
* scenarios, nothing should be dropped from this HashMap.
*/

protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()

Expand All @@ -232,13 +228,15 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}

/** Register multiple map output information for the given shuffle */
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}

/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
Expand All @@ -254,11 +252,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}

def unregisterShuffle(shuffleId: Int) {
/** Unregister shuffle data */
override def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
cachedSerializedStatuses.remove(shuffleId)
}

/** Check if the given shuffle is being tracked */
def containsShuffle(shuffleId: Int): Boolean = {
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
}

def incrementEpoch() {
epochLock.synchronized {
epoch += 1
Expand Down Expand Up @@ -295,26 +299,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
bytes
}

def contains(shuffleId: Int): Boolean = {
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
}

override def stop() {
super.stop()
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}

override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}

protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
}

/**
* MapOutputTracker for the workers, which fetches map output information from the driver's
* MapOutputTrackerMaster.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
protected val mapStatuses = new HashMap[Int, Array[MapStatus]]
}

private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1

Expand Down
25 changes: 13 additions & 12 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,6 @@ object SparkEnv extends Logging {
}
}

val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf)), conf)
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, securityManager)

val connectionManager = blockManager.connectionManager

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val cacheManager = new CacheManager(blockManager)

// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = if (isDriver) {
Expand All @@ -188,6 +176,19 @@ object SparkEnv extends Logging {
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))

val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf)), conf)

val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, securityManager, mapOutputTracker)

val connectionManager = blockManager.connectionManager

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val cacheManager = new CacheManager(blockManager)

val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

Expand Down
15 changes: 3 additions & 12 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ abstract class RDD[T: ClassTag](
*/
def unpersist(blocking: Boolean = true): RDD[T] = {
logInfo("Removing RDD " + id + " from persistence list")
sc.env.blockManager.master.removeRdd(id, blocking)
sc.persistentRdds.remove(id)
sc.cleaner.unpersistRDD(id, blocking)
storageLevel = StorageLevel.NONE
this
}
Expand Down Expand Up @@ -1025,14 +1024,6 @@ abstract class RDD[T: ClassTag](
checkpointData.flatMap(_.getCheckpointFile)
}

def cleanup() {
logInfo("Cleanup called on RDD " + id)
sc.cleaner.cleanRDD(id)
dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]])
.map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId)
.foreach(sc.cleaner.cleanShuffle)
}

// =======================================================================
// Other internal methods and fields
// =======================================================================
Expand Down Expand Up @@ -1114,14 +1105,14 @@ abstract class RDD[T: ClassTag](

override def finalize() {
try {
cleanup()
sc.cleaner.scheduleRDDCleanup(id)
} catch {
case t: Throwable =>
// Paranoia - If logError throws error as well, report to stderr.
try {
logError("Error in finalize", t)
} catch {
case _ =>
case _ : Throwable =>
System.err.println("Error in finalize (and could not write to logError): " + t)
}
} finally {
Expand Down
Loading

0 comments on commit 892b952

Please sign in to comment.