Skip to content

Commit

Permalink
Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap
Browse files Browse the repository at this point in the history
This allows us to get rid of WrappedJavaHashMap without much duplicate code.
  • Loading branch information
andrewor14 committed Mar 29, 2014
1 parent fbfeec8 commit 88904a3
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 479 deletions.
1 change: 0 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private[spark] class BlockManager(
}
}

/** Return the status of the block identified by the given ID, if it exists. */
/** Get the BlockStatus for the block identified by the given ID, if it exists.*/
def getStatus(blockId: BlockId): Option[BlockStatus] = {
blockInfo.get(blockId).map { info =>
val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
Expand Down Expand Up @@ -635,9 +635,10 @@ private[spark] class BlockManager(
diskStore.putValues(blockId, iterator, level, askForBytes)
case ArrayBufferValues(array) =>
diskStore.putValues(blockId, array, level, askForBytes)
case ByteBufferValues(bytes) =>
case ByteBufferValues(bytes) => {
bytes.rewind()
diskStore.putBytes(blockId, bytes, level)
}
}
size = res.size
res.data match {
Expand Down Expand Up @@ -872,7 +873,7 @@ private[spark] class BlockManager(
}

private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
val iterator = blockInfo.getEntrySet.iterator
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
Expand Down
117 changes: 81 additions & 36 deletions core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,108 @@

package org.apache.spark.util

import java.util.Set
import java.util.Map.Entry
import java.util.concurrent.ConcurrentHashMap

import scala.collection.{immutable, JavaConversions, mutable}

import org.apache.spark.Logging

private[util] case class TimeStampedValue[T](timestamp: Long, value: T)
private[spark] case class TimeStampedValue[V](value: V, timestamp: Long)

/**
* A map that stores the timestamp of when a key was inserted along with the value. If specified,
* the timestamp of each pair can be updated every time it is accessed.
* Key-value pairs whose timestamps are older than a particular
* threshold time can then be removed using the clearOldValues method. It exposes a
* scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala
* HashMaps.
*
* Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe.
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* timestamp along with each key-value pair. If specified, the timestamp of each pair can be
* updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
* threshold time can then be removed using the clearOldValues method. This is intended to
* be a drop-in replacement of scala.collection.mutable.HashMap.
*
* @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
* updated when it is accessed
* @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed
*/
private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
extends WrappedJavaHashMap[A, B, A, TimeStampedValue[B]] with Logging {
extends mutable.Map[A, B]() with Logging {

private[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]()
private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]()

private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = {
new TimeStampedHashMap[K1, V1]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
if (value != null && updateTimeStampOnGet) {
internalMap.replace(key, value, TimeStampedValue(value.value, currentTime))
}
Option(value).map(_.value)
}

def internalMap = internalJavaMap
def iterator: Iterator[(A, B)] = {
val jIterator = getEntrySet.iterator()
JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value))
}

override def get(key: A): Option[B] = {
val timeStampedValue = internalMap.get(key)
if (updateTimeStampOnGet && timeStampedValue != null) {
internalJavaMap.replace(key, timeStampedValue,
TimeStampedValue(currentTime, timeStampedValue.value))
}
Option(timeStampedValue).map(_.value)
def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet()

override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
val newMap = new TimeStampedHashMap[A, B1]
val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]]
newMap.internalMap.putAll(oldInternalMap)
kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) }
newMap
}
@inline override protected def externalValueToInternalValue(v: B): TimeStampedValue[B] = {
new TimeStampedValue(currentTime, v)

override def - (key: A): mutable.Map[A, B] = {
val newMap = new TimeStampedHashMap[A, B]
newMap.internalMap.putAll(this.internalMap)
newMap.internalMap.remove(key)
newMap
}

override def += (kv: (A, B)): this.type = {
kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) }
this
}

override def -= (key: A): this.type = {
internalMap.remove(key)
this
}

override def update(key: A, value: B) {
this += ((key, value))
}

@inline override protected def internalValueToExternalValue(iv: TimeStampedValue[B]): B = {
iv.value
override def apply(key: A): B = {
val value = internalMap.get(key)
Option(value).map(_.value).getOrElse { throw new NoSuchElementException() }
}

/** Atomically put if a key is absent. This exposes the existing API of ConcurrentHashMap. */
override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = {
JavaConversions.mapAsScalaConcurrentMap(internalMap)
.map { case (k, TimeStampedValue(v, t)) => (k, v) }
.filter(p)
}

override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]()

override def size: Int = internalMap.size

override def foreach[U](f: ((A, B)) => U) {
val iterator = getEntrySet.iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val kv = (entry.getKey, entry.getValue.value)
f(kv)
}
}

// Should we return previous value directly or as Option?
def putIfAbsent(key: A, value: B): Option[B] = {
val prev = internalJavaMap.putIfAbsent(key, TimeStampedValue(currentTime, value))
val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime))
Option(prev).map(_.value)
}

/**
* Removes old key-value pairs that have timestamp earlier than `threshTime`,
* calling the supplied function on each such entry before removing.
*/
def toMap: immutable.Map[A, B] = iterator.toMap

def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
val iterator = internalJavaMap.entrySet().iterator()
val iterator = getEntrySet.iterator()
while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue.timestamp < threshTime) {
Expand All @@ -86,11 +130,12 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa
}

