Skip to content

Commit

Permalink
Added ContextCleaner to automatically clean RDDs and shuffles when th…
Browse files Browse the repository at this point in the history
…ey fall out of scope. Also replaced TimeStampedHashMap to BoundedHashMaps and TimeStampedWeakValueHashMap for the necessary hashmap behavior.
  • Loading branch information
tdas committed Feb 14, 2014
1 parent 3a9d82c commit e427a9e
Show file tree
Hide file tree
Showing 22 changed files with 946 additions and 60 deletions.
126 changes: 126 additions & 0 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}

import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}

import org.apache.spark.rdd.RDD

/** Listener class used for testing when any item has been cleaned by the Cleaner class */
private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
}

/**
* Cleans RDDs and shuffle data. This should be instantiated only on the driver.
*/
private[spark] class ContextCleaner(env: SparkEnv) extends Logging {

/** Classes to represent cleaning tasks */
private sealed trait CleaningTask
private case class CleanRDD(sc: SparkContext, id: Int) extends CleaningTask
private case class CleanShuffle(id: Int) extends CleaningTask
// TODO: add CleanBroadcast

private val QUEUE_CAPACITY = 1000
private val queue = new ArrayBlockingQueue[CleaningTask](QUEUE_CAPACITY)

protected val listeners = new ArrayBuffer[CleanerListener]
with SynchronizedBuffer[CleanerListener]

private val cleaningThread = new Thread() { override def run() { keepCleaning() }}

private var stopped = false

/** Start the cleaner */
def start() {
cleaningThread.setDaemon(true)
cleaningThread.start()
}

/** Stop the cleaner */
def stop() {
synchronized { stopped = true }
cleaningThread.interrupt()
}

/** Clean all data and metadata related to a RDD, including shuffle files and metadata */
def cleanRDD(rdd: RDD[_]) {
enqueue(CleanRDD(rdd.sparkContext, rdd.id))
logDebug("Enqueued RDD " + rdd + " for cleaning up")
}

def cleanShuffle(shuffleId: Int) {
enqueue(CleanShuffle(shuffleId))
logDebug("Enqueued shuffle " + shuffleId + " for cleaning up")
}

def attachListener(listener: CleanerListener) {
listeners += listener
}
/** Enqueue a cleaning task */
private def enqueue(task: CleaningTask) {
queue.put(task)
}

/** Keep cleaning RDDs and shuffle data */
private def keepCleaning() {
try {
while (!isStopped) {
val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS))
if (taskOpt.isDefined) {
logDebug("Got cleaning task " + taskOpt.get)
taskOpt.get match {
case CleanRDD(sc, rddId) => doCleanRDD(sc, rddId)
case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId)
}
}
}
} catch {
case ie: java.lang.InterruptedException =>
if (!isStopped) logWarning("Cleaning thread interrupted")
}
}

/** Perform RDD cleaning */
private def doCleanRDD(sc: SparkContext, rddId: Int) {
logDebug("Cleaning rdd "+ rddId)
sc.env.blockManager.master.removeRdd(rddId, false)
sc.persistentRdds.remove(rddId)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned rdd "+ rddId)
}

/** Perform shuffle cleaning */
private def doCleanShuffle(shuffleId: Int) {
logDebug("Cleaning shuffle "+ shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManager.master.removeShuffle(shuffleId)
listeners.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
}

private def mapOutputTrackerMaster = env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]

private def blockManager = env.blockManager

private def isStopped = synchronized { stopped }
}
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class ShuffleDependency[K, V](
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {

val shuffleId: Int = rdd.context.newShuffleId()

override def finalize() {
if (rdd != null) {
rdd.sparkContext.cleaner.cleanShuffle(shuffleId)
}
}
}


Expand Down
64 changes: 45 additions & 19 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@

package org.apache.spark

import scala.Some
import scala.collection.mutable.{HashSet, Map}
import scala.concurrent.Await

import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashSet
import scala.concurrent.Await
import scala.concurrent.duration._

import akka.actor._
import akka.pattern.ask

import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
import org.apache.spark.util._

private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
Expand All @@ -51,23 +51,21 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}
}

