Skip to content

Commit

Permalink
Implements specialized subsetOf for HashSet
Browse files Browse the repository at this point in the history
Fixes SI-7326. This also adds a basic test for subsetOf that was missing before.
  • Loading branch information
rklaehn committed Jan 14, 2014
1 parent 9f0594c commit 24a227d
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 1 deletion.
85 changes: 84 additions & 1 deletion src/library/scala/collection/immutable/HashSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ package scala
package collection
package immutable

import scala.annotation.unchecked.{ uncheckedVariance => uV }
import generic._
import scala.collection.parallel.immutable.ParHashSet
import scala.collection.GenSet

/** This class implements immutable sets using a hash trie.
*
Expand Down Expand Up @@ -54,6 +54,30 @@ class HashSet[A] extends AbstractSet[A]

def contains(e: A): Boolean = get0(e, computeHash(e), 0)

override def subsetOf(that: GenSet[A]) = that match {
case that:HashSet[A] =>
// call the specialized implementation with a level of 0 since both this and that are top-level hash sets
subsetOf0(that, 0)
case _ =>
// call the generic implementation
super.subsetOf(that)
}

/**
* A specialized implementation of subsetOf for when both this and that are HashSet[A] and we can take advantage
* of the tree structure of both operands and the precalculated hashcodes of the HashSet1 instances.
* @param that the other set
* @param level the level of this and that hashset
* The purpose of level is to keep track of how deep we are in the tree.
* We need this information for when we arrive at a leaf and have to call get0 on that
* The value of level is 0 for a top-level HashSet and grows in increments of 5
* @return true if all elements of this set are contained in that set
*/
protected def subsetOf0(that: HashSet[A], level: Int) = {
// The default implementation is for the empty set and returns true because the empty set is a subset of all sets
true
}

override def + (e: A): HashSet[A] = updated0(e, computeHash(e), 0)

