Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SI-6584, Stream#distinct uses too much memory. #1535

Merged
merged 2 commits into from Nov 1, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 36 additions & 9 deletions src/library/scala/collection/immutable/Stream.scala
Expand Up @@ -181,6 +181,7 @@ import scala.language.implicitConversions
* @define coll stream
* @define orderDependent
* @define orderDependentFold
* @define willTerminateInf Note: lazily evaluated; will terminate for infinite-sized collections.
*/
abstract class Stream[+A] extends AbstractSeq[A]
with LinearSeq[A]
Expand Down Expand Up @@ -286,9 +287,8 @@ self =>
len
}

/** It's an imperfect world, but at least we can bottle up the
* imperfection in a capsule.
*/
// It's an imperfect world, but at least we can bottle up the
// imperfection in a capsule.
@inline private def asThat[That](x: AnyRef): That = x.asInstanceOf[That]
@inline private def asStream[B](x: AnyRef): Stream[B] = x.asInstanceOf[Stream[B]]
@inline private def isStreamBuilder[B, That](bf: CanBuildFrom[Stream[A], B, That]) =
Expand Down Expand Up @@ -725,10 +725,15 @@ self =>
* // produces: "5, 6, 7, 8, 9"
* }}}
*/
override def take(n: Int): Stream[A] =
override def take(n: Int): Stream[A] = (
// Note that the n == 1 condition appears redundant but is not.
// It prevents "tail" from being referenced (and its head being evaluated)
// when obtaining the last element of the result. Such are the challenges
// of working with a lazy-but-not-really sequence.
if (n <= 0 || isEmpty) Stream.empty
else if (n == 1) cons(head, Stream.empty)
else cons(head, tail take n-1)
)

@tailrec final override def drop(n: Int): Stream[A] =
if (n <= 0 || isEmpty) this
Expand Down Expand Up @@ -784,8 +789,23 @@ self =>
these
}

// there's nothing we can do about dropRight, so we just keep the definition
// in LinearSeq
/**
* @inheritdoc
* $willTerminateInf
*/
override def dropRight(n: Int): Stream[A] = {
// We make dropRight work for possibly infinite streams by carrying
// a buffer of the dropped size. As long as the buffer is full and the
// rest is non-empty, we can feed elements off the buffer head. When
// the rest becomes empty, the full buffer is the dropped elements.
def advance(stub0: List[A], stub1: List[A], rest: Stream[A]): Stream[A] = {
if (rest.isEmpty) Stream.empty
else if (stub0.isEmpty) advance(stub1.reverse, Nil, rest)
else cons(stub0.head, advance(stub0.tail, rest.head :: stub1, rest.tail))
}
if (n <= 0) this
else advance((this take n).toList, Nil, this drop n)
}

/** Returns the longest prefix of this `Stream` whose elements satisfy the
* predicate `p`.
Expand Down Expand Up @@ -841,9 +861,16 @@ self =>
* // produces: "1, 2, 3, 4, 5, 6"
* }}}
*/
override def distinct: Stream[A] =
if (isEmpty) this
else cons(head, tail.filter(head != _).distinct)
override def distinct: Stream[A] = {
// This should use max memory proportional to N, whereas
// recursively calling distinct on the tail is N^2.
def loop(seen: Set[A], rest: Stream[A]): Stream[A] = {
if (rest.isEmpty) rest
else if (seen(rest.head)) loop(seen, rest.tail)
else cons(rest.head, loop(seen + rest.head, rest.tail))
}
loop(Set(), this)
}

/** Returns a new sequence of given length containing the elements of this
* sequence followed by zero or more occurrences of given elements.
Expand Down
1 change: 1 addition & 0 deletions test/files/run/streams.check
Expand Up @@ -23,3 +23,4 @@ Stream(100001, ?)
true
true
705082704
6
5 changes: 4 additions & 1 deletion test/files/run/streams.scala
Expand Up @@ -29,7 +29,7 @@ object Test extends App {
def powers(x: Int) = if ((x&(x-1)) == 0) Some(x) else None
println(s3.flatMap(powers).reverse.head)

// large enough to generate StackOverflows (on most systems)
// large enough to generate StackOverflows (on most systems)
// unless the following methods are tail call optimized.
val size = 100000

Expand All @@ -43,4 +43,7 @@ object Test extends App {
println(Stream.from(1).take(size).foldLeft(0)(_ + _))
val arr = new Array[Int](size)
Stream.from(1).take(size).copyToArray(arr, 0)

// dropRight terminates
println(Stream from 1 dropRight 1000 take 3 sum)
}
8 changes: 8 additions & 0 deletions test/files/run/t6584.check
@@ -0,0 +1,8 @@
Array: 102400
Vector: 102400
List: 102400
Stream: 102400
Array: 102400
Vector: 102400
List: 102400
Stream: 102400
16 changes: 16 additions & 0 deletions test/files/run/t6584.scala
@@ -0,0 +1,16 @@
object Test {
def main(args: Array[String]): Unit = {
val size = 100 * 1024
val doubled = (1 to size) ++ (1 to size)

println("Array: " + Array.tabulate(size)(x => x).distinct.size)
println("Vector: " + Vector.tabulate(size)(x => x).distinct.size)
println("List: " + List.tabulate(size)(x => x).distinct.size)
println("Stream: " + Stream.tabulate(size)(x => x).distinct.size)

println("Array: " + doubled.toArray.distinct.size)
println("Vector: " + doubled.toVector.distinct.size)
println("List: " + doubled.toList.distinct.size)
println("Stream: " + doubled.toStream.distinct.size)
}
}