Skip to content

Commit

Permalink
Merge pull request #3318 from rklaehn/issue/6196
Browse files Browse the repository at this point in the history
SI-6196 - Set should implement filter
  • Loading branch information
Ichoran committed Jan 15, 2014
2 parents 9cb5ed8 + 47a91d7 commit dddf1f5
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 0 deletions.
122 changes: 122 additions & 0 deletions src/library/scala/collection/immutable/HashMap.scala
Expand Up @@ -40,6 +40,8 @@ class HashMap[A, +B] extends AbstractMap[A, B]
with Serializable
with CustomParallelizable[(A, B), ParHashMap[A, B]]
{
import HashMap.{nullToEmpty, bufferSize}

override def size: Int = 0

override def empty = HashMap.empty[A, B]
Expand All @@ -63,6 +65,18 @@ class HashMap[A, +B] extends AbstractMap[A, B]
def - (key: A): HashMap[A, B] =
removed0(key, computeHash(key), 0)

override def filter(p: ((A, B)) => Boolean) = {
val buffer = new Array[HashMap[A, B]](bufferSize(size))
nullToEmpty(filter0(p, false, 0, buffer, 0))
}

override def filterNot(p: ((A, B)) => Boolean) = {
val buffer = new Array[HashMap[A, B]](bufferSize(size))
nullToEmpty(filter0(p, true, 0, buffer, 0))
}

protected def filter0(p: ((A, B)) => Boolean, negate: Boolean, level: Int, buffer: Array[HashMap[A, B @uV]], offset0: Int): HashMap[A, B] = null

protected def elemHashCode(key: A) = key.##

protected final def improve(hcode: Int) = {
Expand Down Expand Up @@ -200,6 +214,9 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
override def removed0(key: A, hash: Int, level: Int): HashMap[A, B] =
if (hash == this.hash && key == this.key) HashMap.empty[A,B] else this

override protected def filter0(p: ((A, B)) => Boolean, negate: Boolean, level: Int, buffer: Array[HashMap[A, B @uV]], offset0: Int): HashMap[A, B] =
if (negate ^ p(ensurePair)) this else null

override def iterator: Iterator[(A,B)] = Iterator(ensurePair)
override def foreach[U](f: ((A, B)) => U): Unit = f(ensurePair)
// this method may be called multiple times in a multithreaded environment, but that's ok
Expand Down Expand Up @@ -241,6 +258,21 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
new HashMapCollision1(hash, kvs1)
} else this

override protected def filter0(p: ((A, B)) => Boolean, negate: Boolean, level: Int, buffer: Array[HashMap[A, B @uV]], offset0: Int): HashMap[A, B] = {
val kvs1 = if(negate) kvs.filterNot(p) else kvs.filter(p)
kvs1.size match {
case 0 =>
null
case 1 =>
val kv@(k,v) = kvs1.head
new HashMap1(k, hash, v, kv)
case x if x == kvs.size =>
this
case _ =>
new HashMapCollision1(hash, kvs1)
}
}

override def iterator: Iterator[(A,B)] = kvs.iterator
override def foreach[U](f: ((A, B)) => U): Unit = kvs.foreach(f)
override def split: Seq[HashMap[A, B]] = {
Expand Down Expand Up @@ -336,6 +368,52 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
}
}

override protected def filter0(p: ((A, B)) => Boolean, negate: Boolean, level: Int, buffer: Array[HashMap[A, B @uV]], offset0: Int): HashMap[A, B] = {
// current offset
var offset = offset0
// result size
var rs = 0
// bitmap for kept elems
var kept = 0
// loop over all elements
var i = 0
while (i < elems.length) {
val result = elems(i).filter0(p, negate, level + 5, buffer, offset)
if (result ne null) {
buffer(offset) = result
offset += 1
// add the result size
rs += result.size
// mark the bit i as kept
kept |= (1 << i)
}
i += 1
}
if (offset == offset0) {
// empty
null
} else if (rs == size0) {
// unchanged
this
} else if (offset == offset0 + 1 && !buffer(offset0).isInstanceOf[HashTrieMap[A, B]]) {
// leaf
buffer(offset0)
} else {
// we have to return a HashTrieMap
val length = offset - offset0
val elems1 = new Array[HashMap[A, B]](length)
System.arraycopy(buffer, offset0, elems1, 0, length)
val bitmap1 = if (length == elems.length) {
// we can reuse the original bitmap
bitmap
} else {
// calculate new bitmap by keeping just bits in the kept bitmask
keepBits(bitmap, kept)
}
new HashTrieMap(bitmap1, elems1, rs)
}
}

override def iterator: Iterator[(A, B)] = new TrieIterator[(A, B)](elems.asInstanceOf[Array[Iterable[(A, B)]]]) {
final override def getElem(cc: AnyRef): (A, B) = cc.asInstanceOf[HashMap1[A, B]].ensurePair
}
Expand Down Expand Up @@ -439,6 +517,50 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
}
}

