Skip to content

Commit

Permalink
custom LinkedHashSet subclass for match analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
lrytz committed Apr 17, 2023
1 parent 2fa3ee2 commit 5a2376b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
54 changes: 39 additions & 15 deletions src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
package scala
package tools.nsc.transform.patmat

import scala.collection.mutable
import scala.collection.immutable.ArraySeq
import scala.collection.IterableOps
import scala.collection.{IterableOps, mutable}
import scala.reflect.internal.util.Collections._
import scala.reflect.internal.util.HashSet
import scala.tools.nsc.transform.patmat.Logic.LogicLinkedHashSet

trait Logic extends Debugging {
import global._
Expand Down Expand Up @@ -114,22 +114,21 @@ trait Logic extends Debugging {
def implications: List[(Sym, List[Sym], List[Sym])]
}

// The error message of t7020 assumes the ops are ordered implicitly
// However, ListSet is slow (concatenate cost?), so use
// scala.collection.mutable.LinkedHashSet (which grantees "the order in which elements were inserted into the set")
// Using LogicLinkedHashSet (a custom mutable.LinkedHashSet subclass) to ensure deterministic exhaustivity
// messages. immutable.ListSet was too slow (concatenate cost? scala/bug#12499).

// would be nice to statically check whether a prop is equational or pure,
// but that requires typing relations like And(x: Tx, y: Ty) : (if(Tx == PureProp && Ty == PureProp) PureProp else Prop)
final case class And(ops: mutable.LinkedHashSet[Prop]) extends Prop
final case class And(ops: LogicLinkedHashSet[Prop]) extends Prop
object And {
def apply(ps: Prop*) = create(ps)
def create(ps: Iterable[Prop]) = new And(ps.to(mutable.LinkedHashSet))
def create(ps: Iterable[Prop]) = new And(ps.to(LogicLinkedHashSet))
}

final case class Or(ops: mutable.LinkedHashSet[Prop]) extends Prop
final case class Or(ops: LogicLinkedHashSet[Prop]) extends Prop
object Or {
def apply(ps: Prop*) = create(ps)
def create(ps: Iterable[Prop]) = new Or(ps.to(mutable.LinkedHashSet))
def create(ps: Iterable[Prop]) = new Or(ps.to(LogicLinkedHashSet))
}

final case class Not(a: Prop) extends Prop
Expand Down Expand Up @@ -286,7 +285,7 @@ trait Logic extends Debugging {
def simplifyAnd(ps: Iterable[Prop]): Prop = {
// recurse for nested And (pulls all Ands up)
// build up Set in order to remove duplicates
val props = mutable.LinkedHashSet.empty[Prop]
val props = LogicLinkedHashSet.empty[Prop]
for (prop <- ps) {
simplifyProp(prop) match {
case True => // ignore `True`
Expand All @@ -302,7 +301,7 @@ trait Logic extends Debugging {
def simplifyOr(ps: Iterable[Prop]): Prop = {
// recurse for nested Or (pulls all Ors up)
// build up Set in order to remove duplicates
val props = mutable.LinkedHashSet.empty[Prop]
val props = LogicLinkedHashSet.empty[Prop]
for (prop <- ps) {
simplifyProp(prop) match {
case False => // ignore `False`
Expand Down Expand Up @@ -343,15 +342,15 @@ trait Logic extends Debugging {
}

def gatherVariables(p: Prop): collection.Set[Var] = {
val vars = new mutable.LinkedHashSet[Var]()
val vars = new LogicLinkedHashSet[Var]()
(new PropTraverser {
override def applyVar(v: Var) = vars += v
})(p)
vars
}

def gatherSymbols(p: Prop): collection.Set[Sym] = {
val syms = new mutable.LinkedHashSet[Sym]()
val syms = new LogicLinkedHashSet[Sym]()
(new PropTraverser {
override def applySymbol(s: Sym) = syms += s
})(p)
Expand Down Expand Up @@ -409,7 +408,7 @@ trait Logic extends Debugging {
def removeVarEq(props: List[Prop], modelNull: Boolean = false): (Prop, List[Prop]) = {
val start = if (settings.areStatisticsEnabled) statistics.startTimer(statistics.patmatAnaVarEq) else null

val vars = new mutable.LinkedHashSet[Var]
val vars = new LogicLinkedHashSet[Var]

object gatherEqualities extends PropTraverser {
override def apply(p: Prop) = p match {
Expand Down Expand Up @@ -516,6 +515,31 @@ trait Logic extends Debugging {
}
}

object Logic {
import scala.annotation.nowarn
import scala.collection.mutable.{Growable, GrowableBuilder, SetOps}
import scala.collection.{IterableFactory, IterableFactoryDefaults, StrictOptimizedIterableOps}

// Local subclass because we can't override `addAll` in the collections (bin compat), see PR scala/scala#10361
@nowarn("msg=inheritance from class LinkedHashSet")
class LogicLinkedHashSet[A] extends mutable.LinkedHashSet[A]
with SetOps[A, LogicLinkedHashSet, LogicLinkedHashSet[A]]
with StrictOptimizedIterableOps[A, LogicLinkedHashSet, LogicLinkedHashSet[A]]
with IterableFactoryDefaults[A, LogicLinkedHashSet] {
override def iterableFactory: IterableFactory[LogicLinkedHashSet] = LogicLinkedHashSet
override def addAll(xs: IterableOnce[A]): this.type = {
sizeHint(xs.knownSize)
super.addAll(xs)
}
}

object LogicLinkedHashSet extends IterableFactory[LogicLinkedHashSet] {
override def from[A](source: IterableOnce[A]): LogicLinkedHashSet[A] = Growable.from(empty[A], source)
override def empty[A]: LogicLinkedHashSet[A] = new LogicLinkedHashSet[A]
override def newBuilder[A]: mutable.Builder[A, LogicLinkedHashSet[A]] = new GrowableBuilder(empty[A])
}
}

trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
trait TreesAndTypesDomain extends PropositionalLogic with CheckableTreeAndTypeAnalysis {
type Type = global.Type
Expand Down Expand Up @@ -733,8 +757,8 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
}


import global.{ConstantType, SingletonType, Literal, Ident, singleType, TypeBounds, NoSymbol}
import global.definitions._
import global.{ConstantType, Ident, Literal, NoSymbol, SingletonType, TypeBounds, singleType}


// all our variables range over types
Expand Down
5 changes: 5 additions & 0 deletions src/library/scala/collection/mutable/LinkedHashSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class LinkedHashSet[A]

def contains(elem: A): Boolean = findEntry(elem) ne null

override def sizeHint(size: Int): Unit = {
val target = tableSizeFor(((size + 1).toDouble / LinkedHashSet.defaultLoadFactor).toInt)
if (target > table.length) growTable(target)
}

override def add(elem: A): Boolean = {
if (contentSize + 1 >= threshold) growTable(table.length * 2)
val hash = computeHash(elem)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.junit.Test

import scala.collection.mutable
import scala.reflect.internal.util.Position
import scala.tools.nsc.transform.patmat.Logic.LogicLinkedHashSet
import scala.tools.nsc.{Global, Settings}

object TestSolver extends Logic with Solving {
Expand Down Expand Up @@ -579,7 +580,7 @@ class SolvingTest {
def pairWiseEncoding(ops: List[Sym]) = {
And(ops.combinations(2).collect {
case a :: b :: Nil => Or(Not(a), Not(b)): Prop
}.to(mutable.LinkedHashSet))
}.to(LogicLinkedHashSet))
}

@Test
Expand Down

0 comments on commit 5a2376b

Please sign in to comment.