private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {

private val timeout = AkkaUtils.askTimeout(conf)

// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _

protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
/** This HashMap needs to have different storage behavior for driver and worker */
protected val mapStatuses: Map[Int, Array[MapStatus]]

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
protected var epoch: Long = 0
protected val epochLock = new java.lang.Object

private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)

// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
private def askTracker(message: Any): Any = {
Expand Down Expand Up @@ -138,8 +136,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
else {
} else {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
Expand All @@ -151,13 +148,12 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}

protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
mapStatuses.asInstanceOf[TimeStampedHashMap[_, _]].clearOldValues(cleanupTime)
}

def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
metadataCleaner.cancel()
trackerActor = null
}

Expand All @@ -182,15 +178,42 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}

private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {

/**
* Bounded HashMap for storing serialized statuses in the worker. This allows
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
* automatically repopulated by fetching them again from the driver.
*/
protected val MAX_MAP_STATUSES = 100
protected val mapStatuses = new BoundedHashMap[Int, Array[MapStatus]](MAX_MAP_STATUSES, true)
}


private[spark] class MapOutputTrackerMaster(conf: SparkConf)
extends MapOutputTracker(conf) {

// Cache a serialized version of the output statuses for each shuffle to send them out faster
private var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]

/**
* Timestamp based HashMap for storing mapStatuses in the master, so that statuses are dropped
* only by explicit deregistering or by ttl-based cleaning (if set). Other than these two
* scenarios, nothing should be dropped from this HashMap.
*/
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()

/**
* Bounded HashMap for storing serialized statuses in the master. This allows
* the HashMap stay bounded in memory-usage. Things dropped from this HashMap will be
* automatically repopulated by serializing the lost statuses again .
*/
protected val MAX_SERIALIZED_STATUSES = 100
private val cachedSerializedStatuses =
new BoundedHashMap[Int, Array[Byte]](MAX_SERIALIZED_STATUSES, true)

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
Expand Down Expand Up @@ -224,6 +247,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}

def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
}

def incrementEpoch() {
epochLock.synchronized {
epoch += 1
Expand Down Expand Up @@ -260,9 +287,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
bytes
}

protected override def cleanup(cleanupTime: Long) {
super.cleanup(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
def contains(shuffleId: Int): Boolean = {
mapStatuses.contains(shuffleId)
}

override def stop() {
Expand Down
12 changes: 9 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType,
ClosureCleaner}
import org.apache.spark.util._
import scala.Some
import org.apache.spark.storage.RDDInfo
import org.apache.spark.storage.StorageStatus

/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
Expand Down Expand Up @@ -150,7 +152,7 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()

// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]]
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)

Expand Down Expand Up @@ -202,6 +204,9 @@ class SparkContext(
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()

private[spark] val cleaner = new ContextCleaner(env)
cleaner.start()

ui.start()

/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
Expand Down Expand Up @@ -784,6 +789,7 @@ class SparkContext(
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
cleaner.stop()
dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ object SparkEnv extends Logging {
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf)
} else {
new MapOutputTracker(conf)
new MapOutputTrackerWorker(conf)
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,13 @@ abstract class RDD[T: ClassTag](
checkpointData.flatMap(_.getCheckpointFile)
}

def cleanup() {
sc.cleaner.cleanRDD(this)
dependencies.filter(_.isInstanceOf[ShuffleDependency[_, _]])
.map(_.asInstanceOf[ShuffleDependency[_, _]].shuffleId)
.foreach(sc.cleaner.cleanShuffle)
}

// =======================================================================
// Other internal methods and fields
// =======================================================================
Expand Down Expand Up @@ -1091,4 +1098,7 @@ abstract class RDD[T: ClassTag](
new JavaRDD(this)(elementClassTag)
}

override def finalize() {
cleanup()
}
}
12 changes: 4 additions & 8 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,16 @@ import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.util.BoundedHashMap

private[spark] object ResultTask {

// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]

// TODO: This object shouldn't have global variables
val metadataCleaner = new MetadataCleaner(
MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf)
val MAX_CACHE_SIZE = 100
val serializedInfoCache = new BoundedHashMap[Int, Array[Byte]](MAX_CACHE_SIZE, true)

def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
synchronized {
Expand Down
Loading

0 comments on commit e427a9e

Please sign in to comment.