Skip to content

Commit

Permalink
Added try-catch in context cleaner and null value cleaning in TimeSta…
Browse files Browse the repository at this point in the history
…mpedWeakValueHashMap.
  • Loading branch information
tdas committed Mar 17, 2014
1 parent e61daa0 commit a7260d3
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 34 deletions.
50 changes: 30 additions & 20 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Start the cleaner */
def start() {
cleaningThread.setDaemon(true)
cleaningThread.setName("ContextCleaner")
cleaningThread.start()
}

Expand All @@ -60,7 +61,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}

/**
* Clean (unpersist) RDD data. Do not perform any time or resource intensive
* Clean RDD data. Do not perform any time or resource intensive
* computation in this function as this is called from a finalize() function.
*/
def cleanRDD(rddId: Int) {
Expand Down Expand Up @@ -92,39 +93,48 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Keep cleaning RDDs and shuffle data */
private def keepCleaning() {
try {
while (!isStopped) {
while (!isStopped) {
try {
val taskOpt = Option(queue.poll(100, TimeUnit.MILLISECONDS))
taskOpt.foreach(task => {
taskOpt.foreach { task =>
logDebug("Got cleaning task " + taskOpt.get)
task match {
case CleanRDD(rddId) => doCleanRDD(sc, rddId)
case CleanRDD(rddId) => doCleanRDD(rddId)
case CleanShuffle(shuffleId) => doCleanShuffle(shuffleId)
}
})
}
} catch {
case ie: InterruptedException =>
if (!isStopped) logWarning("Cleaning thread interrupted")
case t: Throwable => logError("Error in cleaning thread", t)
}
} catch {
case ie: InterruptedException =>
if (!isStopped) logWarning("Cleaning thread interrupted")
}
}

/** Perform RDD cleaning */
private def doCleanRDD(sc: SparkContext, rddId: Int) {
logDebug("Cleaning rdd " + rddId)
blockManagerMaster.removeRdd(rddId, false)
sc.persistentRdds.remove(rddId)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned rdd " + rddId)
private def doCleanRDD(rddId: Int) {
try {
logDebug("Cleaning RDD " + rddId)
blockManagerMaster.removeRdd(rddId, false)
sc.persistentRdds.remove(rddId)
listeners.foreach(_.rddCleaned(rddId))
logInfo("Cleaned RDD " + rddId)
} catch {
case t: Throwable => logError("Error cleaning RDD " + rddId, t)
}
}

/** Perform shuffle cleaning */
private def doCleanShuffle(shuffleId: Int) {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId)
listeners.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
try {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId)
listeners.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
} catch {
case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t)
}
}

private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.Some
import scala.collection.mutable.{HashSet, Map}
import scala.concurrent.Await

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.lang.ref.WeakReference
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.Logging
import java.util.concurrent.atomic.AtomicInteger

private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) {
def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value))
Expand All @@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea
private[spark] class TimeStampedWeakValueHashMap[A, B]()
extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging {

/** Number of inserts after which keys whose weak ref values are null will be cleaned */
private val CLEANUP_INTERVAL = 1000

/** Counter for counting the number of inserts */
private val insertCounts = new AtomicInteger(0)

protected[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = {
new ConcurrentHashMap[A, TimeStampedWeakValue[B]]()
}
Expand All @@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
new TimeStampedWeakValueHashMap[K1, V1]()
}

override def +=(kv: (A, B)): this.type = {
// Cleanup null value at certain intervals
if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) {
cleanNullValues()
}
super.+=(kv)
}

override def get(key: A): Option[B] = {
Option(internalJavaMap.get(key)) match {
case Some(weakValue) =>
val value = weakValue.weakValue.get
if (value == null) cleanupKey(key)
if (value == null) {
internalJavaMap.remove(key)
}
Option(value)
case None =>
None
Expand All @@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
}

override def iterator: Iterator[(A, B)] = {
val jIterator = internalJavaMap.entrySet().iterator()
JavaConversions.asScalaIterator(jIterator).flatMap(kv => {
val key = kv.getKey
val value = kv.getValue.weakValue.get
if (value == null) {
cleanupKey(key)
Seq.empty
} else {
Seq((key, value))
}
val iterator = internalJavaMap.entrySet().iterator()
JavaConversions.asScalaIterator(iterator).flatMap(kv => {
val (key, value) = (kv.getKey, kv.getValue.weakValue.get)
if (value != null) Seq((key, value)) else Seq.empty
})
}

Expand All @@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
}
}

private def cleanupKey(key: A) {
// TODO: Consider cleaning up keys to empty weak ref values automatically in future.
/**
* Removes keys whose weak referenced values have become null.
*/
private def cleanNullValues() {
val iterator = internalJavaMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue.weakValue.get == null) {
logDebug("Removing key " + entry.getKey)
iterator.remove()
}
}
}

private def currentTime = System.currentTimeMillis()
Expand Down

0 comments on commit a7260d3

Please sign in to comment.