Skip to content
This repository
Browse code

SI-6642 Refactor mutable.TreeSet to use RedBlackTree instead of AVL

There was no reason to have mutable.TreeSet use AVLTree while
immutable.TreeSet and immutable.HashSet used RedBlackTree. In
particular that would have meant duplicating the iteratorFrom logic
unnecessarily. So this commit refactors mutable.TreeSet to use
RedBlackTree for everything, including iteratorFrom. It also adds
a test to make sure TreeSet works as expected.

AVLTree should be dead code since it's private[scala.collection.mutable]
and only used by mutable.TreeSet, but to be safe it's only deprecated
in this commit.
  • Loading branch information...
commit 39037798c94e6e862f39dacffc5e65bb08b78d6a 1 parent 62bc99d
James Iry authored February 12, 2013
21  src/library/scala/collection/immutable/RedBlackTree.scala
@@ -24,7 +24,7 @@ import scala.annotation.meta.getter
24 24
  *
25 25
  *  @since 2.10
26 26
  */
27  
-private[immutable]
  27
+private[collection]
28 28
 object RedBlackTree {
29 29
 
30 30
   def isEmpty(tree: Tree[_, _]): Boolean = tree eq null
@@ -44,6 +44,25 @@ object RedBlackTree {
44 44
   }
45 45
 
46 46
   def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count
  47
+  /**
  48
+   * Count all the nodes with keys greater than or equal to the lower bound and less than the upper bound.
  49
+   * The two bounds are optional.
  50
+   */
  51
+  def countInRange[A, _](tree: Tree[A, _], from: Option[A], to:Option[A])(implicit ordering: Ordering[A]) : Int = 
  52
+    if (tree eq null) 0 else
  53
+    (from, to) match {
  54
+      // with no bounds use this node's count
  55
+      case (None, None) => tree.count
  56
+      // if node is less than the lower bound, try the tree on the right, it might be in range
  57
+      case (Some(lb), _) if ordering.lt(tree.key, lb) => countInRange(tree.right, from, to)
  58
+      // if node is greater than or equal to the upper bound, try the tree on the left, it might be in range
  59
+      case (_, Some(ub)) if ordering.gteq(tree.key, ub) => countInRange(tree.left, from, to)
  60
+      // node is in range so the tree on the left will all be less than the upper bound and the tree on the
  61
+      // right will all be greater than or equal to the lower bound. So 1 for this node plus
  62
+      // count the subtrees by stripping off the bounds that we don't need any more
  63
+      case _ => 1 + countInRange(tree.left, from, None) + countInRange(tree.right, None, to)
  64
+    
  65
+    }
47 66
   def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1, overwrite: Boolean)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v, overwrite))
48 67
   def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k))
