Skip to content
This repository
Browse code

Iterator.++ no longer blows the stack.

To my chagrin we still hadn't gotten this one. I took a new
approach which seems like a winner to me. Here's a benchmark:

object Test {
  def run(n: Int) = println((1 to n).foldLeft(Iterator.empty: Iterator[Int])((res, _) => res ++ Iterator(1)) sum)
  def main(args: Array[String]): Unit = run(args(0).toInt)
}

Runtime before this commit for various n:

  500   0.403 real
  1000  0.911 real
  1500  2.351 real
  2000  5.298 real
  2500 10.184 real

Runtime after this commit, same n:

  500  0.346 real
  1000 0.359 real
  1500 0.368 real
  2000 0.379 real
  2500 0.390 real

In the test case I dial it up to 100000.
  • Loading branch information...
commit e3ddb2d7dff859c9fb81d34d1c9687f72321a713 1 parent 59d4998
Paul Phillips authored March 26, 2013
54  src/library/scala/collection/Iterator.scala
@@ -161,6 +161,41 @@ object Iterator {
161 161
     def hasNext = true
162 162
     def next = elem
163 163
   }
  164
+
  165
+  /** Avoid stack overflows when applying ++ to lots of iterators by
  166
+   *  flattening the unevaluated iterators out into a vector of closures.
  167
+   */
  168
+  private[scala] final class ConcatIterator[+A](initial: Vector[() => Iterator[A]]) extends Iterator[A] {
  169
+    // current set to null when all iterators are exhausted
  170
+    private[this] var current: Iterator[A] = Iterator.empty
  171
+    private[this] var queue: Vector[() => Iterator[A]] = initial
  172
+    // Advance current to the next non-empty iterator
  173
+    private[this] def advance(): Boolean = {
  174
+      if (queue.isEmpty) {
  175
+        current = null
  176
+        false
  177
+      }
  178
+      else {
  179
+        current = queue.head()
  180
+        queue = queue.tail
  181
+        current.hasNext || advance()
  182
+      }
  183
+    }
  184
+    def hasNext = (current ne null) && (current.hasNext || advance())
  185
+    def next()  = if (hasNext) current.next else Iterator.empty.next
  186
+
  187
+    override def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] =
  188
+      new ConcatIterator(queue :+ (() => that.toIterator))
  189
+  }
  190
+
  191
+  private[scala] final class JoinIterator[+A](lhs: Iterator[A], that: => GenTraversableOnce[A]) extends Iterator[A] {
  192
+    private[this] lazy val rhs: Iterator[A] = that.toIterator
  193
+    def hasNext = lhs.hasNext || rhs.hasNext
  194
+    def next    = if (lhs.hasNext) lhs.next else rhs.next
  195
+
  196
+    override def ++[B >: A](that: => GenTraversableOnce[B]) =
  197
+      new ConcatIterator(Vector(() => this, () => that.toIterator))
  198
+  }
164 199
 }
165 200
 
166 201
 import Iterator.empty
@@ -338,24 +373,7 @@ trait Iterator[+A] extends TraversableOnce[A] {
338 373
    *  @usecase def ++(that: => Iterator[A]): Iterator[A]
339 374
    *    @inheritdoc
340 375
    */
341  
-  def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
342  
-    // optimize a little bit to prevent n log n behavior.
343  
-    private var cur : Iterator[B] = self
344  
-    private var selfExhausted : Boolean = false
345  
-    // since that is by-name, make sure it's only referenced once -
346  
-    // if "val it = that" is inside the block, then hasNext on an empty
347  
-    // iterator will continually reevaluate it.  (ticket #3269)
348  
-    lazy val it = that.toIterator
349  
-    // the eq check is to avoid an infinite loop on "x ++ x"
350  
-    def hasNext = cur.hasNext || (!selfExhausted && {
351  
-      it.hasNext && {
352  
-        cur = it
353  
-        selfExhausted = true
354  
-        true
355  
-      }
356  
-    })
357  
-    def next() = { hasNext; cur.next() }
358  
-  }
  376
+  def ++[B >: A](that: => GenTraversableOnce[B]): Iterator[B] = new Iterator.JoinIterator(self, that)
359 377
 
360 378
   /** Creates a new iterator by applying a function to all values produced by this iterator
361 379
    *  and concatenating the results.
4  test/files/run/iterator-concat.check
... ...
@@ -0,0 +1,4 @@
  1
+100
  2
+1000
  3
+10000
  4
+100000
15  test/files/run/iterator-concat.scala
... ...
@@ -0,0 +1,15 @@
  1
+object Test {
  2
+  // Create `size` Function0s, each of which evaluates to an Iterator
  3
+  // which produces 1. Then fold them over ++ to get a single iterator,
  4
+  // which should sum to "size".
  5
+  def mk(size: Int): Iterator[Int] = {
  6
+    val closures = (1 to size).toList.map(x => (() => Iterator(1)))
  7
+    closures.foldLeft(Iterator.empty: Iterator[Int])((res, f) => res ++ f())
  8
+  }
  9
+  def main(args: Array[String]): Unit = {
  10
+    println(mk(100).sum)
  11
+    println(mk(1000).sum)
  12
+    println(mk(10000).sum)
  13
+    println(mk(100000).sum)
  14
+  }
  15
+}

0 notes on commit e3ddb2d

Please sign in to comment.
Something went wrong with that request. Please try again.