override def + (elem1: A, elem2: A, elems: A*): HashSet[A] =
Expand Down Expand Up @@ -136,6 +160,14 @@ object HashSet extends ImmutableSetFactory[HashSet] {
override def get0(key: A, hash: Int, level: Int): Boolean =
(hash == this.hash && key == this.key)

override def subsetOf0(that: HashSet[A], level: Int) = {
// check if that contains this.key
// we use get0 with our key and hash at the correct level instead of calling contains,
// which would not work since that might not be a top-level HashSet
// and in any case would be inefficient because it would require recalculating the hash code
that.get0(key, hash, level)
}

override def updated0(key: A, hash: Int, level: Int): HashSet[A] =
if (hash == this.hash && key == this.key) this
else {
Expand All @@ -162,6 +194,14 @@ object HashSet extends ImmutableSetFactory[HashSet] {
override def get0(key: A, hash: Int, level: Int): Boolean =
if (hash == this.hash) ks.contains(key) else false

override def subsetOf0(that: HashSet[A], level: Int) = {
// we have to check each element
// we use get0 with our hash at the correct level instead of calling contains,
// which would not work since that might not be a top-level HashSet
// and in any case would be inefficient because it would require recalculating the hash code
ks.forall(key => that.get0(key, hash, level))
}

override def updated0(key: A, hash: Int, level: Int): HashSet[A] =
if (hash == this.hash) new HashSetCollision1(hash, ks + key)
else makeHashTrieSet(this.hash, this, hash, new HashSet1(key, hash), level)
Expand Down Expand Up @@ -279,6 +319,49 @@ object HashSet extends ImmutableSetFactory[HashSet] {
}
}

override def subsetOf0(that: HashSet[A], level: Int): Boolean = if (that eq this) true else that match {
case that: HashTrieSet[A] if this.size0 <= that.size0 =>
// create local mutable copies of members
var abm = this.bitmap
val a = this.elems
var ai = 0
val b = that.elems
var bbm = that.bitmap
var bi = 0
if ((abm & bbm) == abm) {
// I tried rewriting this using tail recursion, but the generated java byte code was less than optimal
while(abm!=0) {
// highest remaining bit in abm
val alsb = abm ^ (abm & (abm - 1))
// highest remaining bit in bbm
val blsb = bbm ^ (bbm & (bbm - 1))
// if both trees have a bit set at the same position, we need to check the subtrees
if (alsb == blsb) {
// we are doing a comparison of a child of this with a child of that,
// so we have to increase the level by 5 to keep track of how deep we are in the tree
if (!a(ai).subsetOf0(b(bi), level + 5))
return false
// clear lowest remaining one bit in abm and increase the a index
abm &= ~alsb; ai += 1
}
// clear lowermost remaining one bit in bbm and increase the b index
// we must do this in any case
bbm &= ~blsb; bi += 1
}
true
} else {
// the bitmap of this contains more one bits than the bitmap of that,
// so this can not possibly be a subset of that
false
}
case _ =>
// if the other set is a HashTrieSet but has less elements than this, it can not be a subset
// if the other set is a HashSet1, we can not be a subset of it because we are a HashTrieSet with at least two children (see assertion)
// if the other set is a HashSetCollision1, we can not be a subset of it because we are a HashTrieSet with at least two different hash codes
// if the other set is the empty set, we are not a subset of it because we are not empty
false
}

override def iterator = new TrieIterator[A](elems.asInstanceOf[Array[Iterable[A]]]) {
final override def getElem(cc: AnyRef): A = cc.asInstanceOf[HashSet1[A]].key
}
Expand Down
64 changes: 64 additions & 0 deletions test/files/run/t7326.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import scala.collection.immutable.ListSet
import scala.collection.immutable.HashSet

object Test extends App {

def testCorrectness() {
// a key that has many hashCode collisions
case class Collision(i: Int) { override def hashCode = i / 5 }

def subsetTest[T](emptyA:Set[T], emptyB:Set[T], mkKey:Int => T, n:Int) {
val outside = mkKey(n + 1)
for(i <- 0 to n) {
val a = emptyA ++ (0 until i).map(mkKey)
// every set must be a subset of itself
require(a.subsetOf(a), "A set must be the subset of itself")
for(k <- 0 to i) {
// k <= i, so b is definitely a subset
val b = emptyB ++ (0 until k).map(mkKey)
// c has less elements than a, but contains a value that is not in a
// so it is not a subset, but that is not immediately obvious due to size
val c = b + outside
require(b.subsetOf(a), s"$b must be a subset of $a")
require(!c.subsetOf(a), s"$c must not be a subset of $a")
}
}
}

// test the HashSet/HashSet case
subsetTest(HashSet.empty[Int], HashSet.empty[Int], identity, 100)

// test the HashSet/other set case
subsetTest(HashSet.empty[Int], ListSet.empty[Int], identity, 100)

// test the HashSet/HashSet case for Collision keys
subsetTest(HashSet.empty[Collision], HashSet.empty[Collision], Collision, 100)

// test the HashSet/other set case for Collision keys
subsetTest(HashSet.empty[Collision], ListSet.empty[Collision], Collision, 100)
}

/**
* A main performance benefit of the new subsetOf is that we do not have to call hashCode during subsetOf
* since we already have the hash codes in the HashSet1 nodes.
*/
def testNoHashCodeInvocationsDuringSubsetOf() = {
var count = 0

case class HashCodeCounter(i:Int) {
override def hashCode = {
count += 1
i
}
}

val a = HashSet.empty ++ (0 until 100).map(HashCodeCounter)
val b = HashSet.empty ++ (0 until 50).map(HashCodeCounter)
val count0 = count
val result = b.subsetOf(a)
require(count == count0, "key.hashCode must not be called during subsetOf of two HashSets")
result
}
testCorrectness()
testNoHashCodeInvocationsDuringSubsetOf()
}

0 comments on commit 24a227d

Please sign in to comment.