Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement in satask: Reduce the need of costly rcall() on SymPy expression #17379

Merged
merged 11 commits into from Aug 23, 2019
Merged
38 changes: 34 additions & 4 deletions sympy/assumptions/cnf.py
Expand Up @@ -16,10 +16,12 @@ class Literal(object):
"""

def __new__(cls, lit, is_Not=False):
obj = super(Literal, cls).__new__(cls)
if isinstance(lit, Not):
lit = lit.args[0]
is_Not = True
elif isinstance(lit, (AND, OR, Literal)):
return ~lit if is_Not else lit
obj = super(Literal, cls).__new__(cls)
obj.lit = lit
obj.is_Not = is_Not
return obj
Expand All @@ -28,6 +30,16 @@ def __new__(cls, lit, is_Not=False):
def arg(self):
return self.lit

def rcall(self, expr):
if callable(self.lit):
lit = self.lit(expr)
else:
try:
lit = self.lit.apply(expr)
except AttributeError:
lit = self.lit.rcall(expr)
return type(self)(lit, self.is_Not)

def __invert__(self):
is_Not = not self.is_Not
return Literal(self.lit, is_Not)
Expand Down Expand Up @@ -56,6 +68,11 @@ def __init__(self, *args):
def args(self):
return sorted(self._args, key=str)

def rcall(self, expr):
return type(self)(*[arg.rcall(expr)
for arg in self._args
])

def __invert__(self):
return AND(*[~arg for arg in self._args])

Expand Down Expand Up @@ -86,6 +103,11 @@ def __invert__(self):
def args(self):
return sorted(self._args, key=str)

def rcall(self, expr):
return type(self)(*[arg.rcall(expr)
for arg in self._args
])

def __hash__(self):
return hash((type(self).__name__,) + tuple(self.args))

Expand All @@ -104,8 +126,6 @@ def to_NNF(expr):
Generates the Negation Normal Form of any boolean expression in terms
of AND, OR, and Literal objects.
"""
if not isinstance(expr, BooleanFunction):
return Literal(expr)

if isinstance(expr, Not):
arg = expr.args[0]
Expand Down Expand Up @@ -163,7 +183,7 @@ def to_NNF(expr):
return AND(OR(~L, M), OR(L, R))

else:
raise NotImplementedError('NNF conversion not implemented for %s class' % type(expr).__name__)
return Literal(expr)


def distribute_AND_over_OR(expr):
Expand Down Expand Up @@ -262,6 +282,16 @@ def _not(self):
ll = ll._or(CNF(p))
return ll

def rcall(self, expr):
clause_list = list()
for clause in self.clauses:
lits = [arg.rcall(expr) for arg in clause]
clause_list.append(OR(*lits))
expr = AND(*clause_list)
return distribute_AND_over_OR(expr)



@classmethod
def all_or(cls, *cnfs):
b = cnfs[0].copy()
Expand Down
14 changes: 7 additions & 7 deletions sympy/assumptions/satask.py
Expand Up @@ -99,10 +99,11 @@ def find_symbols(pred):

for expr in exprs:
for fact in fact_registry[expr.func]:
newfact = fact.rcall(expr)
relevant_facts.add(newfact)
newexprs |= set([key.args[0] for key in
newfact.atoms(AppliedPredicate)])
cnf_fact = CNF.to_CNF(fact)
newfact = cnf_fact.rcall(expr)
relevant_facts = relevant_facts._and(newfact)
newexprs |= set([key.args[0] for key in newfact.all_predicates()
if isinstance(key, AppliedPredicate)])

return newexprs - exprs, relevant_facts

