Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/library/scala/collection/immutable/LazyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import java.lang.{StringBuilder => JStringBuilder}

import scala.annotation.tailrec
import scala.collection.generic.SerializeEnd
import scala.collection.mutable.{ArrayBuffer, Builder, ReusableBuilder, StringBuilder}
import scala.collection.mutable.{Builder, ReusableBuilder, StringBuilder}
import scala.language.implicitConversions
import scala.runtime.Statics

Expand Down Expand Up @@ -1353,7 +1353,7 @@ object LazyList extends SeqFactory[LazyList] {
private[this] def writeObject(out: ObjectOutputStream): Unit = {
out.defaultWriteObject()
var these = coll
while(these.knownNonEmpty) {
while (these.knownNonEmpty) {
out.writeObject(these.head)
these = these.tail
}
Expand All @@ -1363,14 +1363,17 @@ object LazyList extends SeqFactory[LazyList] {

private[this] def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
val init = new ArrayBuffer[A]
val init = new mutable.ListBuffer[A]
var initRead = false
while (!initRead) in.readObject match {
case SerializeEnd => initRead = true
case a => init += a.asInstanceOf[A]
}
val tail = in.readObject().asInstanceOf[LazyList[A]]
coll = init ++: tail
// scala/scala#10118: caution that no code path can evaluate `tail.state`
// before the resulting LazyList is returned
val it = init.toList.iterator
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
}

private[this] def readResolve(): Any = coll
Expand Down
73 changes: 72 additions & 1 deletion test/junit/scala/collection/immutable/LazyListTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,79 @@ import org.junit.Assert._

import scala.annotation.unused
import scala.collection.mutable.{Builder, ListBuffer}
import scala.tools.testkit.AssertUtil
import scala.tools.testkit.{AssertUtil, ReflectUtil}
import scala.util.Try

@RunWith(classOf[JUnit4])
class LazyListTest {

@Test
def serialization(): Unit = {
import java.io._

def serialize(obj: AnyRef): Array[Byte] = {
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
buffer.toByteArray
}

def deserialize(a: Array[Byte]): AnyRef = {
val in = new ObjectInputStream(new ByteArrayInputStream(a))
in.readObject
}

def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]

val l = LazyList.from(10)

val ld1 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld1.take(10).toList)

l.tail.head
val ld2 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld2.take(10).toList)

LazyListTest.serializationForceCount = 0
val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x })

@unused def printDiff(): Unit = {
val a = serialize(u)
ReflectUtil.getFieldAccessible[LazyList[_]]("scala$collection$immutable$LazyList$$stateEvaluated").setBoolean(u, true)
val b = serialize(u)
val i = a.zip(b).indexWhere(p => p._1 != p._2)
println("difference: ")
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
}

// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
// printDiff()

val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46)
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46)

assertEquals(LazyListTest.serializationForceCount, 0)

u.head
assertEquals(LazyListTest.serializationForceCount, 1)

val data = serialize(u)
var i = data.indexOfSlice(from)
to.foreach(x => {data(i) = x; i += 1})

val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]

// this check failed before scala/scala#10118, deserialization triggered evaluation
assertEquals(LazyListTest.serializationForceCount, 1)

ud1.tail.head
assertEquals(LazyListTest.serializationForceCount, 2)

u.tail.head
assertEquals(LazyListTest.serializationForceCount, 3)
}

@Test
def t6727_and_t6440_and_8627(): Unit = {
assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ()))
Expand Down Expand Up @@ -378,3 +445,7 @@ class LazyListTest {
assertEquals(1, count)
}
}

object LazyListTest {
var serializationForceCount = 0
}