From 8f6e522778f23e5b9a51d0b938ed6d20ea0f0d85 Mon Sep 17 00:00:00 2001 From: NthPortal Date: Wed, 2 Sep 2020 14:42:40 -0400 Subject: [PATCH] [bug#12009] Make ListBuffer's iterator fail-fast Make `ListBuffer`'s iterator fail-fast when the buffer is mutated after the iterator's creation. Co-authored-by: Jason Zaugg --- build.sbt | 5 + .../nsc/transform/async/AnfTransform.scala | 10 +- .../scala/collection/mutable/ListBuffer.scala | 41 +++- .../collection/mutable/MutationTracker.scala | 78 +++++++ .../mutable/ConstructionBenchmark.scala | 49 ++++ .../mutable/ListBufferBenchmark.scala | 14 ++ test/files/run/t8153.check | 1 - test/files/run/t8153.scala | 14 -- .../mutable/MutationTrackerTest.scala | 32 +++ .../mutable/MutationTrackingTest.scala | 211 ++++++++++++++++++ .../tools/nsc/scaladoc/HtmlFactoryTest.scala | 21 +- 11 files changed, 435 insertions(+), 41 deletions(-) create mode 100644 src/library/scala/collection/mutable/MutationTracker.scala create mode 100644 test/benchmarks/src/main/scala/scala/collection/mutable/ConstructionBenchmark.scala delete mode 100644 test/files/run/t8153.check delete mode 100644 test/files/run/t8153.scala create mode 100644 test/junit/scala/collection/mutable/MutationTrackerTest.scala create mode 100644 test/junit/scala/collection/mutable/MutationTrackingTest.scala diff --git a/build.sbt b/build.sbt index 7be330920e52..c82dbe148dd1 100644 --- a/build.sbt +++ b/build.sbt @@ -451,6 +451,11 @@ val mimaFilterSettings = Seq { // this is safe because the default cannot be used; instead the single-param overload in // `IterableOnceOps` is chosen (https://github.com/scala/scala/pull/9232#discussion_r501554458) ProblemFilters.exclude[DirectMissingMethodProblem]("scala.collection.immutable.ArraySeq.copyToArray$default$2"), + + // Fix for scala/bug#12009 + ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker"), + ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$"), + ProblemFilters.exclude[MissingClassProblem]("scala.collection.mutable.MutationTracker$CheckedIterator"), ), } diff --git a/src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala b/src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala index 9489de6464b3..3d6a550b08d0 100644 --- a/src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala +++ b/src/compiler/scala/tools/nsc/transform/async/AnfTransform.scala @@ -136,19 +136,15 @@ private[async] trait AnfTransform extends TransformUtils { } case ValDef(mods, name, tpt, rhs) => atOwner(tree.symbol) { - // Capture current cursor of a non-empty `stats` buffer so we can efficiently restrict the + // Capture size of `stats` buffer so we can efficiently restrict the // `changeOwner` to the newly added items... - var statsIterator = if (currentStats.isEmpty) null else currentStats.iterator + val oldItemsCount = currentStats.length val expr = atOwner(currentOwner.owner)(transform(rhs)) - // But, ListBuffer.empty.iterator doesn't reflect later mutation. Luckily we can just start - // from the beginning of the buffer - if (statsIterator == null) statsIterator = currentStats.iterator - // Definitions within stats lifted out of the `ValDef` rhs should no longer be owned by the // the ValDef. - statsIterator.foreach(_.changeOwner((currentOwner, currentOwner.owner))) + currentStats.iterator.drop(oldItemsCount).foreach(_.changeOwner((currentOwner, currentOwner.owner))) val expr1 = if (isUnitType(expr.tpe)) { currentStats += expr literalBoxedUnit diff --git a/src/library/scala/collection/mutable/ListBuffer.scala b/src/library/scala/collection/mutable/ListBuffer.scala index 395086a00ef9..21211e8ac11a 100644 --- a/src/library/scala/collection/mutable/ListBuffer.scala +++ b/src/library/scala/collection/mutable/ListBuffer.scala @@ -35,6 +35,7 @@ import scala.runtime.Statics.releaseFence * @define mayNotTerminateInf * @define willNotTerminateInf */ +@SerialVersionUID(-8428291952499836345L) class ListBuffer[A] extends AbstractBuffer[A] with SeqOps[A, ListBuffer, ListBuffer[A]] @@ -42,6 +43,7 @@ class ListBuffer[A] with ReusableBuilder[A, immutable.List[A]] with IterableFactoryDefaults[A, ListBuffer] with DefaultSerializable { + @transient private[this] var mutationCount: Int = 0 private var first: List[A] = Nil private var last0: ::[A] = null @@ -50,7 +52,7 @@ class ListBuffer[A] private type Predecessor[A0] = ::[A0] /*| Null*/ - def iterator = first.iterator + def iterator: Iterator[A] = new MutationTracker.CheckedIterator(first.iterator, mutationCount) override def iterableFactory: SeqFactory[ListBuffer] = ListBuffer @@ -69,7 +71,12 @@ class ListBuffer[A] aliased = false } - private def ensureUnaliased() = if (aliased) copyElems() + // we only call this before mutating things, so it's + // a good place to track mutations for the iterator + private def ensureUnaliased(): Unit = { + mutationCount += 1 + if (aliased) copyElems() + } // Avoids copying where possible. override def toList: List[A] = { @@ -97,6 +104,7 @@ class ListBuffer[A] } def clear(): Unit = { + mutationCount += 1 first = Nil len = 0 last0 = null @@ -301,15 +309,17 @@ class ListBuffer[A] } def mapInPlace(f: A => A): this.type = { - ensureUnaliased() + mutationCount += 1 val buf = new ListBuffer[A] for (elem <- this) buf += f(elem) first = buf.first last0 = buf.last0 + aliased = false // we just assigned from a new instance this } def flatMapInPlace(f: A => IterableOnce[A]): this.type = { + mutationCount += 1 var src = first var dst: List[A] = null last0 = null @@ -325,6 +335,7 @@ class ListBuffer[A] src = src.tail } first = if(dst eq null) Nil else dst + aliased = false // we just rebuilt a fresh, unaliased instance this } @@ -348,12 +359,24 @@ class ListBuffer[A] } def patchInPlace(from: Int, patch: collection.IterableOnce[A], replaced: Int): this.type = { - val i = math.min(math.max(from, 0), length) - val n = math.min(math.max(replaced, 0), length) - ensureUnaliased() - val p = locate(i) - removeAfter(p, math.min(n, len - i)) - insertAfter(p, patch.iterator) + val _len = len + val _from = math.max(from, 0) // normalized + val _replaced = math.max(replaced, 0) // normalized + val it = patch.iterator + + val nonEmptyPatch = it.hasNext + val nonEmptyReplace = (_from < _len) && (_replaced > 0) + + // don't want to add a mutation or check aliasing (potentially expensive) + // if there's no patching to do + if (nonEmptyPatch || nonEmptyReplace) { + ensureUnaliased() + val i = math.min(_from, _len) + val n = math.min(_replaced, _len) + val p = locate(i) + removeAfter(p, math.min(n, _len - i)) + insertAfter(p, it) + } this } diff --git a/src/library/scala/collection/mutable/MutationTracker.scala b/src/library/scala/collection/mutable/MutationTracker.scala new file mode 100644 index 000000000000..e98536d0dad5 --- /dev/null +++ b/src/library/scala/collection/mutable/MutationTracker.scala @@ -0,0 +1,78 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala +package collection +package mutable + +import java.util.ConcurrentModificationException + +/** + * Utilities to check that mutations to a client that tracks + * its mutations have not occurred since a given point. + * [[Iterator `Iterator`]]s that perform this check automatically + * during iteration can be created by wrapping an `Iterator` + * in a [[MutationTracker.CheckedIterator `CheckedIterator`]], + * or by manually using the [[MutationTracker.checkMutations() `checkMutations`]] + * and [[MutationTracker.checkMutationsForIteration() `checkMutationsForIteration`]] + * methods. + */ +private object MutationTracker { + + /** + * Checks whether or not the actual mutation count differs from + * the expected one, throwing an exception, if it does. + * + * @param expectedCount the expected mutation count + * @param actualCount the actual mutation count + * @param message the exception message in case of mutations + * @throws ConcurrentModificationException if the expected and actual + * mutation counts differ + */ + @throws[ConcurrentModificationException] + def checkMutations(expectedCount: Int, actualCount: Int, message: String): Unit = { + if (actualCount != expectedCount) throw new ConcurrentModificationException(message) + } + + /** + * Checks whether or not the actual mutation count differs from + * the expected one, throwing an exception, if it does. This method + * produces an exception message saying that it was called because a + * backing collection was mutated during iteration. + * + * @param expectedCount the expected mutation count + * @param actualCount the actual mutation count + * @throws ConcurrentModificationException if the expected and actual + * mutation counts differ + */ + @throws[ConcurrentModificationException] + @inline def checkMutationsForIteration(expectedCount: Int, actualCount: Int): Unit = + checkMutations(expectedCount, actualCount, "mutation occurred during iteration") + + /** + * An iterator wrapper that checks if the underlying collection has + * been mutated. + * + * @param underlying the underlying iterator + * @param mutationCount a by-name provider of the current mutation count + * @tparam A the type of the iterator's elements + */ + final class CheckedIterator[A](underlying: Iterator[A], mutationCount: => Int) extends AbstractIterator[A] { + private[this] val expectedCount = mutationCount + + def hasNext: Boolean = { + checkMutationsForIteration(expectedCount, mutationCount) + underlying.hasNext + } + def next(): A = underlying.next() + } +} diff --git a/test/benchmarks/src/main/scala/scala/collection/mutable/ConstructionBenchmark.scala b/test/benchmarks/src/main/scala/scala/collection/mutable/ConstructionBenchmark.scala new file mode 100644 index 000000000000..4771f8efc829 --- /dev/null +++ b/test/benchmarks/src/main/scala/scala/collection/mutable/ConstructionBenchmark.scala @@ -0,0 +1,49 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.collection +package mutable + +import java.util.concurrent.TimeUnit + +import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.infra._ + +@BenchmarkMode(Array(Mode.AverageTime)) +@Fork(2) +@Threads(1) +@Warmup(iterations = 20) +@Measurement(iterations = 20) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Benchmark) +class ConstructionBenchmark { + @Param(Array("0", "1", "10", "100")) + var size: Int = _ + + var values: Range = _ + + @Setup(Level.Trial) def init(): Unit = { + values = 1 to size + } + + @Benchmark def listBuffer_new: Any = { + new ListBuffer ++= values + } + + @Benchmark def listBuffer_from: Any = { + ListBuffer from values + } + + @Benchmark def listBuffer_to: Any = { + values to ListBuffer + } +} diff --git a/test/benchmarks/src/main/scala/scala/collection/mutable/ListBufferBenchmark.scala b/test/benchmarks/src/main/scala/scala/collection/mutable/ListBufferBenchmark.scala index 8d4a6e1fc62a..a2e0999c939f 100644 --- a/test/benchmarks/src/main/scala/scala/collection/mutable/ListBufferBenchmark.scala +++ b/test/benchmarks/src/main/scala/scala/collection/mutable/ListBufferBenchmark.scala @@ -98,4 +98,18 @@ class ListBufferBenchmark { b.flatMapInPlace { _ => seq } bh.consume(b) } + + @Benchmark def iteratorA(bh: Blackhole): Unit = { + val b = ref.clone() + var n = 0 + for (x <- b.iterator) n += x + bh.consume(n) + bh.consume(b) + } + + @Benchmark def iteratorB(bh: Blackhole): Unit = { + val b = ref.clone() + bh.consume(b.iterator.toVector) + bh.consume(b) + } } diff --git a/test/files/run/t8153.check b/test/files/run/t8153.check deleted file mode 100644 index 0cfbf08886fc..000000000000 --- a/test/files/run/t8153.check +++ /dev/null @@ -1 +0,0 @@ -2 diff --git a/test/files/run/t8153.scala b/test/files/run/t8153.scala deleted file mode 100644 index f3063bdc7bfc..000000000000 --- a/test/files/run/t8153.scala +++ /dev/null @@ -1,14 +0,0 @@ -object Test { - def f() = { - val lb = scala.collection.mutable.ListBuffer[Int](1, 2) - val it = lb.iterator - if (it.hasNext) it.next() - val xs = lb.toList - lb += 3 - it.mkString - } - - def main(args: Array[String]): Unit = { - println(f()) - } -} diff --git a/test/junit/scala/collection/mutable/MutationTrackerTest.scala b/test/junit/scala/collection/mutable/MutationTrackerTest.scala new file mode 100644 index 000000000000..ccdf8712c9ed --- /dev/null +++ b/test/junit/scala/collection/mutable/MutationTrackerTest.scala @@ -0,0 +1,32 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.collection.mutable + +import java.util.ConcurrentModificationException + +import org.junit.Test + +import scala.tools.testkit.AssertUtil.assertThrows + +class MutationTrackerTest { + @Test + def checkedIterator(): Unit = { + var mutationCount = 0 + def it = new MutationTracker.CheckedIterator(List(1, 2, 3).iterator, mutationCount) + val it1 = it + it1.toList // does not throw + val it2 = it + mutationCount += 1 + assertThrows[ConcurrentModificationException](it2.toList, _ contains "iteration") + } +} diff --git a/test/junit/scala/collection/mutable/MutationTrackingTest.scala b/test/junit/scala/collection/mutable/MutationTrackingTest.scala new file mode 100644 index 000000000000..9f565b13a3e6 --- /dev/null +++ b/test/junit/scala/collection/mutable/MutationTrackingTest.scala @@ -0,0 +1,211 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.collection +package mutable + +import java.util.ConcurrentModificationException + +import org.junit.Test + +import scala.annotation.nowarn +import scala.tools.testkit.AssertUtil.assertThrows + +abstract class MutationTrackingTest[+C <: Iterable[_]](factory: Factory[Int, C]) { + private def runOp(op: C => Any, viewOrIterator: C => IterableOnceOps[_, AnyConstr, _]): Unit = { + val coll = (factory.newBuilder += 1 += 2 += 3 += 4).result() + val it = viewOrIterator(coll) + op(coll) + it.foreach(_ => ()) + } + + private def runOpMaybeThrowing(op: C => Any, + throws: Boolean, + viewOrIterator: C => IterableOnceOps[_, AnyConstr, _]): Unit = { + if (throws) assertThrows[ConcurrentModificationException](runOp(op, viewOrIterator), _ contains "iteration") + else runOp(op, viewOrIterator) + } + + private def runOpForViewAndIterator(op: C => Any, throws: Boolean): Unit = { + runOp(op, _.view) // never throws + runOpMaybeThrowing(op, throws, _.iterator) + runOpMaybeThrowing(op, throws, _.view.iterator) + } + + /** Checks that no exception is thrown by an operation. */ + def checkFine(op: C => Any): Unit = runOpForViewAndIterator(op, throws = false) + + /** Checks that an exception is thrown by an operation. */ + def checkThrows(op: C => Any): Unit = runOpForViewAndIterator(op, throws = true) + + @Test + def nop(): Unit = checkFine { _ => () } + + @Test + def knownSize(): Unit = checkFine { _.knownSize } +} + +// mixins +object MutationTrackingTest { + type I = Iterable[_] + + trait ClearableTest { self: MutationTrackingTest[Clearable with I] => + @Test + def clear(): Unit = checkThrows { _.clear() } + } + + trait GrowableTest extends ClearableTest { self: MutationTrackingTest[Growable[Int] with I] => + @Test + def addOne(): Unit = checkThrows { _ += 1 } + + @Test + def addAll(): Unit = { + checkThrows { _ ++= Seq(2, 3) } + checkFine { _ ++= Nil } + } + } + + trait ShrinkableTest { self: MutationTrackingTest[Shrinkable[Int] with I] => + @Test + def subtractOne(): Unit = checkThrows { _ -= 1 } + + @Test + def subtractAll(): Unit = { + checkThrows { _ --= Seq(1, 2) } + checkFine { _ --= Nil } + } + } + + trait SeqTest { self: MutationTrackingTest[Seq[Int]] => + @Test + def update(): Unit = checkThrows { _(1) = 5 } + + @Test + @nowarn("cat=deprecation") + def transform(): Unit = checkThrows { _.transform(_ + 1) } + } + + trait BufferTest extends GrowableTest with ShrinkableTest with SeqTest { self: MutationTrackingTest[Buffer[Int]] => + @Test + def insert(): Unit = checkThrows { _.insert(0, 5) } + + @Test + def insertAll(): Unit = { + checkThrows { _.insertAll(1, Seq(1, 2, 3)) } + checkFine { _.insertAll(1, Nil) } + } + + @Test + def remove1(): Unit = checkThrows { _ remove 1 } + + @Test + def remove2(): Unit = { + checkThrows { _.remove(1, 2) } + checkFine { _.remove(1, 0) } + } + + @Test + def prepend(): Unit = checkThrows { _ prepend 5 } + + @Test + def prependAll(): Unit = { + checkThrows { _ prependAll Seq(1, 2, 3) } + checkFine { _ prependAll Nil } + } + + @Test + def patchInPlace(): Unit = { + checkThrows { _.patchInPlace(1, Seq(2, 3), 3) } + checkThrows { _.patchInPlace(1, Seq(2, 3), 0) } + checkThrows { _.patchInPlace(4, Seq(2, 3), 3) } + checkThrows { _.patchInPlace(1, Nil, 3) } + checkFine { _.patchInPlace(1, Nil, 0) } + checkFine { _.patchInPlace(4, Nil, 3) } + } + + @Test + def dropInPlace(): Unit = { + checkThrows { _ dropInPlace 2 } + checkFine { _ dropInPlace 0 } + } + + @Test + def dropRightInPlace(): Unit = { + checkThrows { _ dropRightInPlace 2 } + checkFine { _ dropRightInPlace 0 } + } + + @Test + def trimStart(): Unit = { + checkThrows { _ trimStart 2 } + checkFine { _ trimStart 0 } + } + + @Test + def trimEnd(): Unit = { + checkThrows { _ trimEnd 2 } + checkFine { _ trimEnd 0 } + } + + @Test + def takeInPlace(): Unit = { + checkThrows { _ takeInPlace 2 } + checkFine { _ takeInPlace 4 } + } + + @Test + def takeRightInPlace(): Unit = { + checkThrows { _ takeRightInPlace 2 } + checkFine { _ takeRightInPlace 4 } + } + + @Test + def sliceInPlace(): Unit = { + checkThrows { _.sliceInPlace(1, 3) } + checkFine { _.sliceInPlace(0, 4) } + } + + @Test + def dropWhileInPlace(): Unit = { + checkThrows { _.dropWhileInPlace(_ < 3) } + checkFine { _.dropWhileInPlace(_ => false) } + } + + @Test + def takeWhileInPlace(): Unit = { + checkThrows { _.takeWhileInPlace(_ < 3) } + checkFine { _.takeWhileInPlace(_ => true) } + } + + @Test + def padToInPlace(): Unit = { + checkThrows { _.padToInPlace(5, 0) } + checkFine { _.padToInPlace(2, 0) } + } + } +} + +// concrete tests +package MutationTrackingTestImpl { + import scala.collection.mutable.MutationTrackingTest._ + + class ListBufferTest extends MutationTrackingTest(ListBuffer) with BufferTest { + @Test + def mapInPlace(): Unit = checkThrows { _.mapInPlace(_ + 1) } + + @Test + def flatMapInPlace(): Unit = checkThrows { _.flatMapInPlace(i => (i + 1) :: Nil) } + + @Test + def filterInPlace(): Unit = checkThrows { _.filterInPlace(_ => true) } + } +} diff --git a/test/scalacheck/scala/tools/nsc/scaladoc/HtmlFactoryTest.scala b/test/scalacheck/scala/tools/nsc/scaladoc/HtmlFactoryTest.scala index 8002a6f8c85e..2673c7ea4d70 100644 --- a/test/scalacheck/scala/tools/nsc/scaladoc/HtmlFactoryTest.scala +++ b/test/scalacheck/scala/tools/nsc/scaladoc/HtmlFactoryTest.scala @@ -537,16 +537,17 @@ object HtmlFactoryTest extends Properties("HtmlFactory") { val files = createTemplates("basic.scala") //println(files) - property("class") = files.get("com/example/p1/Clazz.html").exists { page => - val html = toHtml(page) - - property("implicit conversion") = html contains """implicit """ - - property("gt4s") = html contains """title="gt4s: $colon$colon"""" - - property("gt4s of a deprecated method") = html contains """title="gt4s: $colon$colon$colon$colon. Deprecated: """ - - true + // class + { + val html = files.get("com/example/p1/Clazz.html") + .map(page => { lazy val s = toHtml(page); () => s }) + property("class") = html.map(_()).isDefined + def ifExists(op: String => Prop) = html.map(_()).fold(undecided)(op) + property("class:implicit conversion") = + ifExists(_ contains """implicit """) + property("class:gt4s") = ifExists(_ contains """title="gt4s: $colon$colon"""") + property("class:gt4s of a deprecated method") = + ifExists(_ contains """title="gt4s: $colon$colon$colon$colon. Deprecated: """) } property("package") = files.contains("com/example/p1/index.html")