/**
* Calculates the maximum buffer size given the maximum possible total size of the trie-based collection
* @param size the maximum size of the collection to be generated
* @return the maximum buffer size
*/
@inline private def bufferSize(size: Int): Int = (size + 6) min (32 * 7)

/**
* In many internal operations the empty map is represented as null for performance reasons. This method converts
* null to the empty map for use in public methods
*/
@inline private def nullToEmpty[A, B](m: HashMap[A, B]): HashMap[A, B] = if (m eq null) empty[A, B] else m

/**
* Utility method to keep a subset of all bits in a given bitmap
*
* Example
* bitmap (binary): 00000001000000010000000100000001
* keep (binary): 1010
* result (binary): 00000001000000000000000100000000
*
* @param bitmap the bitmap
* @param keep a bitmask containing which bits to keep
* @return the original bitmap with all bits where keep is not 1 set to 0
*/
private def keepBits(bitmap: Int, keep: Int): Int = {
var result = 0
var current = bitmap
var kept = keep
while (kept != 0) {
// lowest remaining bit in current
val lsb = current ^ (current & (current - 1))
if ((kept & 1) != 0) {
// mark bit in result bitmap
result |= lsb
}
// clear lowest remaining one bit in abm
current &= ~lsb
// look at the next kept bit
kept >>>= 1
}
result
}

