From 366d7a1a49993db80ef1759f1a395196b3fc7920 Mon Sep 17 00:00:00 2001 From: Marissa | April <7505383+NthPortal@users.noreply.github.com> Date: Thu, 10 Nov 2022 09:15:23 -0500 Subject: [PATCH] Fix CVE-2022-36944 for `LazyList` Backport fix for CVE-2022-36944 from 2.13. Code copy-pasted in a browser. --- .../compat/immutable/LazyList.scala | 13 ++- .../scala/collection/LazyListGCTest.scala | 82 +++++++++++++++++++ 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/compat/src/main/scala-2.11_2.12/scala/collection/compat/immutable/LazyList.scala b/compat/src/main/scala-2.11_2.12/scala/collection/compat/immutable/LazyList.scala index 3e4d73ff..e3d9398d 100644 --- a/compat/src/main/scala-2.11_2.12/scala/collection/compat/immutable/LazyList.scala +++ b/compat/src/main/scala-2.11_2.12/scala/collection/compat/immutable/LazyList.scala @@ -33,7 +33,7 @@ import scala.collection.generic.{ SeqFactory } import scala.collection.immutable.{LinearSeq, NumericRange} -import scala.collection.mutable.{ArrayBuffer, Builder, StringBuilder} +import scala.collection.mutable.{Builder, StringBuilder} import scala.language.implicitConversions /** This class implements an immutable linked list that evaluates elements @@ -516,10 +516,6 @@ final class LazyList[+A] private (private[this] var lazyState: () => LazyList.St else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state)) } else super.++:(prefix)(bf) - private def prependedAllToLL[B >: A](prefix: Traversable[B]): LazyList[B] = - if (knownIsEmpty) LazyList.from(prefix) - else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state)) - /** @inheritdoc * * $preservesLaziness @@ -1512,14 +1508,17 @@ object LazyList extends SeqFactory[LazyList] { private[this] def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - val init = new ArrayBuffer[A] + val init = new mutable.ListBuffer[A] var initRead = false while (!initRead) in.readObject match { case SerializeEnd => initRead = true case a => init += a.asInstanceOf[A] } val tail = in.readObject().asInstanceOf[LazyList[A]] - coll = tail.prependedAllToLL(init) + // scala/scala#10118: caution that no code path can evaluate `tail.state` + // before the resulting LazyList is returned + val it = init.toList.iterator + coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state)) } private[this] def readResolve(): Any = coll diff --git a/compat/src/test/scala-jvm/test/scala/collection/LazyListGCTest.scala b/compat/src/test/scala-jvm/test/scala/collection/LazyListGCTest.scala index 6b936739..d10d8172 100644 --- a/compat/src/test/scala-jvm/test/scala/collection/LazyListGCTest.scala +++ b/compat/src/test/scala-jvm/test/scala/collection/LazyListGCTest.scala @@ -125,4 +125,86 @@ class LazyListGCTest { def tapEach_takeRight_headOption_allowsGC(): Unit = { assertLazyListOpAllowsGC(_.tapEach(_).takeRight(2).headOption, _ => ()) } + + @Test + def serialization(): Unit = + if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.12"))) { + import java.io._ + + def serialize(obj: AnyRef): Array[Byte] = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + buffer.toByteArray + } + + def deserialize(a: Array[Byte]): AnyRef = { + val in = new ObjectInputStream(new ByteArrayInputStream(a)) + in.readObject + } + + def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T] + + val l = LazyList.from(10) + + val ld1 = serializeDeserialize(l) + assertEquals(l.take(10).toList, ld1.take(10).toList) + + l.tail.head + val ld2 = serializeDeserialize(l) + assertEquals(l.take(10).toList, ld2.take(10).toList) + + LazyListGCTest.serializationForceCount = 0 + val u = LazyList + .from(10) + .map(x => { + LazyListGCTest.serializationForceCount += 1; x + }) + + def printDiff(): Unit = { + val a = serialize(u) + classOf[LazyList[_]] + .getDeclaredField("scala$collection$compat$immutable$LazyList$$stateEvaluated") + .setBoolean(u, true) + val b = serialize(u) + val i = a.zip(b).indexWhere(p => p._1 != p._2) + println("difference: ") + println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}") + println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}") + } + + // to update this test, comment-out `LazyList.writeReplace` and run `printDiff` + // printDiff() + + val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, + 118, 97, 46) + val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, + 118, 97, 46) + + assertEquals(LazyListGCTest.serializationForceCount, 0) + + u.head + assertEquals(LazyListGCTest.serializationForceCount, 1) + + val data = serialize(u) + var i = data.indexOfSlice(from) + to.foreach(x => { + data(i) = x; i += 1 + }) + + val ud1 = deserialize(data).asInstanceOf[LazyList[Int]] + + // this check failed before scala/scala#10118, deserialization triggered evaluation + assertEquals(LazyListGCTest.serializationForceCount, 1) + + ud1.tail.head + assertEquals(LazyListGCTest.serializationForceCount, 2) + + u.tail.head + assertEquals(LazyListGCTest.serializationForceCount, 3) + } +} + +object LazyListGCTest { + var serializationForceCount = 0 }