49 68
   def rangeImpl[A: Ordering, B](tree: Tree[A, B], from: Option[A], until: Option[A]): Tree[A, B] = (from, until) match {
11  src/library/scala/collection/mutable/AVLTree.scala
@@ -15,7 +15,7 @@ package mutable
15 15
  * An immutable AVL Tree implementation used by mutable.TreeSet
16 16
  *
17 17
  * @author Lucien Pereira
18  
- *
  18
+ * @deprecated("AVLTree and its related classes are being removed from the standard library since they're not different enough from RedBlackTree to justify keeping them.", "2.11")
19 19
  */
20 20
 private[mutable] sealed trait AVLTree[+A] extends Serializable {
21 21
   def balance: Int
@@ -65,12 +65,18 @@ private[mutable] sealed trait AVLTree[+A] extends Serializable {
65 65
   def doubleRightRotation[B >: A]: Node[B] = sys.error("Should not happen.")
66 66
 }
67 67
 
  68
+/**
  69
+ * @deprecated("AVLTree and its related classes are being removed from the standard library since they're not different enough from RedBlackTree to justify keeping them.", "2.11")
  70
+ */
68 71
 private case object Leaf extends AVLTree[Nothing] {
69 72
   override val balance: Int = 0
70 73
 
71 74
   override val depth: Int = -1
72 75
 }
73 76
 
  77
+/**
  78
+ * @deprecated("AVLTree and its related classes are being removed from the standard library since they're not different enough from RedBlackTree to justify keeping them.", "2.11")
  79
+ */
74 80
 private case class Node[A](val data: A, val left: AVLTree[A], val right: AVLTree[A]) extends AVLTree[A] {
75 81
   override val balance: Int = right.depth - left.depth
76 82
 
@@ -205,6 +211,9 @@ private case class Node[A](val data: A, val left: AVLTree[A], val right: AVLTree
205 211
   }
206 212
 }
207 213
 
  214
+/**
  215
+ * @deprecated("AVLTree and its related classes are being removed from the standard library since they're not different enough from RedBlackTree to justify keeping them.", "2.11")
  216
+ */
208 217
 private class AVLIterator[A](root: Node[A]) extends Iterator[A] {
209 218
   val stack = mutable.ArrayStack[Node[A]](root)
210 219
   diveLeft()
125  src/library/scala/collection/mutable/TreeSet.scala
@@ -10,6 +10,8 @@ package scala.collection
10 10
 package mutable
11 11
 
12 12
 import generic._
  13
+import scala.collection.immutable.{RedBlackTree => RB}
  14
+import scala.runtime.ObjectRef
13 15
 
14 16
 /**
15 17
  * @define Coll `mutable.TreeSet`
@@ -29,112 +31,81 @@ object TreeSet extends MutableSortedSetFactory[TreeSet] {
29 31
 }
30 32
 
31 33
 /**
32  
- * A mutable SortedSet using an immutable AVL Tree as underlying data structure.
  34
+ * A mutable SortedSet using an immutable RedBlack Tree as underlying data structure.
33 35
  *
34 36
  * @author Lucien Pereira
35 37
  *
36 38
  */
37  
-class TreeSet[A](implicit val ordering: Ordering[A]) extends SortedSet[A] with SetLike[A, TreeSet[A]]
  39
+class TreeSet[A] private (treeRef: ObjectRef[RB.Tree[A, Null]], from: Option[A], until: Option[A])(implicit val ordering: Ordering[A]) 
  40
+  extends SortedSet[A] with SetLike[A, TreeSet[A]]
38 41
   with SortedSetLike[A, TreeSet[A]] with Set[A] with Serializable {
39 42
 
40  
-  // Projection constructor
41  
-  private def this(base: Option[TreeSet[A]], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]) {
42  
-    this();
43  
-    this.base = base
44  
-    this.from = from
45  
-    this.until = until
46  
-  }
47  
-
48  
-  private var base: Option[TreeSet[A]] = None
49  
-
50  
-  private var from: Option[A] = None
51  
-
52  
-  private var until: Option[A] = None
53  
-
54  
-  private var avl: AVLTree[A] = Leaf
55  
-
56  
-  private var cardinality: Int = 0
  43
+  def this()(implicit ordering: Ordering[A]) = this(new ObjectRef(null), None, None)
57 44
 
58  
-  def resolve: TreeSet[A] = base.getOrElse(this)
59  
-
60  
-  private def isLeftAcceptable(from: Option[A], ordering: Ordering[A])(a: A): Boolean =
61  
-    from.map(x => ordering.gteq(a, x)).getOrElse(true)
62  
-
63  
-  private def isRightAcceptable(until: Option[A], ordering: Ordering[A])(a: A): Boolean =
64  
-    until.map(x => ordering.lt(a, x)).getOrElse(true)
65  
-
66  
-  /**
67  
-   * Cardinality store the set size, unfortunately a
68  
-   * set view (given by rangeImpl)
69  
-   * cannot take advantage of this optimisation
70  
-   *
71  
-   */
72  
-  override def size: Int = base.map(_ => super.size).getOrElse(cardinality)
  45
+  override def size: Int = RB.countInRange(treeRef.elem, from, until)
73 46
 
74 47
   override def stringPrefix = "TreeSet"
75 48
 
76 49
   override def empty: TreeSet[A] = TreeSet.empty
77 50
 
78  
-  override def rangeImpl(from: Option[A], until: Option[A]): TreeSet[A] = new TreeSet(Some(this), from, until)
  51
+  private def pickBound(comparison: (A, A) => A, oldBound: Option[A], newBound: Option[A]) = (newBound, oldBound) match {
  52
+    case (Some(newB), Some(oldB)) => Some(comparison(newB, oldB))
  53
+    case (None, _) => oldBound
  54
+    case _ => newBound
  55
+  }      
  56
+    
  57
+  override def rangeImpl(fromArg: Option[A], untilArg: Option[A]): TreeSet[A] = {
  58
+    val newFrom = pickBound(ordering.max, fromArg, from)
  59
+    val newUntil = pickBound(ordering.min, untilArg, until)
  60
+    
  61
+    new TreeSet(treeRef, newFrom, newUntil) 
  62
+  }
79 63
 
80 64
   override def -=(elem: A): this.type = {
81  
-    try {
82  
-      resolve.avl = resolve.avl.remove(elem, ordering)
83  
-      resolve.cardinality = resolve.cardinality - 1
84  
-    } catch {
85  
-      case e: NoSuchElementException => ()
86  
-    }
  65
+    treeRef.elem = RB.delete(treeRef.elem, elem)
87 66
     this
88 67
   }
89 68
 
90 69
   override def +=(elem: A): this.type = {
91  
-    try {
92  
-      resolve.avl = resolve.avl.insert(elem, ordering)
93  
-      resolve.cardinality = resolve.cardinality + 1
94  
-    } catch {
95  
-      case e: IllegalArgumentException => ()
96  
-    }
  70
+    treeRef.elem = RB.update(treeRef.elem, elem, null, false)
97 71
     this
98 72
   }
99 73
 
100 74
   /**
101 75
    * Thanks to the immutable nature of the
102  
-   * underlying AVL Tree, we can share it with
  76
+   * underlying Tree, we can share it with
103 77
    * the clone. So clone complexity in time is O(1).
104 78
    *
105 79
    */
106  
-  override def clone(): TreeSet[A] = {
107  
-    val clone = new TreeSet[A](base, from, until)
108  
-    clone.avl = resolve.avl
109  
-    clone.cardinality = resolve.cardinality
110  
-    clone
111  
-  }
  80
+  override def clone(): TreeSet[A] = 
  81
+    new TreeSet[A](new ObjectRef(treeRef.elem), from, until)
  82
+    
  83
+  private val notProjection = !(from.isDefined || until.isDefined)
112 84
 
113 85
   override def contains(elem: A): Boolean = {
114  
-    isLeftAcceptable(from, ordering)(elem) &&
115  
-    isRightAcceptable(until, ordering)(elem) &&
116  
-    resolve.avl.contains(elem, ordering)
  86
+    def leftAcceptable: Boolean = from match {
  87
+      case Some(lb) => ordering.gteq(elem, lb)
  88
+      case _ => true
  89
+    }
  90
+
  91
+    def rightAcceptable: Boolean = until match {
  92
+      case Some(ub) => ordering.lt(elem, ub)
  93
+      case _ => true
  94
+    }    
  95
+    
  96
+    (notProjection || (leftAcceptable && rightAcceptable)) &&
  97
+      RB.contains(treeRef.elem, elem)
117 98
   }
118 99
 
119  
-  // TODO see the discussion on keysIteratorFrom
120  
-  override def iterator: Iterator[A] = resolve.avl.iterator
121  
-    .dropWhile(e => !isLeftAcceptable(from, ordering)(e))
122  
-      .takeWhile(e => isRightAcceptable(until, ordering)(e))
  100
+  override def iterator: Iterator[A] = iteratorFrom(None)
123 101
   
124  
-  // TODO because TreeSets are potentially ranged views into other TreeSets
125  
-  // what this really needs to do is walk the whole stack of tree sets, find
126  
-  // the highest "from", and then do a tree walk of the underlying avl tree
127  
-  // to find that spot in max(O(stack depth), O(log tree.size)) time which
128  
-  // should effectively be O(log size) since ranged views are rare and
129  
-  // even more rarely deep. With the following implementation it's
130  
-  // O(N log N) to get an iterator from a start point.  
131  
-  // But before engaging that endeavor I think mutable.TreeSet should be
132  
-  // based on the same immutable RedBlackTree that immutable.TreeSet is
133  
-  // based on. There's no good reason to have these two collections based
134  
-  // on two different balanced binary trees. That'll save
135  
-  // having to duplicate logic for finding the starting point of a
136  
-  // sorted binary tree iterator, logic that has already been
137  
-  // baked into RedBlackTree.
138  
-  override def keysIteratorFrom(start: A) = from(start).iterator
139  
-
  102
+  override def keysIteratorFrom(start: A) = iteratorFrom(Some(start))
  103
+  
  104
+  private def iteratorFrom(start: Option[A]) = {
  105
+    val it = RB.keysIterator(treeRef.elem, pickBound(ordering.max, from, start))
  106
+    until match {
  107
+      case None => it
  108
+      case Some(ub) => it takeWhile (k => ordering.lt(k, ub))
  109
+    }
  110
+  }
140 111
 }
145  test/files/run/mutable-treeset.scala
... ...
@@ -0,0 +1,145 @@
  1
+import scala.collection.mutable.TreeSet
  2
+
  3
+object Test extends App {
  4
+  val list = List(6,5,4,3,2,1,1,2,3,4,5,6,6,5,4,3,2,1)
  5
+  val distinct = list.distinct
  6
+  val sorted = distinct.sorted
  7
+
  8
+  // sublist stuff for a single level of slicing
  9
+  val min = list.min
  10
+  val max = list.max
  11
+  val nonlist = ((min - 10) until (max + 20) filterNot list.contains).toList
  12
+  val sublist = list filter {x => x >=(min + 1) && x < max} 
  13
+  val distinctSublist = sublist.distinct 
  14
+  val subnonlist = min :: max :: nonlist
  15
+  val subsorted = distinctSublist.sorted
  16
+
  17
+  // subsublist for a 2nd level of slicing
  18
+  val almostmin = sublist.min
  19
+  val almostmax = sublist.max
  20
+  val subsublist = sublist filter {x => x >=(almostmin + 1) && x < almostmax} 
  21
+  val distinctSubsublist = subsublist.distinct 
  22
+  val subsubnonlist = almostmin :: almostmax :: subnonlist
  23
+  val subsubsorted = distinctSubsublist.sorted
  24
+
  25
+  def testSize {
  26
+    def check(set : TreeSet[Int], list: List[Int]) { 
  27
+      assert(set.size == list.size, s"$set had size ${set.size} while $list had size ${list.size}")
  28
+    }
  29
+
  30
+    check(TreeSet[Int](), List[Int]())
  31
+    val set = TreeSet(list:_*)
  32
+    check(set, distinct)
  33
+    check(set.clone, distinct)
  34
+
  35
+    val subset = set from (min + 1) until max
  36
+    check(subset, distinctSublist)
  37
+    check(subset.clone, distinctSublist)
  38
+
  39
+    val subsubset = subset from (almostmin + 1) until almostmax
  40
+    check(subsubset, distinctSubsublist)
  41
+    check(subsubset.clone, distinctSubsublist)
  42
+  }
  43
+
  44
+  def testContains {
  45
+    def check(set : TreeSet[Int], list: List[Int], nonlist: List[Int]) {
  46
+      assert(list forall set.apply, s"$set did not contain all elements of $list using apply")
  47
+      assert(list forall set.contains, s"$set did not contain all elements of $list using contains")
  48
+      assert(!(nonlist exists set.apply), s"$set had an element from $nonlist using apply")
  49
+      assert(!(nonlist exists set.contains), s"$set had an element from $nonlist using contains")      
  50
+    }
  51
+
  52
+    val set = TreeSet(list:_*)
  53
+    check(set, list, nonlist)
  54
+    check(set.clone, list, nonlist)
  55
+
  56
+    val subset = set from (min + 1) until max
  57
+    check(subset, sublist, subnonlist)
  58
+    check(subset.clone, sublist, subnonlist)
  59
+
  60
+    val subsubset = subset from (almostmin + 1) until almostmax
  61
+    check(subsubset, subsublist, subsubnonlist)
  62
+    check(subsubset.clone, subsublist, subsubnonlist)
  63
+  }
  64
+
  65
+  def testAdd {
  66
+    def check(set : TreeSet[Int], list: List[Int], nonlist: List[Int]) {
  67
+      var builtList = List[Int]()
  68
+      for (x <- list) {
  69
+        set += x
  70
+        builtList = (builtList :+ x).distinct.sorted filterNot nonlist.contains
  71
+        assert(builtList forall set.apply, s"$set did not contain all elements of $builtList using apply")
  72
+        assert(builtList.size == set.size, s"$set had size ${set.size} while $builtList had size ${builtList.size}")
  73
+      }
  74
+      assert(!(nonlist exists set.apply), s"$set had an element from $nonlist using apply")
  75
+      assert(!(nonlist exists set.contains), s"$set had an element from $nonlist using contains")      
  76
+    }
  77
+
  78
+    val set = TreeSet[Int]()
  79
+    val clone = set.clone    
  80
+    val subset = set.clone from (min + 1) until max
  81
+    val subclone = subset.clone
  82
+    val subsubset = subset.clone from (almostmin + 1) until almostmax
  83
+    val subsubclone = subsubset.clone
  84
+
  85
+    check(set, list, nonlist)
  86
+    check(clone, list, nonlist)
  87
+
  88
+    check(subset, list, subnonlist)
  89
+    check(subclone, list, subnonlist)
  90
+
  91
+    check(subsubset, list, subsubnonlist)
  92
+    check(subsubclone, list, subsubnonlist)
  93
+  }
  94
+
  95
+  def testRemove {
  96
+    def check(set: TreeSet[Int], sorted: List[Int]) {
  97
+      var builtList = sorted
  98
+      for (x <- list) {
  99
+        set remove x
  100
+        builtList = builtList filterNot (_ == x)
  101
+        assert(builtList forall set.apply, s"$set did not contain all elements of $builtList using apply")
  102
+        assert(builtList.size == set.size, s"$set had size $set.size while $builtList had size $builtList.size")
  103
+      }   
  104
+    }
  105
+    val set = TreeSet(list:_*)
  106
+    val clone = set.clone    
  107
+    val subset = set.clone from (min + 1) until max
  108
+    val subclone = subset.clone
  109
+    val subsubset = subset.clone from (almostmin + 1) until almostmax
  110
+    val subsubclone = subsubset.clone
  111
+
  112
+    check(set, sorted)
  113
+    check(clone, sorted)
  114
+
  115
+    check(subset, subsorted)
  116
+    check(subclone, subsorted)
  117
+    
  118
+    check(subsubset, subsubsorted)
  119
+    check(subsubclone, subsubsorted)
  120
+  }
  121
+
  122
+  def testIterator {
  123
+    def check(set: TreeSet[Int], list: List[Int]) {
  124
+      val it = set.iterator.toList
  125
+      assert(it == list, s"$it did not equal $list")
  126
+    }
  127
+    val set = TreeSet(list: _*)
  128
+    check(set, sorted)
  129
+    check(set.clone, sorted)
  130
+
  131
+    val subset = set from (min + 1) until max
  132
+    check(subset, subsorted)
  133
+    check(subset.clone, subsorted)
  134
+
  135
+    val subsubset = subset from (almostmin + 1) until almostmax
  136
+    check(subsubset, subsubsorted)
  137
+    check(subsubset.clone, subsubsorted)
  138
+  }
  139
+
  140
+  testSize
  141
+  testContains
  142
+  testAdd
  143
+  testRemove
  144
+  testIterator
  145
+}

0 notes on commit 3903779

Please sign in to comment.
Something went wrong with that request. Please try again.