Skip to content

Commit

Permalink
Custom implementations of drop/take/slice.
Browse files Browse the repository at this point in the history
This mainly helps performance when comparing keys is expensive.
  • Loading branch information
erikrozendaal committed Jan 23, 2012
1 parent 7824dbd commit 78374f3
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 12 deletions.
39 changes: 38 additions & 1 deletion src/library/scala/collection/immutable/RedBlackTree.scala
Expand Up @@ -56,6 +56,10 @@ object RedBlackTree {
def to[A: Ordering, B](tree: Tree[A, B], to: A): Tree[A, B] = blacken(doTo(tree, to))
def until[A: Ordering, B](tree: Tree[A, B], key: A): Tree[A, B] = blacken(doUntil(tree, key))

def drop[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = blacken(doDrop(tree, n))
def take[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = blacken(doTake(tree, n))
def slice[A: Ordering, B](tree: Tree[A, B], from: Int, until: Int): Tree[A, B] = blacken(doSlice(tree, from, until))

def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = {
if (tree eq null) throw new NoSuchElementException("empty map")
var result = tree
Expand Down Expand Up @@ -86,7 +90,7 @@ object RedBlackTree {

@tailrec
def nth[A, B](tree: Tree[A, B], n: Int): Tree[A, B] = {
val count = RedBlackTree.count(tree.left)
val count = this.count(tree.left)
if (n < count) nth(tree.left, n)
else if (n > count) nth(tree.right, n - count - 1)
else tree
Expand Down Expand Up @@ -243,6 +247,39 @@ object RedBlackTree {
else rebalance(tree, newLeft, newRight)
}

private[this] def doDrop[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = {
if (n <= 0) return tree
if (n >= this.count(tree)) return null
val count = this.count(tree.left)
if (n > count) return doDrop(tree.right, n - count - 1)
val newLeft = doDrop(tree.left, n)
if (newLeft eq tree.left) tree
else if (newLeft eq null) upd(tree.right, tree.key, tree.value)
else rebalance(tree, newLeft, tree.right)
}
private[this] def doTake[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = {
if (n <= 0) return null
if (n >= this.count(tree)) return tree
val count = this.count(tree.left)
if (n <= count) return doTake(tree.left, n)
val newRight = doTake(tree.right, n - count - 1)
if (newRight eq tree.right) tree
else if (newRight eq null) upd(tree.left, tree.key, tree.value)
else rebalance(tree, tree.left, newRight)
}
private[this] def doSlice[A: Ordering, B](tree: Tree[A, B], from: Int, until: Int): Tree[A, B] = {
if (tree eq null) return null
val count = this.count(tree.left)
if (from > count) return doSlice(tree.right, from - count - 1, until - count - 1)
if (until <= count) return doSlice(tree.left, from, until)
val newLeft = doDrop(tree.left, from)
val newRight = doTake(tree.right, until - count - 1)
if ((newLeft eq tree.left) && (newRight eq tree.right)) tree
else if (newLeft eq null) upd(newRight, tree.key, tree.value)
else if (newRight eq null) upd(newLeft, tree.key, tree.value)
else rebalance(tree, newLeft, newRight)
}

// The zipper returned might have been traversed left-most (always the left child)
// or right-most (always the right child). Left trees are traversed right-most,
// and right trees are traversed leftmost.
Expand Down
6 changes: 3 additions & 3 deletions src/library/scala/collection/immutable/TreeMap.scala
Expand Up @@ -89,20 +89,20 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi
override def drop(n: Int) = {
if (n <= 0) this
else if (n >= size) empty
else from(RB.nth(tree, n).key)
else new TreeMap(RB.drop(tree, n))
}

override def take(n: Int) = {
if (n <= 0) empty
else if (n >= size) this
else until(RB.nth(tree, n).key)
else new TreeMap(RB.take(tree, n))
}

override def slice(from: Int, until: Int) = {
if (until <= from) empty
else if (from <= 0) take(until)
else if (until >= size) drop(from)
else range(RB.nth(tree, from).key, RB.nth(tree, until).key)
else new TreeMap(RB.slice(tree, from, until))
}

override def dropRight(n: Int) = take(size - n)
Expand Down
6 changes: 3 additions & 3 deletions src/library/scala/collection/immutable/TreeSet.scala
Expand Up @@ -66,20 +66,20 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin
override def drop(n: Int) = {
if (n <= 0) this
else if (n >= size) empty
else from(RB.nth(tree, n).key)
else newSet(RB.drop(tree, n))
}

override def take(n: Int) = {
if (n <= 0) empty
else if (n >= size) this
else until(RB.nth(tree, n).key)
else newSet(RB.take(tree, n))
}

override def slice(from: Int, until: Int) = {
if (until <= from) empty
else if (from <= 0) take(until)
else if (until >= size) drop(from)
else range(RB.nth(tree, from).key, RB.nth(tree, until).key)
else newSet(RB.slice(tree, from, until))
}

override def dropRight(n: Int) = take(size - n)
Expand Down
18 changes: 15 additions & 3 deletions test/files/scalacheck/treemap.scala
Expand Up @@ -7,11 +7,12 @@ import util._
import Buildable._

object Test extends Properties("TreeMap") {
implicit def arbTreeMap[A : Arbitrary : Ordering, B : Arbitrary]: Arbitrary[TreeMap[A, B]] =
Arbitrary(for {
def genTreeMap[A: Arbitrary: Ordering, B: Arbitrary]: Gen[TreeMap[A, B]] =
for {
keys <- listOf(arbitrary[A])
values <- listOfN(keys.size, arbitrary[B])
} yield TreeMap(keys zip values: _*))
} yield TreeMap(keys zip values: _*)
implicit def arbTreeMap[A : Arbitrary : Ordering, B : Arbitrary] = Arbitrary(genTreeMap[A, B])

property("foreach/iterator consistency") = forAll { (subject: TreeMap[Int, String]) =>
val it = subject.iterator
Expand Down Expand Up @@ -96,6 +97,17 @@ object Test extends Properties("TreeMap") {
prefix == subject.take(n) && suffix == subject.drop(n)
}

def genSliceParms = for {
tree <- genTreeMap[Int, String]
from <- choose(0, tree.size)
until <- choose(from, tree.size)
} yield (tree, from, until)

property("slice") = forAll(genSliceParms) { case (subject, from, until) =>
val slice = subject.slice(from, until)
slice.size == until - from && subject.toSeq == subject.take(from).toSeq ++ slice ++ subject.drop(until)
}

property("takeWhile") = forAll { (subject: TreeMap[Int, String]) =>
val result = subject.takeWhile(_._1 < 0)
result.forall(_._1 < 0) && result == subject.take(result.size)
Expand Down
18 changes: 16 additions & 2 deletions test/files/scalacheck/treeset.scala
Expand Up @@ -6,8 +6,11 @@ import Arbitrary._
import util._

object Test extends Properties("TreeSet") {
implicit def arbTreeSet[A : Arbitrary : Ordering]: Arbitrary[TreeSet[A]] =
Arbitrary(listOf(arbitrary[A]) map (elements => TreeSet(elements: _*)))
def genTreeSet[A: Arbitrary: Ordering]: Gen[TreeSet[A]] =
for {
elements <- listOf(arbitrary[A])
} yield TreeSet(elements: _*)
implicit def arbTreeSet[A : Arbitrary : Ordering]: Arbitrary[TreeSet[A]] = Arbitrary(genTreeSet)

property("foreach/iterator consistency") = forAll { (subject: TreeSet[Int]) =>
val it = subject.iterator
Expand Down Expand Up @@ -92,6 +95,17 @@ object Test extends Properties("TreeSet") {
prefix == subject.take(n) && suffix == subject.drop(n)
}

def genSliceParms = for {
tree <- genTreeSet[Int]
from <- choose(0, tree.size)
until <- choose(from, tree.size)
} yield (tree, from, until)

property("slice") = forAll(genSliceParms) { case (subject, from, until) =>
val slice = subject.slice(from, until)
slice.size == until - from && subject.toSeq == subject.take(from).toSeq ++ slice ++ subject.drop(until)
}

property("takeWhile") = forAll { (subject: TreeSet[Int]) =>
val result = subject.takeWhile(_ < 0)
result.forall(_ < 0) && result == subject.take(result.size)
Expand Down

0 comments on commit 78374f3

Please sign in to comment.