Skip to content

Commit

Permalink
Fix bad performance on complex patmat
Browse files Browse the repository at this point in the history
AnalysisBudget.maxDPLLdepth is already working to limit the initial SAT
solving.  But given enough unassigned symbols, like the test case, the
compiler can end up spending the rest of eternity and all its memory
expanding the model.  So apply the limit where it hurts most (the
cartesian product part).

The ordered sets and various sortings, instead, are to stabilise the
results.
  • Loading branch information
dwijnand committed Mar 3, 2021
1 parent eb8cb18 commit a63a30a
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 262 deletions.
133 changes: 67 additions & 66 deletions src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.collection.immutable.ArraySeq
import scala.reflect.internal.util.Collections._
import scala.reflect.internal.util.{HashSet, StatisticsStatics}

trait Logic extends Debugging {
trait Logic extends Debugging {
import global._

private def max(xs: Seq[Int]) = if (xs.isEmpty) 0 else xs.max
Expand Down Expand Up @@ -117,12 +117,20 @@ trait Logic extends Debugging {
// but that requires typing relations like And(x: Tx, y: Ty) : (if(Tx == PureProp && Ty == PureProp) PureProp else Prop)
final case class And(ops: Set[Prop]) extends Prop
object And {
def apply(ops: Prop*) = new And(ops.toSet)
def apply(ps: Prop*) = create(ps)
def create(ps: Iterable[Prop]) = ps match {
case ps: Set[Prop] => new And(ps)
case _ => new And(ps.to(scala.collection.immutable.ListSet))
}
}

final case class Or(ops: Set[Prop]) extends Prop
object Or {
def apply(ops: Prop*) = new Or(ops.toSet)
def apply(ps: Prop*) = create(ps)
def create(ps: Iterable[Prop]) = ps match {
case ps: Set[Prop] => new Or(ps)
case _ => new Or(ps.to(scala.collection.immutable.ListSet))
}
}

final case class Not(a: Prop) extends Prop
Expand Down Expand Up @@ -161,8 +169,17 @@ trait Logic extends Debugging {
implicit val SymOrdering: Ordering[Sym] = Ordering.by(_.id)
}

def /\(props: Iterable[Prop]) = if (props.isEmpty) True else And(props.toSeq: _*)
def \/(props: Iterable[Prop]) = if (props.isEmpty) False else Or(props.toSeq: _*)
def /\(props: Iterable[Prop]) = props match {
case _ if props.isEmpty => True
case _ if props.sizeIs == 1 => props.head
case _ => And.create(props)
}

def \/(props: Iterable[Prop]) = props match {
case _ if props.isEmpty => False
case _ if props.sizeIs == 1 => props.head
case _ => Or.create(props)
}

/**
* Simplifies propositional formula according to the following rules:
Expand Down Expand Up @@ -267,61 +284,44 @@ trait Logic extends Debugging {
| (_: AtMostOne) => p
}

def simplifyProp(p: Prop): Prop = p match {
case And(fv) =>
// recurse for nested And (pulls all Ands up)
// build up Set in order to remove duplicates
val opsFlattenedBuilder = collection.immutable.Set.newBuilder[Prop]
for (prop <- fv) {
val simplified = simplifyProp(prop)
if (simplified != True) { // ignore `True`
simplified match {
case And(fv) => fv.foreach(opsFlattenedBuilder += _)
case f => opsFlattenedBuilder += f
}
}
}
val opsFlattened = opsFlattenedBuilder.result()

if (opsFlattened.contains(False) || hasImpureAtom(opsFlattened)) {
False
} else {
opsFlattened.size match {
case 0 => True
case 1 => opsFlattened.head
case _ => new And(opsFlattened)
}
def simplifyAnd(ps: Set[Prop]): Prop = {
// recurse for nested And (pulls all Ands up)
// build up Set in order to remove duplicates
val props = mutable.HashSet.empty[Prop]
for (prop <- ps) {
simplifyProp(prop) match {
case True => // ignore `True`
case And(fv) => fv.foreach(props += _)
case f => props += f
}
case Or(fv) =>
// recurse for nested Or (pulls all Ors up)
// build up Set in order to remove duplicates
val opsFlattenedBuilder = collection.immutable.Set.newBuilder[Prop]
for (prop <- fv) {
val simplified = simplifyProp(prop)
if (simplified != False) { // ignore `False`
simplified match {
case Or(fv) => fv.foreach(opsFlattenedBuilder += _)
case f => opsFlattenedBuilder += f
}
}
}
val opsFlattened = opsFlattenedBuilder.result()

if (opsFlattened.contains(True) || hasImpureAtom(opsFlattened)) {
True
} else {
opsFlattened.size match {
case 0 => False
case 1 => opsFlattened.head
case _ => new Or(opsFlattened)
}
}

if (props.contains(False) || hasImpureAtom(props)) False
else /\(props)
}

def simplifyOr(ps: Set[Prop]): Prop = {
// recurse for nested Or (pulls all Ors up)
// build up Set in order to remove duplicates
val props = mutable.HashSet.empty[Prop]
for (prop <- ps) {
simplifyProp(prop) match {
case False => // ignore `False`
case Or(fv) => props ++= fv
case f => props += f
}
case Not(Not(a)) =>
simplify(a)
case Not(p) =>
Not(simplify(p))
case p =>
p
}

if (props.contains(True) || hasImpureAtom(props)) True
else \/(props)
}

def simplifyProp(p: Prop): Prop = p match {
case And(ps) => simplifyAnd(ps)
case Or(ps) => simplifyOr(ps)
case Not(Not(a)) => simplify(a)
case Not(p) => Not(simplify(p))
case p => p
}

val nnf = negationNormalForm(f)
Expand All @@ -344,15 +344,15 @@ trait Logic extends Debugging {
}

def gatherVariables(p: Prop): collection.Set[Var] = {
val vars = new mutable.HashSet[Var]()
val vars = new mutable.LinkedHashSet[Var]()
(new PropTraverser {
override def applyVar(v: Var) = vars += v
})(p)
vars
}

def gatherSymbols(p: Prop): collection.Set[Sym] = {
val syms = new mutable.HashSet[Sym]()
val syms = new mutable.LinkedHashSet[Sym]()
(new PropTraverser {
override def applySymbol(s: Sym) = syms += s
})(p)
Expand Down Expand Up @@ -511,7 +511,7 @@ trait Logic extends Debugging {

final case class Solution(model: Model, unassigned: List[Sym])

def findModelFor(solvable: Solvable): Model
def hasModel(solvable: Solvable): Boolean

def findAllModelsFor(solvable: Solvable, sym: Symbol = NoSymbol): List[Solution]
}
Expand Down Expand Up @@ -562,7 +562,7 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
val subConsts =
enumerateSubtypes(staticTp, grouped = false)
.headOption.map { tps =>
tps.toSet[Type].map{ tp =>
tps.to(scala.collection.immutable.ListSet).map { tp =>
val domainC = TypeConst(tp)
registerEquality(domainC)
domainC
Expand All @@ -583,7 +583,7 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
val subtypes = enumerateSubtypes(staticTp, grouped = true)
subtypes.map {
subTypes =>
val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).toSet
val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).to(scala.collection.immutable.ListSet)
if (mayBeNull) syms + symForEqualsTo(NullConst) else syms
}.filter(_.nonEmpty)
}
Expand Down Expand Up @@ -719,13 +719,14 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
lazy val symForStaticTp: Option[Sym] = symForEqualsTo.get(TypeConst(staticTpCheckable))

// don't access until all potential equalities have been registered using registerEquality
private lazy val equalitySyms = {observed(); symForEqualsTo.values.toList}
private lazy val equalitySyms = {observed(); symForEqualsTo.values.toList.sortBy(_.toString) }

// don't call until all equalities have been registered and registerNull has been called (if needed)
def describe = {
val consts = symForEqualsTo.keys.toSeq.sortBy(_.toString)
def domain_s = domain match {
case Some(d) => d.mkString(" ::= ", " | ", "// "+ symForEqualsTo.keys)
case _ => symForEqualsTo.keys.mkString(" ::= ", " | ", " | ...")
case Some(d) => d.mkString(" ::= ", " | ", "// " + consts)
case _ => consts.mkString(" ::= ", " | ", " | ...")
}
s"$this: ${staticTp}${domain_s} // = $path"
}
Expand Down
46 changes: 21 additions & 25 deletions src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,8 @@ trait MatchAnalysis extends MatchApproximation {
else {
prefix += prefHead
current = current.tail
val and = And((current.head +: prefix).toIndexedSeq: _*)
val model = findModelFor(eqFreePropToSolvable(and))

// debug.patmat("trying to reach:\n"+ cnfString(current.head) +"\nunder prefix:\n"+ cnfString(prefix))
// if (NoModel ne model) debug.patmat("reached: "+ modelString(model))

reachable = NoModel ne model
val and = And((current.head +: prefix).toIndexedSeq: _*)
reachable = hasModel(eqFreePropToSolvable(and))
}
}

Expand Down Expand Up @@ -573,13 +568,9 @@ trait MatchAnalysis extends MatchApproximation {
val matchFailModels = findAllModelsFor(propToSolvable(matchFails), prevBinder)

val scrutVar = Var(prevBinderTree)
val counterExamples = {
matchFailModels.flatMap {
model =>
val varAssignments = expandModel(model)
varAssignments.flatMap(modelToCounterExample(scrutVar) _)
}
}
val counterExamples = matchFailModels.iterator.flatMap { model =>
expandModel(model).flatMap(modelToCounterExample(scrutVar))
}.take(AnalysisBudget.maxDPLLdepth).toList

// sorting before pruning is important here in order to
// keep neg/t7020.scala stable
Expand Down Expand Up @@ -658,16 +649,18 @@ trait MatchAnalysis extends MatchApproximation {
case object WildcardExample extends CounterExample { override def toString = "_" }
case object NoExample extends CounterExample { override def toString = "??" }

type VarAssignment = Map[Var, (Seq[Const], Seq[Const])]

// returns a mapping from variable to
// equal and notEqual symbols
def modelToVarAssignment(model: Model): Map[Var, (Seq[Const], Seq[Const])] =
def modelToVarAssignment(model: Model): VarAssignment =
model.toSeq.groupBy(_._1.variable).view.mapValues{ xs =>
val (trues, falses) = xs.partition(_._2)
(trues map (_._1.const), falses map (_._1.const))
// should never be more than one value in trues...
}.to(Map)

def varAssignmentString(varAssignment: Map[Var, (Seq[Const], Seq[Const])]) =
def varAssignmentString(varAssignment: VarAssignment) =
varAssignment.toSeq.sortBy(_._1.toString).map { case (v, (trues, falses)) =>
s"$v(=${v.path}: ${v.staticTpCheckable}) == ${trues.mkString("(", ", ", ")")} != (${falses.mkString(", ")})"
}.mkString("\n")
Expand Down Expand Up @@ -702,7 +695,7 @@ trait MatchAnalysis extends MatchApproximation {
* Only one of these symbols can be set to true,
* since `V2` can at most be equal to one of {2,6,5,4,7}.
*/
def expandModel(solution: Solution): List[Map[Var, (Seq[Const], Seq[Const])]] = {
def expandModel(solution: Solution): List[VarAssignment] = {

val model = solution.model

Expand All @@ -719,7 +712,7 @@ trait MatchAnalysis extends MatchApproximation {
val groupedByVar: Map[Var, List[Sym]] = solution.unassigned.groupBy(_.variable)

val expanded = for {
(variable, syms) <- groupedByVar.toList
(variable, syms) <- groupedByVar.toList.sortBy(_._1.toString)
} yield {

val (equal, notEqual) = varAssignment.getOrElse(variable, Nil -> Nil)
Expand Down Expand Up @@ -753,11 +746,14 @@ trait MatchAnalysis extends MatchApproximation {
// we need the Cartesian product here,
// since we want to report all missing cases
// (i.e., combinations)
val cartesianProd = expanded.reduceLeft((xs, ys) =>
for {map1 <- xs
map2 <- ys} yield {
map1 ++ map2
})
@tailrec def loop(acc: List[VarAssignment], in: List[List[VarAssignment]]): List[VarAssignment] = {
if (acc.sizeIs > AnalysisBudget.maxDPLLdepth) acc.take(AnalysisBudget.maxDPLLdepth)
else in match {
case vs :: vss => loop(for (map1 <- acc; map2 <- vs) yield map1 ++ map2, vss)
case _ => acc
}
}
val cartesianProd = loop(Nil, expanded)

// add expanded variables
// note that we can just use `++`
Expand All @@ -774,7 +770,7 @@ trait MatchAnalysis extends MatchApproximation {
// (the variables don't take into account type information derived from other variables,
// so, naively, you might try to construct a counter example like _ :: Nil(_ :: _, _ :: _),
// since we didn't realize the tail of the outer cons was a Nil)
def modelToCounterExample(scrutVar: Var)(varAssignment: Map[Var, (Seq[Const], Seq[Const])]): Option[CounterExample] = {
def modelToCounterExample(scrutVar: Var)(varAssignment: VarAssignment): Option[CounterExample] = {
val strict = !settings.nonStrictPatmatAnalysis.value

// chop a path into a list of symbols
Expand Down Expand Up @@ -919,7 +915,7 @@ trait MatchAnalysis extends MatchApproximation {
}

// slurp in information from other variables
varAssignment.keys.foreach{ v => if (v != scrutVar) VariableAssignment(v) }
varAssignment.keys.toSeq.sortBy(_.toString).foreach(v => if (v != scrutVar) VariableAssignment(v))

// this is the variable we want a counter example for
VariableAssignment(scrutVar).toCounterExample()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
val testss = approximateMatchConservative(prevBinder, cases)

// interpret:
val dependencies = new mutable.LinkedHashMap[Test, Set[Prop]]
val tested = new mutable.HashSet[Prop]
val dependencies = new mutable.LinkedHashMap[Test, mutable.LinkedHashSet[Prop]]
val tested = new mutable.LinkedHashSet[Prop]
val reusesMap = new mutable.LinkedHashMap[Int, Test]
val reusesTest = { (test: Test) => reusesMap.get(test.id) }
val registerReuseBy = { (priorTest: Test, later: Test) =>
Expand All @@ -57,32 +57,32 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
val cond = test.prop

def simplify(c: Prop): Set[Prop] = c match {
case And(ops) => ops.toSet flatMap simplify
case And(ops) => ops flatMap simplify
case Or(ops) => Set(False) // TODO: make more precise
case Not(Eq(Var(_), NullConst)) => Set(True) // not worth remembering
case Not(Eq(Var(_), NullConst)) => Set.empty // not worth remembering
case True => Set.empty // same
case _ => Set(c)
}
val conds = simplify(cond)

if (conds(False)) false // stop when we encounter a definite "no" or a "not sure"
else {
val nonTrivial = conds - True
if (!nonTrivial.isEmpty) {
tested ++= nonTrivial
if (!conds.isEmpty) {
tested ++= conds

// is there an earlier test that checks our condition and whose dependencies are implied by ours?
dependencies find {
case (priorTest, deps) =>
((simplify(priorTest.prop) == nonTrivial) || // our conditions are implied by priorTest if it checks the same thing directly
(nonTrivial subsetOf deps) // or if it depends on a superset of our conditions
) && (deps subsetOf tested) // the conditions we've tested when we are here in the match satisfy the prior test, and hence what it tested
((simplify(priorTest.prop) == conds) || // our conditions are implied by priorTest if it checks the same thing directly
(conds subsetOf deps) // or if it depends on a superset of our conditions
) && (deps subsetOf tested) // the conditions we've tested when we are here in the match satisfy the prior test, and hence what it tested
} foreach {
case (priorTest, _) =>
// if so, note the dependency in both tests
registerReuseBy(priorTest, test)
}

dependencies(test) = tested.toSet // copies
dependencies(test) = tested.clone()
}
true
}
Expand All @@ -108,7 +108,7 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
val collapsed = testss map { tests =>
// map tests to the equivalent list of treemakers, replacing shared prefixes by a reusing treemaker
// if there's no sharing, simply map to the tree makers corresponding to the tests
var currDeps = Set[Prop]()
var currDeps = mutable.LinkedHashSet.empty[Prop]
val (sharedPrefix, suffix) = tests span { test =>
(test.prop == True) || (for(
reusedTest <- reusesTest(test);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
// mutable case class fields need to be stored regardless (scala/bug#5158, scala/bug#6070) -- see override in ProductExtractorTreeMaker
// sub patterns bound to wildcard (_) are never stored as they can't be referenced
// dirty debuggers will have to get dirty to see the wildcards
lazy val storedBinders: Set[Symbol] =
private lazy val storedBinders: Set[Symbol] =
(if (debugInfoEmitVars) subPatBinders.toSet else Set.empty) ++ extraStoredBinders diff ignoredSubPatBinders

// e.g., mutable fields of a case class in ProductExtractorTreeMaker
Expand Down

0 comments on commit a63a30a

Please sign in to comment.