@SerialVersionUID(2L)
private class SerializationProxy[A,B](@transient private var orig: HashMap[A, B]) extends Serializable {
private def writeObject(out: java.io.ObjectOutputStream) {
Expand Down
121 changes: 121 additions & 0 deletions src/library/scala/collection/immutable/HashSet.scala
Expand Up @@ -38,6 +38,8 @@ class HashSet[A] extends AbstractSet[A]
with CustomParallelizable[A, ParHashSet[A]]
with Serializable
{
import HashSet.{nullToEmpty, bufferSize}

override def companion: GenericCompanion[HashSet] = HashSet

//class HashSet[A] extends Set[A] with SetLike[A, HashSet[A]] {
Expand Down Expand Up @@ -86,6 +88,18 @@ class HashSet[A] extends AbstractSet[A]
def - (e: A): HashSet[A] =
removed0(e, computeHash(e), 0)

override def filter(p: A => Boolean) = {
val buffer = new Array[HashSet[A]](bufferSize(size))
nullToEmpty(filter0(p, false, 0, buffer, 0))
}

override def filterNot(p: A => Boolean) = {
val buffer = new Array[HashSet[A]](bufferSize(size))
nullToEmpty(filter0(p, true, 0, buffer, 0))
}

protected def filter0(p: A => Boolean, negate: Boolean, level: Int, buffer: Array[HashSet[A]], offset0: Int): HashSet[A] = null

protected def elemHashCode(key: A) = key.##

protected final def improve(hcode: Int) = {
Expand Down Expand Up @@ -179,6 +193,9 @@ object HashSet extends ImmutableSetFactory[HashSet] {
override def removed0(key: A, hash: Int, level: Int): HashSet[A] =
if (hash == this.hash && key == this.key) HashSet.empty[A] else this

override protected def filter0(p: A => Boolean, negate: Boolean, level: Int, buffer: Array[HashSet[A]], offset0: Int): HashSet[A] =
if (negate ^ p(key)) this else null

override def iterator: Iterator[A] = Iterator(key)
override def foreach[U](f: A => U): Unit = f(key)
}
Expand Down Expand Up @@ -214,6 +231,20 @@ object HashSet extends ImmutableSetFactory[HashSet] {
new HashSetCollision1(hash, ks1)
} else this

override protected def filter0(p: A => Boolean, negate: Boolean, level: Int, buffer: Array[HashSet[A]], offset0: Int): HashSet[A] = {
val ks1 = if(negate) ks.filterNot(p) else ks.filter(p)
ks1.size match {
case 0 =>
null
case 1 =>
new HashSet1(ks1.head, hash)
case x if x == ks.size =>
this
case _ =>
new HashSetCollision1(hash, ks1)
}
}

override def iterator: Iterator[A] = ks.iterator
override def foreach[U](f: A => U): Unit = ks.foreach(f)

Expand Down Expand Up @@ -392,6 +423,52 @@ object HashSet extends ImmutableSetFactory[HashSet] {
false
}

override protected def filter0(p: A => Boolean, negate: Boolean, level: Int, buffer: Array[HashSet[A]], offset0: Int): HashSet[A] = {
// current offset
var offset = offset0
// result size
var rs = 0
// bitmap for kept elems
var kept = 0
// loop over all elements
var i = 0
while (i < elems.length) {
val result = elems(i).filter0(p, negate, level + 5, buffer, offset)
if (result ne null) {
buffer(offset) = result
offset += 1
// add the result size
rs += result.size
// mark the bit i as kept
kept |= (1 << i)
}
i += 1
}
if (offset == offset0) {
// empty
null
} else if (rs == size0) {
// unchanged
this
} else if (offset == offset0 + 1 && !buffer(offset0).isInstanceOf[HashTrieSet[A]]) {
// leaf
buffer(offset0)
} else {
// we have to return a HashTrieSet
val length = offset - offset0
val elems1 = new Array[HashSet[A]](length)
System.arraycopy(buffer, offset0, elems1, 0, length)
val bitmap1 = if (length == elems.length) {
// we can reuse the original bitmap
bitmap
} else {
// calculate new bitmap by keeping just bits in the kept bitmask
keepBits(bitmap, kept)
}
new HashTrieSet(bitmap1, elems1, rs)
}
}

override def iterator = new TrieIterator[A](elems.asInstanceOf[Array[Iterable[A]]]) {
final override def getElem(cc: AnyRef): A = cc.asInstanceOf[HashSet1[A]].key
}
Expand All @@ -405,6 +482,50 @@ object HashSet extends ImmutableSetFactory[HashSet] {
}
}

/**
* Calculates the maximum buffer size given the maximum possible total size of the trie-based collection
* @param size the maximum size of the collection to be generated
* @return the maximum buffer size
*/
@inline private def bufferSize(size: Int): Int = (size + 6) min (32 * 7)

/**
* In many internal operations the empty set is represented as null for performance reasons. This method converts
* null to the empty set for use in public methods
*/
@inline private def nullToEmpty[A](s: HashSet[A]): HashSet[A] = if (s eq null) empty[A] else s

/**
* Utility method to keep a subset of all bits in a given bitmap
*
* Example
* bitmap (binary): 00000001000000010000000100000001
* keep (binary): 1010
* result (binary): 00000001000000000000000100000000
*
* @param bitmap the bitmap
* @param keep a bitmask containing which bits to keep
* @return the original bitmap with all bits where keep is not 1 set to 0
*/
private def keepBits(bitmap: Int, keep: Int): Int = {
var result = 0
var current = bitmap
var kept = keep
while (kept != 0) {
// lowest remaining bit in current
val lsb = current ^ (current & (current - 1))
if ((kept & 1) != 0) {
// mark bit in result bitmap
result |= lsb
}
// clear lowest remaining one bit in abm
current &= ~lsb
// look at the next kept bit
kept >>>= 1
}
result
}

@SerialVersionUID(2L) private class SerializationProxy[A,B](@transient private var orig: HashSet[A]) extends Serializable {
private def writeObject(out: java.io.ObjectOutputStream) {
val s = orig.size
Expand Down
68 changes: 68 additions & 0 deletions test/files/run/t6196.scala
@@ -0,0 +1,68 @@
import scala.collection.immutable.HashSet

object Test extends App {

case class Collision(value: Int) extends Ordered[Collision] {
def compare(that:Collision) = value compare that.value

override def hashCode = value / 5
}

def testCorrectness[T : Ordering](n: Int, mkKey: Int => T) {
val o = implicitly[Ordering[T]]
val s = HashSet.empty[T] ++ (0 until n).map(mkKey)
for (i <- 0 until n) {
val ki = mkKey(i)
val a = s.filter(o.lt(_,ki))
val b = s.filterNot(o.lt(_,ki))
require(a.size == i && (0 until i).forall(i => a.contains(mkKey(i))))
require(b.size == n - i && (i until n).forall(i => b.contains(mkKey(i))))
}
}

// this tests the structural sharing of the new filter
// I could not come up with a simple test that tests structural sharing when only parts are reused, but
// at least this fails with the old and passes with the new implementation
def testSharing() {
val s = HashSet.empty[Int] ++ (0 until 100)
require(s.filter(_ => true) eq s)
require(s.filterNot(_ => false) eq s)
}

// this tests that neither hashCode nor equals are called during filter
def testNoHashing() {
var hashCount = 0
var equalsCount = 0
case class HashCounter(value:Int) extends Ordered[HashCounter] {
def compare(that:HashCounter) = value compare that.value

override def hashCode = {
hashCount += 1
value
}

override def equals(that:Any) = {
equalsCount += 1
this match {
case HashCounter(value) => this.value == value
case _ => false
}
}
}

val s = HashSet.empty[HashCounter] ++ (0 until 100).map(HashCounter)
val hashCount0 = hashCount
val equalsCount0 = equalsCount
val t = s.filter(_<HashCounter(50))
require(hashCount == hashCount0)
require(equalsCount == equalsCount0)
}

// this tests correctness of filter and filterNot for integer keys
testCorrectness[Int](100, identity _)
// this tests correctness of filter and filterNot for keys with lots of collisions
// this is necessary because usually collisions are rare so the collision-related code is not thoroughly tested
testCorrectness[Collision](100, Collision.apply _)
testSharing()
testNoHashing()
}

0 comments on commit dddf1f5

Please sign in to comment.