/**
* Removes old key-value pairs that have timestamp earlier than `threshTime`
* Removes old key-value pairs that have timestamp earlier than `threshTime`.
*/
def clearOldValues(threshTime: Long) {
clearOldValues(threshTime, (_, _) => ())
}

private def currentTime: Long = System.currentTimeMillis()
private def currentTime: Long = System.currentTimeMillis

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,113 +18,115 @@
package org.apache.spark.util

import java.lang.ref.WeakReference
import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConversions

import org.apache.spark.Logging

private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) {
def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value))
}
import scala.collection.{immutable, mutable}

/**
* A map that stores the timestamp of when a key was inserted along with the value,
* while ensuring that the values are weakly referenced. If the value is garbage collected and
* the weak reference is null, get() operation returns the key be non-existent. However,
* the key is actually not removed in the current implementation. Key-value pairs whose
* timestamps are older than a particular threshold time can then be removed using the
* clearOldValues method. It exposes a scala.collection.mutable.Map interface to allow it to be a
* drop-in replacement for Scala HashMaps.
* A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
*
* If the value is garbage collected and the weak reference is null, get() operation returns
* a non-existent value. However, the corresponding key is actually not removed in the current
* implementation. Key-value pairs whose timestamps are older than a particular threshold time
* can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map
* interface to allow it to be a drop-in replacement for Scala HashMaps.
*
* Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe.
*
* @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
*/
private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
extends mutable.Map[A, B]() {

import TimeStampedWeakValueHashMap._

private[spark] class TimeStampedWeakValueHashMap[A, B]()
extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging {
private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)

/** Number of inserts after which keys whose weak ref values are null will be cleaned */
private val CLEANUP_INTERVAL = 1000
def get(key: A): Option[B] = internalMap.get(key)

/** Counter for counting the number of inserts */
private val insertCounts = new AtomicInteger(0)
def iterator: Iterator[(A, B)] = internalMap.iterator

override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
val newMap = new TimeStampedWeakValueHashMap[A, B1]
newMap.internalMap += kv
newMap
}

private[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = {
new ConcurrentHashMap[A, TimeStampedWeakValue[B]]()
override def - (key: A): mutable.Map[A, B] = {
val newMap = new TimeStampedWeakValueHashMap[A, B]
newMap.internalMap -= key
newMap
}

private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = {
new TimeStampedWeakValueHashMap[K1, V1]()
override def += (kv: (A, B)): this.type = {
internalMap += kv
this
}

override def +=(kv: (A, B)): this.type = {
// Cleanup null value at certain intervals
if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) {
cleanNullValues()
}
super.+=(kv)
override def -= (key: A): this.type = {
internalMap -= key
this
}

override def get(key: A): Option[B] = {
Option(internalJavaMap.get(key)).flatMap { weakValue =>
val value = weakValue.weakValue.get
if (value == null) {
internalJavaMap.remove(key)
}
Option(value)
}
override def update(key: A, value: B) = this += ((key, value))

override def apply(key: A): B = internalMap.apply(key)

override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p)

override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()

override def size: Int = internalMap.size

override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f)

def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)

def toMap: immutable.Map[A, B] = iterator.toMap

/**
* Remove old key-value pairs that have timestamp earlier than `threshTime`.
*/
def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)

}

/**
* Helper methods for converting to and from WeakReferences.
*/
private[spark] object TimeStampedWeakValueHashMap {

/* Implicit conversion methods to WeakReferences */

implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)

implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = {
kv match { case (k, v) => (k, toWeakReference(v)) }
}

@inline override protected def externalValueToInternalValue(v: B): TimeStampedWeakValue[B] = {
new TimeStampedWeakValue(currentTime, v)
implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = {
(kv: (K, WeakReference[V])) => p(kv)
}

@inline override protected def internalValueToExternalValue(iv: TimeStampedWeakValue[B]): B = {
iv.weakValue.get
/* Implicit conversion methods from WeakReferences */

implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get

implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
v.map(fromWeakReference)
}

override def iterator: Iterator[(A, B)] = {
val iterator = internalJavaMap.entrySet().iterator()
JavaConversions.asScalaIterator(iterator).flatMap(kv => {
val (key, value) = (kv.getKey, kv.getValue.weakValue.get)
if (value != null) Seq((key, value)) else Seq.empty
})
implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
kv match { case (k, v) => (k, fromWeakReference(v)) }
}

/**
* Removes old key-value pairs that have timestamp earlier than `threshTime`,
* calling the supplied function on each such entry before removing.
*/
def clearOldValues(threshTime: Long, f: (A, B) => Unit = null) {
val iterator = internalJavaMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue.timestamp < threshTime) {
val value = entry.getValue.weakValue.get
if (f != null && value != null) {
f(entry.getKey, value)
}
logDebug("Removing key " + entry.getKey)
iterator.remove()
}
}
implicit def fromWeakReferenceIterator[K, V](
it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = {
it.map(fromWeakReferenceTuple)
}

/**
* Removes keys whose weak referenced values have become null.
*/
private def cleanNullValues() {
val iterator = internalJavaMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
if (entry.getValue.weakValue.get == null) {
logDebug("Removing key " + entry.getKey)
iterator.remove()
}
}
implicit def fromWeakReferenceMap[K, V](
map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
}

private def currentTime = System.currentTimeMillis()
}
Loading

0 comments on commit 88904a3

Please sign in to comment.