From 5a2376b89f549e5f36275f9a1dc75be6b194776b Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Tue, 4 Apr 2023 15:31:51 +0200 Subject: [PATCH] custom LinkedHashSet subclass for match analysis --- .../tools/nsc/transform/patmat/Logic.scala | 54 +++++++++++++------ .../collection/mutable/LinkedHashSet.scala | 5 ++ .../nsc/transform/patmat/SolvingTest.scala | 3 +- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala index ec2c739f0c62..b5ea1d2f91d1 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala @@ -13,11 +13,11 @@ package scala package tools.nsc.transform.patmat -import scala.collection.mutable import scala.collection.immutable.ArraySeq -import scala.collection.IterableOps +import scala.collection.{IterableOps, mutable} import scala.reflect.internal.util.Collections._ import scala.reflect.internal.util.HashSet +import scala.tools.nsc.transform.patmat.Logic.LogicLinkedHashSet trait Logic extends Debugging { import global._ @@ -114,22 +114,21 @@ trait Logic extends Debugging { def implications: List[(Sym, List[Sym], List[Sym])] } - // The error message of t7020 assumes the ops are ordered implicitly - // However, ListSet is slow (concatenate cost?), so use - // scala.collection.mutable.LinkedHashSet (which grantees "the order in which elements were inserted into the set") + // Using LogicLinkedHashSet (a custom mutable.LinkedHashSet subclass) to ensure deterministic exhaustivity + // messages. immutable.ListSet was too slow (concatenate cost? scala/bug#12499). // would be nice to statically check whether a prop is equational or pure, // but that requires typing relations like And(x: Tx, y: Ty) : (if(Tx == PureProp && Ty == PureProp) PureProp else Prop) - final case class And(ops: mutable.LinkedHashSet[Prop]) extends Prop + final case class And(ops: LogicLinkedHashSet[Prop]) extends Prop object And { def apply(ps: Prop*) = create(ps) - def create(ps: Iterable[Prop]) = new And(ps.to(mutable.LinkedHashSet)) + def create(ps: Iterable[Prop]) = new And(ps.to(LogicLinkedHashSet)) } - final case class Or(ops: mutable.LinkedHashSet[Prop]) extends Prop + final case class Or(ops: LogicLinkedHashSet[Prop]) extends Prop object Or { def apply(ps: Prop*) = create(ps) - def create(ps: Iterable[Prop]) = new Or(ps.to(mutable.LinkedHashSet)) + def create(ps: Iterable[Prop]) = new Or(ps.to(LogicLinkedHashSet)) } final case class Not(a: Prop) extends Prop @@ -286,7 +285,7 @@ trait Logic extends Debugging { def simplifyAnd(ps: Iterable[Prop]): Prop = { // recurse for nested And (pulls all Ands up) // build up Set in order to remove duplicates - val props = mutable.LinkedHashSet.empty[Prop] + val props = LogicLinkedHashSet.empty[Prop] for (prop <- ps) { simplifyProp(prop) match { case True => // ignore `True` @@ -302,7 +301,7 @@ trait Logic extends Debugging { def simplifyOr(ps: Iterable[Prop]): Prop = { // recurse for nested Or (pulls all Ors up) // build up Set in order to remove duplicates - val props = mutable.LinkedHashSet.empty[Prop] + val props = LogicLinkedHashSet.empty[Prop] for (prop <- ps) { simplifyProp(prop) match { case False => // ignore `False` @@ -343,7 +342,7 @@ trait Logic extends Debugging { } def gatherVariables(p: Prop): collection.Set[Var] = { - val vars = new mutable.LinkedHashSet[Var]() + val vars = new LogicLinkedHashSet[Var]() (new PropTraverser { override def applyVar(v: Var) = vars += v })(p) @@ -351,7 +350,7 @@ trait Logic extends Debugging { } def gatherSymbols(p: Prop): collection.Set[Sym] = { - val syms = new mutable.LinkedHashSet[Sym]() + val syms = new LogicLinkedHashSet[Sym]() (new PropTraverser { override def applySymbol(s: Sym) = syms += s })(p) @@ -409,7 +408,7 @@ trait Logic extends Debugging { def removeVarEq(props: List[Prop], modelNull: Boolean = false): (Prop, List[Prop]) = { val start = if (settings.areStatisticsEnabled) statistics.startTimer(statistics.patmatAnaVarEq) else null - val vars = new mutable.LinkedHashSet[Var] + val vars = new LogicLinkedHashSet[Var] object gatherEqualities extends PropTraverser { override def apply(p: Prop) = p match { @@ -516,6 +515,31 @@ trait Logic extends Debugging { } } +object Logic { + import scala.annotation.nowarn + import scala.collection.mutable.{Growable, GrowableBuilder, SetOps} + import scala.collection.{IterableFactory, IterableFactoryDefaults, StrictOptimizedIterableOps} + + // Local subclass because we can't override `addAll` in the collections (bin compat), see PR scala/scala#10361 + @nowarn("msg=inheritance from class LinkedHashSet") + class LogicLinkedHashSet[A] extends mutable.LinkedHashSet[A] + with SetOps[A, LogicLinkedHashSet, LogicLinkedHashSet[A]] + with StrictOptimizedIterableOps[A, LogicLinkedHashSet, LogicLinkedHashSet[A]] + with IterableFactoryDefaults[A, LogicLinkedHashSet] { + override def iterableFactory: IterableFactory[LogicLinkedHashSet] = LogicLinkedHashSet + override def addAll(xs: IterableOnce[A]): this.type = { + sizeHint(xs.knownSize) + super.addAll(xs) + } + } + + object LogicLinkedHashSet extends IterableFactory[LogicLinkedHashSet] { + override def from[A](source: IterableOnce[A]): LogicLinkedHashSet[A] = Growable.from(empty[A], source) + override def empty[A]: LogicLinkedHashSet[A] = new LogicLinkedHashSet[A] + override def newBuilder[A]: mutable.Builder[A, LogicLinkedHashSet[A]] = new GrowableBuilder(empty[A]) + } +} + trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis { trait TreesAndTypesDomain extends PropositionalLogic with CheckableTreeAndTypeAnalysis { type Type = global.Type @@ -733,8 +757,8 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis { } - import global.{ConstantType, SingletonType, Literal, Ident, singleType, TypeBounds, NoSymbol} import global.definitions._ + import global.{ConstantType, Ident, Literal, NoSymbol, SingletonType, TypeBounds, singleType} // all our variables range over types diff --git a/src/library/scala/collection/mutable/LinkedHashSet.scala b/src/library/scala/collection/mutable/LinkedHashSet.scala index 684a80cbc978..41724de9a661 100644 --- a/src/library/scala/collection/mutable/LinkedHashSet.scala +++ b/src/library/scala/collection/mutable/LinkedHashSet.scala @@ -82,6 +82,11 @@ class LinkedHashSet[A] def contains(elem: A): Boolean = findEntry(elem) ne null + override def sizeHint(size: Int): Unit = { + val target = tableSizeFor(((size + 1).toDouble / LinkedHashSet.defaultLoadFactor).toInt) + if (target > table.length) growTable(target) + } + override def add(elem: A): Boolean = { if (contentSize + 1 >= threshold) growTable(table.length * 2) val hash = computeHash(elem) diff --git a/test/junit/scala/tools/nsc/transform/patmat/SolvingTest.scala b/test/junit/scala/tools/nsc/transform/patmat/SolvingTest.scala index 40643b2af887..fa7ce9091d96 100644 --- a/test/junit/scala/tools/nsc/transform/patmat/SolvingTest.scala +++ b/test/junit/scala/tools/nsc/transform/patmat/SolvingTest.scala @@ -5,6 +5,7 @@ import org.junit.Test import scala.collection.mutable import scala.reflect.internal.util.Position +import scala.tools.nsc.transform.patmat.Logic.LogicLinkedHashSet import scala.tools.nsc.{Global, Settings} object TestSolver extends Logic with Solving { @@ -579,7 +580,7 @@ class SolvingTest { def pairWiseEncoding(ops: List[Sym]) = { And(ops.combinations(2).collect { case a :: b :: Nil => Or(Not(a), Not(b)): Prop - }.to(mutable.LinkedHashSet)) + }.to(LogicLinkedHashSet)) } @Test