Skip to content

Commit

Permalink
proper optimization of OrElseValue
Browse files Browse the repository at this point in the history
- bugfix: simplifyAssociative didn't simplify arguments
- bugfix: == and != don't work as patterns
- improvement: don't use DeMorgan's law for now
  • Loading branch information
stefanbohne committed Feb 2, 2017
1 parent 1687c2c commit a5f6cb1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 40 deletions.
99 changes: 61 additions & 38 deletions org.bynar.versailles/src/org/bynar/versailles/Simplifier.scala
Expand Up @@ -106,6 +106,7 @@ class Simplifier {
case b2 =>
val (s2, ctx3) = simplify1(s, forward, ctx2, leaveDefs)
(b2, s2) match {
case (_, s2@Undefined()) => (s2, ctx3)
case (Let(BooleanLiteral(true), c), v: OrElseValue) =>
simplify1(v.copy(expr.copy(b2, v.first), expr.copy(b2.deepCopy(), v.second)), forward, ctx3, leaveDefs)
case (b2, Block(b3, s2)) =>
Expand Down Expand Up @@ -140,20 +141,19 @@ class Simplifier {
context,
leaveDefs)
case expr@OrElseValue(a, b) =>
val (a2, ctx2) = simplify1(a, forward, context, leaveDefs)
val (b2, ctx3) = simplify1(b, forward, ctx2, leaveDefs)
(a2, b2) match {
case (Undefined(), b2) => (b2, ctx3)
case (a2, Undefined()) => (a2, ctx3)
case (a2, _) if isDefined(a2) => (a2, ctx3)
case (a2@Block(Let(BooleanLiteral(true), c1), v1), Block(Let(BooleanLiteral(true), c2), v2)) if isDefined(v1) && c1 == c2 =>
(a2, ctx3)
case (a2@Block(Let(BooleanLiteral(true), c1), v1), OrElseValue(Block(Let(BooleanLiteral(true), c2), v2), v3)) if isDefined(v1) && c1 == c2 =>
(expr.copy(a2, v3), ctx3)
case (a2@OrElseValue(v0, Block(Let(BooleanLiteral(true), c1), v1)), Block(Let(BooleanLiteral(true), c2), v2)) if isDefined(v1) && c1 == c2 =>
(a2, ctx3)
case (a2, b2) => (expr.copy(a2, b2), ctx3)
}
simplifyAssociative(new Pattern2[Expression, Expression, Expression] {
def unapply(e: Expression) = e match { case OrElseValue(a, b) => Some(a, b) case _ => None }
def apply(a: Expression, b: Expression) = OrElseValue(a, b)
}, false, false, expr, forward, context, leaveDefs){
case (a, Undefined()) => Seq(a)
case (Undefined(), b) => Seq(b)
case (a, _) if isDefined(a) => Seq(a)
case (a@Block(Let(BooleanLiteral(true), ac), av), b) if av == b => Seq(a)
case (a@Block(Let(BooleanLiteral(true), ac), _), Block(Let(BooleanLiteral(true), bc), _)) if ac == bc =>
Seq(a)
case (Block(Let(BooleanLiteral(true), ac), av), Block(Let(BooleanLiteral(true), bc), bv)) if av == bv =>
Seq(Block(Let(BooleanLiteral(true), ac || bc), av))
}{ _ => () }

case expr@Application(f, a) =>
val (a1, ctx1) = simplify1(a, forward, context, leaveDefs)
Expand All @@ -167,24 +167,31 @@ class Simplifier {
case expr => (expr, context)
}

def simplifyAssociative[S](function: Literal, symmetric: Boolean, reverse: Boolean, expr: Expression, forward: Boolean, context: Map[VariableIdentity, Expression], leaveDefs: Boolean)(simplifier: PartialFunction[(Expression, Expression), Seq[Expression]])(order: Expression => S)(implicit ord: math.Ordering[S]): (Expression, Map[VariableIdentity, Expression]) = {
trait Pattern2[A, B, C] extends Function2[A, B, C] {
def unapply(c: C): Option[(A, B)]
}
def simplifyAssociative[S](pattern: Pattern2[Expression, Expression, Expression], symmetric: Boolean, reverse: Boolean, expr: Expression, forward: Boolean, context: Map[VariableIdentity, Expression], leaveDefs: Boolean)(simplifier: PartialFunction[(Expression, Expression), Seq[Expression]])(order: Expression => S)(implicit ord: math.Ordering[S]): (Expression, Map[VariableIdentity, Expression]) = {
val result = mutable.Buffer[Expression]()
def collectSymmetricArguments(expr: Expression): Unit =
var ctx = context
def collectAssociativeArguments(expr: Expression): Unit =
expr match {
case Application(Application(`function`, a), b) =>
case pattern(a, b) =>
if (reverse)
collectSymmetricArguments(b)
collectSymmetricArguments(a)
collectAssociativeArguments(b)
collectAssociativeArguments(a)
if (!reverse)
collectSymmetricArguments(b)
case expr => result += expr
collectAssociativeArguments(b)
case expr =>
val (e, ctx2) = simplify1(expr, forward, ctx, leaveDefs)
result += e
ctx = ctx2
}
collectSymmetricArguments(expr)
collectAssociativeArguments(expr)
var found = false
var i = 0
while (i < result.size) {
var j = if (symmetric) 0 else i + 1
while (i < result.size && j < result.size) {
while (i < result.size && j < result.size && (symmetric || j < i + 2)) {
if (i != j)
simplifier.lift(result(i), result(j)) match {
case Some(m) =>
Expand All @@ -199,13 +206,13 @@ class Simplifier {
i += 1
}
val result2 = if (reverse)
result.sortBy(order).reduceLeft((a, b) => function.deepCopy()(b)(a))
result.sortBy(order).reduceLeft((a, b) => pattern(b, a))
else
result.sortBy(order).reduceLeft((a, b) => function.deepCopy()(a)(b))
result.sortBy(order).reduceLeft((a, b) => pattern(a, b))
if (!found)
(result2, context)
(result2, ctx)
else
simplify1(result2, forward, context, leaveDefs)
simplify1(result2, forward, ctx, leaveDefs)
}


Expand Down Expand Up @@ -336,20 +343,20 @@ class Simplifier {
// (a, ctx2)
// case (Application(Or(), a), Application(Application(Or(), b), c)) =>
// (Or()(Or()(a)(b))(c), ctx2)
// case (and@Application(And(), a), or1@Application(or2@Application(Or(), b), c)) =>
// simplify1(or1.copy(or2.copy(argument = Application(and, b)), Application(and.deepCopy(), c)), forward, ctx2, leaveDefs)
// case (and@Application(And(), or1@Application(or2@Application(Or(), a), b)), c) =>
// simplify1(or1.copy(or2.copy(argument = and.copy(argument=a)(c)), and.copy(and.function.deepCopy(), b)(c)), forward, ctx2, leaveDefs)
case (and@Application(And(), a), or1@Application(or2@Application(Or(), b), c)) =>
simplify1(or1.copy(or2.copy(argument = Application(and, b)), Application(and.deepCopy(), c)), forward, ctx2, leaveDefs)
case (and@Application(And(), or1@Application(or2@Application(Or(), a), b)), c) =>
simplify1(or1.copy(or2.copy(argument = and.copy(argument=a)(c)), and.copy(and.function.deepCopy(), b)(c)), forward, ctx2, leaveDefs)
case (Not(), BooleanLiteral(b)) =>
(BooleanLiteral(!b), ctx2)
case (Not(), a && b) =>
simplify1(!a || !b, forward, ctx2, leaveDefs)
case (Not(), a || b) =>
simplify1(!a && !b, forward, ctx2, leaveDefs)
case (Not(), a == b) =>
simplify1(a != b, forward, ctx2, leaveDefs)
simplify1(a neq b, forward, ctx2, leaveDefs)
case (Not(), a != b) =>
simplify1(a == b, forward, ctx2, leaveDefs)
simplify1(a equ b, forward, ctx2, leaveDefs)
case (Not(), a < b) =>
simplify1(a >= b, forward, ctx2, leaveDefs)
case (Not(), a <= b) =>
Expand Down Expand Up @@ -461,7 +468,10 @@ class Simplifier {
simplify1(b.copy(scope=Application(b.scope, a)), forward, ctx2, leaveDefs)

case (Application(op@Plus(), _), _) =>
simplifyAssociative(op, true, true, app, forward, ctx2, leaveDefs){
simplifyAssociative(new Pattern2[Expression, Expression, Expression] {
def unapply(e: Expression) = e match { case Application(Application(Plus(), b), a) => Some(a, b) case _ => None }
def apply(a: Expression, b: Expression) = Application(Application(op, b), a)
}, true, true, app, forward, ctx2, leaveDefs){
case (NumberLiteral(a), NumberLiteral(b)) => Seq(NumberLiteral(a + b))
case (a, NumberLiteral(z)) if z == BigDecimal(0) => Seq(a)
case (a, b) if a == b => Seq(a * 2)
Expand All @@ -471,7 +481,10 @@ class Simplifier {
Seq(a * lit.copy(k + k2))
}{ _.isInstanceOf[NumberLiteral] }
case (Application(op@Times(), _), _) =>
simplifyAssociative(op, true, true, app, forward, ctx2, leaveDefs){
simplifyAssociative(new Pattern2[Expression, Expression, Expression] {
def unapply(e: Expression) = e match { case Application(Application(Times(), b), a) => Some(a, b) case _ => None }
def apply(a: Expression, b: Expression) = Application(Application(op, b), a)
}, true, true, app, forward, ctx2, leaveDefs){
case (a, b@NumberLiteral(z)) if z == BigDecimal(0) => Seq(b)
case (a@NumberLiteral(z), b) if z == BigDecimal(0) => Seq(a)
case (a, NumberLiteral(o)) if o == BigDecimal(1) => Seq(a)
Expand All @@ -484,7 +497,10 @@ class Simplifier {
Seq(pow.copy(pow2.copy(argument = lit.copy(k + k2))))
}{ _.isInstanceOf[NumberLiteral] }
case (Application(op@And(), _), _) =>
simplifyAssociative(op, true, false, app, forward, ctx2, leaveDefs){
simplifyAssociative(new Pattern2[Expression, Expression, Expression] {
def unapply(e: Expression) = e match { case Application(Application(And(), a), b) => Some(a, b) case _ => None }
def apply(a: Expression, b: Expression) = Application(Application(op, a), b)
}, true, false, app, forward, ctx2, leaveDefs){
case (BooleanLiteral(a), BooleanLiteral(b)) => Seq(BooleanLiteral(a && b))
case (a, b@BooleanLiteral(false)) => Seq(b)
case (a, b@BooleanLiteral(true)) => Seq(a)
Expand All @@ -500,10 +516,17 @@ class Simplifier {
case (l@(a < NumberLiteral(b)), c < NumberLiteral(d)) if a == c && b <= d => Seq(l)
case (l@(a <= NumberLiteral(b)), c <= NumberLiteral(d)) if a == c && b <= d => Seq(l)
case (l@(a < NumberLiteral(b)), c <= NumberLiteral(d)) if a == c && b <= d => Seq(l)
case (a < NumberLiteral(b), c <= NumberLiteral(d)) if a == c && b >= d => Seq(a < NumberLiteral(d))
case (l@(a <= NumberLiteral(b)), c < NumberLiteral(d)) if a == c && b < d => Seq(l)
case (l@(a > NumberLiteral(b)), c > NumberLiteral(d)) if a == c && b >= d => Seq(l)
case (l@(a >= NumberLiteral(b)), c >= NumberLiteral(d)) if a == c && b >= d => Seq(l)
case (l@(a > NumberLiteral(b)), c >= NumberLiteral(d)) if a == c && b >= d => Seq(l)
case (l@(a >= NumberLiteral(b)), c > NumberLiteral(d)) if a == c && b > d => Seq(l)
}{ _ => () }
case (Application(op@Or(), _), _) =>
simplifyAssociative(op, true, false, app, forward, ctx2, leaveDefs){
simplifyAssociative(new Pattern2[Expression, Expression, Expression] {
def unapply(e: Expression) = e match { case Application(Application(Or(), a), b) => Some(a, b) case _ => None }
def apply(a: Expression, b: Expression) = Application(Application(op, a), b)
}, true, false, app, forward, ctx2, leaveDefs){
case (BooleanLiteral(a), BooleanLiteral(b)) => Seq(BooleanLiteral(a || b))
case (a, b@BooleanLiteral(false)) => Seq(a)
case (a, b@BooleanLiteral(true)) => Seq(b)
Expand Down
Expand Up @@ -15,9 +15,9 @@ object TermImplicits {
Application(Application(Times(), that), term)
def /(that: Expression) =
Application(Application(Divide(), that), term)
def ==(that: Expression) =
def equ(that: Expression) =
Application(Application(Equals(), term), that)
def !=(that: Expression) =
def neq(that: Expression) =
Application(Application(NotEquals(), term), that)
def <(that: Expression) =
Application(Application(Less(), term), that)
Expand Down

0 comments on commit a5f6cb1

Please sign in to comment.