From 0fb2d9d78f5ebaf887e42cd8db767a2edad33fef Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Tue, 21 Jun 2022 13:57:50 +0200 Subject: [PATCH] Prevent Function0 execution during LazyList deserialization This PR ensures that LazyList deserialization will not execute an arbitrary Function0 when being passed a forged serialization stream. See the PR description for a detailed explanation. --- .../scala/collection/immutable/LazyList.scala | 9 +++++++- .../collection/immutable/LazyListTest.scala | 23 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/library/scala/collection/immutable/LazyList.scala b/src/library/scala/collection/immutable/LazyList.scala index dde413bd91ce..53ba85dfdac9 100644 --- a/src/library/scala/collection/immutable/LazyList.scala +++ b/src/library/scala/collection/immutable/LazyList.scala @@ -249,6 +249,13 @@ final class LazyList[+A] private(private[this] var lazyState: () => LazyList.Sta @inline private def stateDefined: Boolean = stateEvaluated private[this] var midEvaluation = false + private def withNullLazyState[T](f: => T): T = { + val saved = lazyState + lazyState = null + try f + finally lazyState = saved + } + private lazy val state: State[A] = { // if it's already mid-evaluation, we're stuck in an infinite // self-referential loop (also it's empty) @@ -1370,7 +1377,7 @@ object LazyList extends SeqFactory[LazyList] { case a => init += a.asInstanceOf[A] } val tail = in.readObject().asInstanceOf[LazyList[A]] - coll = init ++: tail + coll = tail.withNullLazyState(tail.prependedAll(init)) } private[this] def readResolve(): Any = coll diff --git a/test/junit/scala/collection/immutable/LazyListTest.scala b/test/junit/scala/collection/immutable/LazyListTest.scala index 58798ec4cb9d..947ba788cff5 100644 --- a/test/junit/scala/collection/immutable/LazyListTest.scala +++ b/test/junit/scala/collection/immutable/LazyListTest.scala @@ -14,6 +14,29 @@ import scala.util.Try @RunWith(classOf[JUnit4]) class LazyListTest { + @Test + def serialization(): Unit = { + import java.io._ + + def serialize(obj: AnyRef): Array[Byte] = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + buffer.toByteArray + } + def deserialize(a: Array[Byte]): AnyRef = { + val in = new ObjectInputStream(new ByteArrayInputStream(a)) + in.readObject + } + + def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T] + + val l = LazyList.from(10); l.tail.head + val ld = serializeDeserialize(l) + ld.tail.tail.head + println(ld) + } + @Test def t6727_and_t6440_and_8627(): Unit = { assertTrue(LazyList.continually(()).filter(_ => true).take(2) == Seq((), ()))