Expand All @@ -114,7 +115,7 @@ def get_all_relevant_facts(proposition, assumptions=True,
# we stop getting new things. Hopefully this strategy won't lead to an
# infinite loop in the future.
i = 0
relevant_facts = set()
relevant_facts = CNF()
exprs = None
all_exprs = set()
while exprs != set():
Expand Down Expand Up @@ -152,7 +153,6 @@ def translate_data(data, delta):
else:
ctx = EncodedCNF()

for e in relevant_facts:
ctx.add_prop(e)
ctx.add_from_cnf(relevant_facts)

return ctx
69 changes: 40 additions & 29 deletions sympy/assumptions/sathandlers.py
Expand Up @@ -4,20 +4,19 @@

from sympy.assumptions.ask import Q
from sympy.assumptions.assume import Predicate, AppliedPredicate
from sympy.assumptions.cnf import AND, OR, to_NNF
from sympy.core import (Add, Mul, Pow, Integer, Number, NumberSymbol,)
from sympy.core.compatibility import MutableMapping
from sympy.core.numbers import ImaginaryUnit
from sympy.core.logic import fuzzy_or, fuzzy_and
from sympy.core.rules import Transform
from sympy.core.sympify import _sympify
from sympy.functions.elementary.complexes import Abs
from sympy.logic.boolalg import (Equivalent, Implies, And, Or,
BooleanFunction, Not)
from sympy.logic.boolalg import (Equivalent, Implies, BooleanFunction)
from sympy.matrices.expressions import MatMul

# APIs here may be subject to change

# XXX: Better name?

class UnevaluatedOnFree(BooleanFunction):
"""
Represents a Boolean function that remains unevaluated on free predicates
Expand Down Expand Up @@ -68,13 +67,19 @@ def __new__(cls, arg):
obj.expr = predicate_args.pop()
obj.pred = arg.xreplace(Transform(lambda e: e.func, lambda e:
isinstance(e, AppliedPredicate)))
applied = obj.apply()
applied = obj.apply(obj.expr)
if applied is None:
return obj
return applied

def apply(self):
return
def apply(self, expr=None):
if expr is None:
return
pred = to_NNF(self.pred)
return self._eval_apply(expr, pred)

def _eval_apply(self, expr, pred):
return None


class AllArgs(UnevaluatedOnFree):
Expand All @@ -89,19 +94,20 @@ class AllArgs(UnevaluatedOnFree):

Example
=======

>>> from sympy.assumptions.sathandlers import AllArgs
>>> from sympy import symbols, Q
>>> x, y = symbols('x y')
>>> a = AllArgs(Q.positive | Q.negative)
>>> a
AllArgs(Q.negative | Q.positive)
>>> a.rcall(x*y)
(Q.negative(x) | Q.positive(x)) & (Q.negative(y) | Q.positive(y))
((Literal(Q.negative(x), False) | Literal(Q.positive(x), False)) & (Literal(Q.negative(y), False) | \
Literal(Q.positive(y), False)))

"""

def apply(self):
return And(*[self.pred.rcall(arg) for arg in self.expr.args])
def _eval_apply(self, expr, pred):
return AND(*[pred.rcall(arg) for arg in expr.args])


class AnyArgs(UnevaluatedOnFree):
Expand All @@ -116,19 +122,20 @@ class AnyArgs(UnevaluatedOnFree):

Example
=======

>>> from sympy.assumptions.sathandlers import AnyArgs
>>> from sympy import symbols, Q
>>> x, y = symbols('x y')
>>> a = AnyArgs(Q.positive & Q.negative)
>>> a
AnyArgs(Q.negative & Q.positive)
>>> a.rcall(x*y)
(Q.negative(x) & Q.positive(x)) | (Q.negative(y) & Q.positive(y))
((Literal(Q.negative(x), False) & Literal(Q.positive(x), False)) | (Literal(Q.negative(y), False) & \
Literal(Q.positive(y), False)))

"""

def apply(self):
return Or(*[self.pred.rcall(arg) for arg in self.expr.args])
def _eval_apply(self, expr, pred):
return OR(*[pred.rcall(arg) for arg in expr.args])


class ExactlyOneArg(UnevaluatedOnFree):
Expand All @@ -144,25 +151,26 @@ class ExactlyOneArg(UnevaluatedOnFree):

Example
=======

>>> from sympy.assumptions.sathandlers import ExactlyOneArg
>>> from sympy import symbols, Q
>>> x, y = symbols('x y')
>>> a = ExactlyOneArg(Q.positive)
>>> a
ExactlyOneArg(Q.positive)
>>> a.rcall(x*y)
(Q.positive(x) & ~Q.positive(y)) | (Q.positive(y) & ~Q.positive(x))
((Literal(Q.positive(x), False) & Literal(Q.positive(y), True)) | (Literal(Q.positive(x), True) & \
Literal(Q.positive(y), False)))

"""
def apply(self):
expr = self.expr
pred = self.pred

def _eval_apply(self, expr, pred):
pred_args = [pred.rcall(arg) for arg in expr.args]
# Technically this is xor, but if one term in the disjunction is true,
# it is not possible for the remainder to be true, so regular or is
# fine in this case.
return Or(*[And(pred_args[i], *map(Not, pred_args[:i] +
pred_args[i+1:])) for i in range(len(pred_args))])
res = OR(*[AND(pred_args[i], *[~lit for lit in pred_args[:i] +
pred_args[i+1:]]) for i in range(len(pred_args))])
return res
# Note: this is the equivalent cnf form. The above is more efficient
# as the first argument of an implication, since p >> q is the same as
# q | ~p, so the the ~ will convert the Or to and, and one just needs
Expand Down Expand Up @@ -226,15 +234,18 @@ def evaluate_old_assump(pred):


class CheckOldAssump(UnevaluatedOnFree):
def apply(self):
return Equivalent(self.args[0], evaluate_old_assump(self.args[0]))
def apply(self, expr=None, is_Not=False):
arg = self.args[0](expr) if callable(self.args[0]) else self.args[0]
res = Equivalent(arg, evaluate_old_assump(arg))
return to_NNF(res)


class CheckIsPrime(UnevaluatedOnFree):
def apply(self):
def apply(self, expr=None, is_Not=False):
from sympy import isprime
return Equivalent(self.args[0], isprime(self.expr))

arg = self.args[0](expr) if callable(self.args[0]) else self.args[0]
res = Equivalent(arg, isprime(expr))
return to_NNF(res)

class CustomLambda(object):
"""
Expand All @@ -245,8 +256,8 @@ class CustomLambda(object):
def __init__(self, lamda):
self.lamda = lamda

def rcall(self, *args):
return self.lamda(*args)
def apply(self, *args):
return to_NNF(self.lamda(*args))


class ClassFactRegistry(MutableMapping):
Expand Down
2 changes: 1 addition & 1 deletion sympy/assumptions/tests/test_query.py
Expand Up @@ -568,6 +568,7 @@ def test_I():
assert ask(Q.real(z)) is True



def test_bounded():
x, y, z = symbols('x,y,z')
assert ask(Q.finite(x)) is None
Expand Down Expand Up @@ -798,7 +799,6 @@ def test_bounded():
Q.finite(a), Q.finite(x) & ~Q.finite(y) & Q.positive(z)) is None
assert ask(Q.finite(a), Q.finite(x) & Q.positive(y) &
~Q.finite(y) & Q.positive(z) & ~Q.finite(z)) is False

assert ask(Q.finite(a), Q.finite(x) &
Q.positive(y) & ~Q.finite(y) & Q.negative(z)) is None
assert ask(
Expand Down
29 changes: 15 additions & 14 deletions sympy/assumptions/tests/test_sathandlers.py
@@ -1,4 +1,5 @@
from sympy import Mul, Basic, Q, Expr, And, symbols, Equivalent, Or
from sympy.assumptions.cnf import to_NNF

from sympy.assumptions.sathandlers import (ClassFactRegistry, AllArgs,
UnevaluatedOnFree, AnyArgs, CheckOldAssump, ExactlyOneArg)
Expand Down Expand Up @@ -39,7 +40,7 @@ def test_UnevaluatedOnFree():
Q.negative(y)))

class MyUnevaluatedOnFree(UnevaluatedOnFree):
def apply(self):
def apply(self, expr=None):
return self.args[0]

a = MyUnevaluatedOnFree(Q.positive)
Expand All @@ -56,15 +57,15 @@ def apply(self):
def test_AllArgs():
a = AllArgs(Q.zero)
b = AllArgs(Q.positive | Q.negative)
assert a.rcall(x*y) == And(Q.zero(x), Q.zero(y))
assert b.rcall(x*y) == And(Q.positive(x) | Q.negative(x), Q.positive(y) | Q.negative(y))
assert a.rcall(x*y) == to_NNF(And(Q.zero(x), Q.zero(y)))
assert b.rcall(x*y) == to_NNF(And(Q.positive(x) | Q.negative(x), Q.positive(y) | Q.negative(y)))


def test_AnyArgs():
a = AnyArgs(Q.zero)
b = AnyArgs(Q.positive & Q.negative)
assert a.rcall(x*y) == Or(Q.zero(x), Q.zero(y))
assert b.rcall(x*y) == Or(Q.positive(x) & Q.negative(x), Q.positive(y) & Q.negative(y))
assert a.rcall(x*y) == to_NNF(Or(Q.zero(x), Q.zero(y)))
assert b.rcall(x*y) == to_NNF(Or(Q.positive(x) & Q.negative(x), Q.positive(y) & Q.negative(y)))


def test_CheckOldAssump():
Expand All @@ -90,19 +91,19 @@ def _eval_is_extended_negative(self):
# We can't say if it's positive or negative in the old assumptions without
# bounded. Remember, True means "no new knowledge", and
# Q.positive(t2) means "t2 is positive."
assert CheckOldAssump(Q.positive(t1)) == True
assert CheckOldAssump(Q.negative(t1)) == ~Q.negative(t1)
assert CheckOldAssump(Q.positive(t1)) == to_NNF(True)
assert CheckOldAssump(Q.negative(t1)) == to_NNF(~Q.negative(t1))

assert CheckOldAssump(Q.positive(t2)) == Q.positive(t2)
assert CheckOldAssump(Q.negative(t2)) == ~Q.negative(t2)
assert CheckOldAssump(Q.positive(t2)) == to_NNF(Q.positive(t2))
assert CheckOldAssump(Q.negative(t2)) == to_NNF(~Q.negative(t2))


def test_ExactlyOneArg():
a = ExactlyOneArg(Q.zero)
b = ExactlyOneArg(Q.positive | Q.negative)
assert a.rcall(x*y) == Or(Q.zero(x) & ~Q.zero(y), Q.zero(y) & ~Q.zero(x))
assert a.rcall(x*y*z) == Or(Q.zero(x) & ~Q.zero(y) & ~Q.zero(z), Q.zero(y)
& ~Q.zero(x) & ~Q.zero(z), Q.zero(z) & ~Q.zero(x) & ~Q.zero(y))
assert b.rcall(x*y) == Or((Q.positive(x) | Q.negative(x)) &
assert a.rcall(x*y) == to_NNF(Or(Q.zero(x) & ~Q.zero(y), Q.zero(y) & ~Q.zero(x)))
assert a.rcall(x*y*z) == to_NNF(Or(Q.zero(x) & ~Q.zero(y) & ~Q.zero(z), Q.zero(y)
& ~Q.zero(x) & ~Q.zero(z), Q.zero(z) & ~Q.zero(x) & ~Q.zero(y)))
assert b.rcall(x*y) == to_NNF(Or((Q.positive(x) | Q.negative(x)) &
~(Q.positive(y) | Q.negative(y)), (Q.positive(y) | Q.negative(y)) &
~(Q.positive(x) | Q.negative(x)))
~(Q.positive(x) | Q.negative(x))))