Skip to content

Commit

Permalink
Fix bad performance on complex patmat
Browse files Browse the repository at this point in the history
AnalysisBudget.maxDPLLdepth is already working to limit the initial SAT
solving.  But given enough unassigned symbols, like the test case, the
compiler can end up spending the rest of eternity and all its memory
expanding the model.  So apply the limit where it hurts most (the
cartesian product part).

The ordered sets and various sortings, instead, are to stabilise the
results.
  • Loading branch information
dwijnand committed Mar 3, 2021
1 parent eb8cb18 commit 6eaf217
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 277 deletions.
133 changes: 67 additions & 66 deletions src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.collection.immutable.ArraySeq
import scala.reflect.internal.util.Collections._
import scala.reflect.internal.util.{HashSet, StatisticsStatics}

trait Logic extends Debugging {
trait Logic extends Debugging {
import global._

private def max(xs: Seq[Int]) = if (xs.isEmpty) 0 else xs.max
Expand Down Expand Up @@ -117,12 +117,20 @@ trait Logic extends Debugging {
// but that requires typing relations like And(x: Tx, y: Ty) : (if(Tx == PureProp && Ty == PureProp) PureProp else Prop)
final case class And(ops: Set[Prop]) extends Prop
object And {
def apply(ops: Prop*) = new And(ops.toSet)
def apply(ps: Prop*) = create(ps)
def create(ps: Iterable[Prop]) = ps match {
case ps: Set[Prop] => new And(ps)
case _ => new And(ps.to(scala.collection.immutable.ListSet))
}
}

final case class Or(ops: Set[Prop]) extends Prop
object Or {
def apply(ops: Prop*) = new Or(ops.toSet)
def apply(ps: Prop*) = create(ps)
def create(ps: Iterable[Prop]) = ps match {
case ps: Set[Prop] => new Or(ps)
case _ => new Or(ps.to(scala.collection.immutable.ListSet))
}
}

final case class Not(a: Prop) extends Prop
Expand Down Expand Up @@ -161,8 +169,17 @@ trait Logic extends Debugging {
implicit val SymOrdering: Ordering[Sym] = Ordering.by(_.id)
}

def /\(props: Iterable[Prop]) = if (props.isEmpty) True else And(props.toSeq: _*)
def \/(props: Iterable[Prop]) = if (props.isEmpty) False else Or(props.toSeq: _*)
def /\(props: Iterable[Prop]) = props match {
case _ if props.isEmpty => True
case _ if props.sizeIs == 1 => props.head
case _ => And.create(props)
}

def \/(props: Iterable[Prop]) = props match {
case _ if props.isEmpty => False
case _ if props.sizeIs == 1 => props.head
case _ => Or.create(props)
}

/**
* Simplifies propositional formula according to the following rules:
Expand Down Expand Up @@ -267,61 +284,44 @@ trait Logic extends Debugging {
| (_: AtMostOne) => p
}

def simplifyProp(p: Prop): Prop = p match {
case And(fv) =>
// recurse for nested And (pulls all Ands up)
// build up Set in order to remove duplicates
val opsFlattenedBuilder = collection.immutable.Set.newBuilder[Prop]
for (prop <- fv) {
val simplified = simplifyProp(prop)
if (simplified != True) { // ignore `True`
simplified match {
case And(fv) => fv.foreach(opsFlattenedBuilder += _)
case f => opsFlattenedBuilder += f
}
}
}
val opsFlattened = opsFlattenedBuilder.result()

if (opsFlattened.contains(False) || hasImpureAtom(opsFlattened)) {
False
} else {
opsFlattened.size match {
case 0 => True
case 1 => opsFlattened.head
case _ => new And(opsFlattened)
}
def simplifyAnd(ps: Set[Prop]): Prop = {
// recurse for nested And (pulls all Ands up)
// build up Set in order to remove duplicates
val props = mutable.HashSet.empty[Prop]
for (prop <- ps) {
simplifyProp(prop) match {
case True => // ignore `True`
case And(fv) => fv.foreach(props += _)
case f => props += f
}
case Or(fv) =>
// recurse for nested Or (pulls all Ors up)
// build up Set in order to remove duplicates
val opsFlattenedBuilder = collection.immutable.Set.newBuilder[Prop]
for (prop <- fv) {
val simplified = simplifyProp(prop)
if (simplified != False) { // ignore `False`
simplified match {
case Or(fv) => fv.foreach(opsFlattenedBuilder += _)
case f => opsFlattenedBuilder += f
}
}
}
val opsFlattened = opsFlattenedBuilder.result()

if (opsFlattened.contains(True) || hasImpureAtom(opsFlattened)) {
True
} else {
opsFlattened.size match {
case 0 => False
case 1 => opsFlattened.head
case _ => new Or(opsFlattened)
}
}

if (props.contains(False) || hasImpureAtom(props)) False
else /\(props)
}

def simplifyOr(ps: Set[Prop]): Prop = {
// recurse for nested Or (pulls all Ors up)
// build up Set in order to remove duplicates
val props = mutable.HashSet.empty[Prop]
for (prop <- ps) {
simplifyProp(prop) match {
case False => // ignore `False`
case Or(fv) => props ++= fv
case f => props += f
}
case Not(Not(a)) =>
simplify(a)
case Not(p) =>
Not(simplify(p))
case p =>
p
}

if (props.contains(True) || hasImpureAtom(props)) True
else \/(props)
}

def simplifyProp(p: Prop): Prop = p match {
case And(ps) => simplifyAnd(ps)
case Or(ps) => simplifyOr(ps)
case Not(Not(a)) => simplify(a)
case Not(p) => Not(simplify(p))
case p => p
}

val nnf = negationNormalForm(f)
Expand All @@ -344,15 +344,15 @@ trait Logic extends Debugging {
}

def gatherVariables(p: Prop): collection.Set[Var] = {
val vars = new mutable.HashSet[Var]()
val vars = new mutable.LinkedHashSet[Var]()
(new PropTraverser {
override def applyVar(v: Var) = vars += v
})(p)
vars
}

def gatherSymbols(p: Prop): collection.Set[Sym] = {
val syms = new mutable.HashSet[Sym]()
val syms = new mutable.LinkedHashSet[Sym]()
(new PropTraverser {
override def applySymbol(s: Sym) = syms += s
})(p)
Expand Down Expand Up @@ -511,7 +511,7 @@ trait Logic extends Debugging {

final case class Solution(model: Model, unassigned: List[Sym])

def findModelFor(solvable: Solvable): Model
def hasModel(solvable: Solvable): Boolean

def findAllModelsFor(solvable: Solvable, sym: Symbol = NoSymbol): List[Solution]
}
Expand Down Expand Up @@ -562,7 +562,7 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
val subConsts =
enumerateSubtypes(staticTp, grouped = false)
.headOption.map { tps =>
tps.toSet[Type].map{ tp =>
tps.to(scala.collection.immutable.ListSet).map { tp =>
val domainC = TypeConst(tp)
registerEquality(domainC)
domainC
Expand All @@ -583,7 +583,7 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
val subtypes = enumerateSubtypes(staticTp, grouped = true)
subtypes.map {
subTypes =>
val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).toSet
val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).to(scala.collection.immutable.ListSet)
if (mayBeNull) syms + symForEqualsTo(NullConst) else syms
}.filter(_.nonEmpty)
}
Expand Down Expand Up @@ -719,13 +719,14 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
lazy val symForStaticTp: Option[Sym] = symForEqualsTo.get(TypeConst(staticTpCheckable))

// don't access until all potential equalities have been registered using registerEquality
private lazy val equalitySyms = {observed(); symForEqualsTo.values.toList}
private lazy val equalitySyms = {observed(); symForEqualsTo.values.toList.sortBy(_.toString) }

// don't call until all equalities have been registered and registerNull has been called (if needed)
def describe = {
val consts = symForEqualsTo.keys.toSeq.sortBy(_.toString)
def domain_s = domain match {
case Some(d) => d.mkString(" ::= ", " | ", "// "+ symForEqualsTo.keys)
case _ => symForEqualsTo.keys.mkString(" ::= ", " | ", " | ...")
case Some(d) => d.mkString(" ::= ", " | ", "// " + consts)
case _ => consts.mkString(" ::= ", " | ", " | ...")
}
s"$this: ${staticTp}${domain_s} // = $path"
}
Expand Down
71 changes: 31 additions & 40 deletions src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,8 @@ trait MatchAnalysis extends MatchApproximation {
else {
prefix += prefHead
current = current.tail
val and = And((current.head +: prefix).toIndexedSeq: _*)
val model = findModelFor(eqFreePropToSolvable(and))

// debug.patmat("trying to reach:\n"+ cnfString(current.head) +"\nunder prefix:\n"+ cnfString(prefix))
// if (NoModel ne model) debug.patmat("reached: "+ modelString(model))

reachable = NoModel ne model
val and = And((current.head +: prefix).toIndexedSeq: _*)
reachable = hasModel(eqFreePropToSolvable(and))
}
}

Expand Down Expand Up @@ -573,13 +568,9 @@ trait MatchAnalysis extends MatchApproximation {
val matchFailModels = findAllModelsFor(propToSolvable(matchFails), prevBinder)

val scrutVar = Var(prevBinderTree)
val counterExamples = {
matchFailModels.flatMap {
model =>
val varAssignments = expandModel(model)
varAssignments.flatMap(modelToCounterExample(scrutVar) _)
}
}
val counterExamples = matchFailModels.iterator.flatMap { model =>
expandModel(model).flatMap(modelToCounterExample(scrutVar))
}.take(AnalysisBudget.maxDPLLdepth).toList

// sorting before pruning is important here in order to
// keep neg/t7020.scala stable
Expand Down Expand Up @@ -658,16 +649,18 @@ trait MatchAnalysis extends MatchApproximation {
case object WildcardExample extends CounterExample { override def toString = "_" }
case object NoExample extends CounterExample { override def toString = "??" }

type VarAssignment = Map[Var, (Seq[Const], Seq[Const])]

// returns a mapping from variable to
// equal and notEqual symbols
def modelToVarAssignment(model: Model): Map[Var, (Seq[Const], Seq[Const])] =
def modelToVarAssignment(model: Model): VarAssignment =
model.toSeq.groupBy(_._1.variable).view.mapValues{ xs =>
val (trues, falses) = xs.partition(_._2)
(trues map (_._1.const), falses map (_._1.const))
// should never be more than one value in trues...
}.to(Map)

def varAssignmentString(varAssignment: Map[Var, (Seq[Const], Seq[Const])]) =
def varAssignmentString(varAssignment: VarAssignment) =
varAssignment.toSeq.sortBy(_._1.toString).map { case (v, (trues, falses)) =>
s"$v(=${v.path}: ${v.staticTpCheckable}) == ${trues.mkString("(", ", ", ")")} != (${falses.mkString(", ")})"
}.mkString("\n")
Expand Down Expand Up @@ -702,7 +695,7 @@ trait MatchAnalysis extends MatchApproximation {
* Only one of these symbols can be set to true,
* since `V2` can at most be equal to one of {2,6,5,4,7}.
*/
def expandModel(solution: Solution): List[Map[Var, (Seq[Const], Seq[Const])]] = {
def expandModel(solution: Solution): List[VarAssignment] = {

val model = solution.model

Expand All @@ -719,7 +712,7 @@ trait MatchAnalysis extends MatchApproximation {
val groupedByVar: Map[Var, List[Sym]] = solution.unassigned.groupBy(_.variable)

val expanded = for {
(variable, syms) <- groupedByVar.toList
(variable, syms) <- groupedByVar.toList.sortBy(_._1.toString)
} yield {

val (equal, notEqual) = varAssignment.getOrElse(variable, Nil -> Nil)
Expand All @@ -735,7 +728,7 @@ trait MatchAnalysis extends MatchApproximation {
// a list counter example could contain wildcards: e.g. `List(_,_)`
val allEqual = addVarAssignment(syms.map(_.const), Nil)

if(equal.isEmpty) {
if (equal.isEmpty) {
val oneHot = for {
s <- syms
} yield {
Expand All @@ -747,34 +740,32 @@ trait MatchAnalysis extends MatchApproximation {
}
}

if (expanded.isEmpty) {
List(varAssignment)
} else {
// we need the Cartesian product here,
// since we want to report all missing cases
// (i.e., combinations)
val cartesianProd = expanded.reduceLeft((xs, ys) =>
for {map1 <- xs
map2 <- ys} yield {
map1 ++ map2
})

// add expanded variables
// note that we can just use `++`
// since the Maps have disjoint keySets
for {
m <- cartesianProd
} yield {
varAssignment ++ m
// we need the Cartesian product here,
// since we want to report all missing cases
// (i.e., combinations)
@tailrec def loop(acc: List[VarAssignment], in: List[List[VarAssignment]]): List[VarAssignment] = {
if (acc.sizeIs > AnalysisBudget.maxDPLLdepth) acc.take(AnalysisBudget.maxDPLLdepth)
else in match {
case vs :: vss => loop(for (map1 <- acc; map2 <- vs) yield map1 ++ map2, vss)
case _ => acc
}
}
expanded match {
case head :: tail =>
val cartesianProd = loop(head, tail)
// add expanded variables
// note that we can just use `++`
// since the Maps have disjoint keySets
for (m <- cartesianProd) yield varAssignment ++ m
case _ => List(varAssignment)
}
}

// return constructor call when the model is a true counter example
// (the variables don't take into account type information derived from other variables,
// so, naively, you might try to construct a counter example like _ :: Nil(_ :: _, _ :: _),
// since we didn't realize the tail of the outer cons was a Nil)
def modelToCounterExample(scrutVar: Var)(varAssignment: Map[Var, (Seq[Const], Seq[Const])]): Option[CounterExample] = {
def modelToCounterExample(scrutVar: Var)(varAssignment: VarAssignment): Option[CounterExample] = {
val strict = !settings.nonStrictPatmatAnalysis.value

// chop a path into a list of symbols
Expand Down Expand Up @@ -919,7 +910,7 @@ trait MatchAnalysis extends MatchApproximation {
}

// slurp in information from other variables
varAssignment.keys.foreach{ v => if (v != scrutVar) VariableAssignment(v) }
varAssignment.keys.toSeq.sortBy(_.toString).foreach(v => if (v != scrutVar) VariableAssignment(v))

// this is the variable we want a counter example for
VariableAssignment(scrutVar).toCounterExample()
Expand Down
Loading

0 comments on commit 6eaf217

Please sign in to comment.