# Sympy Expansion with First Order Logic

In [1]:
from sympy import cacheit, Symbol
from sympy.core.function import Function, UndefinedFunction, Application
import itertools

class FOLConstant(Symbol):
    is_Constant = True


class FOLVariable(Symbol):
    is_Constant = False


class FOLFunction(Function):
    def __new__(cls, *args, **options):
        if cls is FOLFunction:
            options['bases'] = (AppliedFOLFunction,)
            # a predicate is considered as an UndefinedFunction
            res = UndefinedFunction(*args, **options)
            return res
        return super(FOLFunction, cls).__new__(cls, *args, **options)

    def ground(self, var_name, constant):
        fun = FOLFunction(self.name)
        variables = []
        for i in range(0, len(self.args)):
          if self.args[i].is_Symbol and self.args[i].name == var_name:
            variables.append(constant)
          elif self.args[i].is_Symbol:
            variables.append(self.args[i])
          else:
            variables.append(self.args[i].ground(var_name, constant))
        return fun(*variables)


class AppliedFOLFunction(FOLFunction):
    pass


class Formula(Function):
    is_Predicate = False
    is_Quantifier = False

    # called when operation is not allowed (as in sympy library)
    def _noop(self, other=None):
        raise TypeError('First order logic term not allowed in this context.')

    # redefinition of and
    def __and__(self, other):
        """Overloading for & operator"""
        return FOLAnd(self, other)

    __rand__ = __and__

    # redefinition of or
    def __or__(self, other):
        """Overloading for |"""
        return FOLOr(self, other)

    __ror__ = __or__

    # redefinition of not
    def __invert__(self):
        """Overloading for ~"""
        return FOLNot(self)

    # redefinition of implies
    def __rshift__(self, other):
        """Overloading for >>"""
        return FOLImplies(self, other)

    # redefinition of is implied
    def __lshift__(self, other):
        """Overloading for <<"""
        return FOLImplies(other, self)

    __rrshift__ = __lshift__
    __rlshift__ = __rshift__

    # redefinition of is xor
    def __xor__(self, other):
        return FOLOr(FOLAnd(self, FOLNot(other)), FOLAnd(FOLNot(self), other))

    __rxor__ = __xor__

    __add__ = _noop
    __radd__ = _noop
    __sub__ = _noop
    __rsub__ = _noop
    __mul__ = _noop
    __rmul__ = _noop
    __pow__ = _noop
    __rpow__ = _noop
    __rdiv__ = _noop
    __truediv__ = _noop
    __div__ = _noop
    __rtruediv__ = _noop
    __mod__ = _noop
    __rmod__ = _noop
    _eval_power = _noop

    def _apply_not(self):
        raise NotImplementedError("Implementation Error")

    # negation normal form
    def to_nnf(self):
        raise NotImplementedError("Implementation Error")

    # skolem normal form
    def to_snf(self, skolem_function_symbol= None):
        from sympy.first_order_logic.algorithm import to_snf
        return to_snf(self, skolem_function_symbol)

    # prenex normal form
    def to_pnf(self):
        from sympy.first_order_logic.algorithm import to_pnf
        return to_pnf(self)

    def to_list_clauses(self):
        from sympy.first_order_logic.algorithm import to_list_clauses
        return to_list_clauses(self)

    def is_satisfiable(self, algorithm=None):
        f = self.to_nnf()
        f = f.to_pnf()
        f = f.to_snf()
        if algorithm is None:
            from sympy.first_order_logic.algorithm import Resolution
            algorithm = Resolution()
        c = f.to_list_clauses()
        return algorithm(*c)


"""
def func(self):
        return self.__class__
"""

class FOLLogicOperator(Formula):
    def to_nnf(self):
        # every argument of the formula must have its own to_nnf() method
        return self.func(*[a.to_nnf() for a in self.args])


class FOLAnd(FOLLogicOperator):
    nargs = 2

    # not(A and B) = not A or not B
    def _apply_not(self):
        return FOLOr(self.args[0]._apply_not(), self.args[1]._apply_not())
    
    def _latex(self,printer,*args):        
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == FOLNot:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == FOLNot:
                return r'%s \wedge %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \wedge (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == FOLNot:
            return r'(%s) \wedge %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

        return r'(%s) \wedge (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
    
    def ground(self, var_name, constant):
        return FOLAnd(self.args[0].ground(var_name, constant),self.args[1].ground(var_name, constant))


class FOLOr(FOLLogicOperator):
    nargs = 2
    
    # not(A or B) = not A and not B
    def _apply_not(self):
        return FOLAnd(self.args[0]._apply_not(), self.args[1]._apply_not())
    
    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == FOLNot:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == FOLNot:
                return r'%s \vee %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \vee (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == FOLNot:
            return r'(%s) \vee %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

        return r'(%s) \vee (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return FOLOr(self.args[0].ground(var_name, constant),self.args[1].ground(var_name, constant))
    
    
class FOLNot(FOLLogicOperator):
    nargs = 1

    # in NNF the logic operator not must only be in front of predicates
    def to_nnf(self):
        expr = self.args[0]
        if (expr.is_Predicate):
            return self

        res = self.args[0]._apply_not()
        return res.to_nnf()
    
    # not (not A) = A
    def _apply_not(self):
        return self.args[0]

    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == FOLNot:
            return r'\lnot %s'%tuple([self.args[0]._latex(printer,*args)])
        return r'\lnot (%s)'%tuple([self.args[0]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return FOLNot(self.args[0].ground(var_name, constant))


class FOLImplies(FOLLogicOperator):
    nargs = 2

    # A -> B = not A or B
    def to_nnf(self):
        return FOLOr(self.args[0]._apply_not().to_nnf(), self.args[1].to_nnf())

    # not (A -> B) = A and not B
    def _apply_not(self):
        return FOLAnd(self.args[0], self.args[1]._apply_not())
    
    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == FOLNot:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == FOLNot:
                return r'%s \rightarrow %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \rightarrow (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == FOLNot:
            return r'(%s) \rightarrow %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        return r'(%s) \rightarrow (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return FOLImplies(self.args[0].ground(var_name, constant),self.args[1].ground(var_name, constant))


class Predicate(Formula):
    is_Atom = True
    is_Predicate = True
    is_Equality = False

    @cacheit
    def __new__(cls, *args, **options):
        if cls is Predicate:
            options['bases'] = (AppliedPredicate,)
            # a predicate is considered as an UndefinedFunction
            res = UndefinedFunction(*args, **options)
            return res
        return super(Predicate, cls).__new__(cls, *args, **options)

    def to_nnf(self):
        return self

    def _apply_not(self):
        return FOLNot(self)

    def _latex(self,printer,*args):
        return r'%s'% str(self)
    
    def ground(self, var_name, constant):
        vec = []
        for i in range(0,len(self.args)):
          if self.args[i].name == var_name:
            vec.append(i)
        pred = Predicate(self.name)

        variables = []
        for i in range(0,len(self.args)):
          if i in vec:
            variables.append(constant)
          else:
            if self.args[i].is_Symbol:
              variables.append(self.args[i])
            else:
              variables.append(self.args[i].ground(var_name,constant))
        return pred(*variables)


class AppliedPredicate(Predicate):
    pass


class Quantifier(Formula):
    nargs = 1
    is_Quantifier = True
    is_Forall = False
    is_Exists = False
    _symbol = None

    def __new__(cls, *args, **options):
        if cls in (Forall, Exists):
            argslist = list(args)
            symbol = argslist[0] # variable which is not free
            argslist[0] = cls.name + symbol.name
            args = tuple(argslist)
            options['bases'] = (cls._getAppliedClass(),)
            res = UndefinedFunction(*args, **options)
            res._symbol = symbol
            return res
        return super(Quantifier, cls).__new__(cls, *args, **options)

    @property
    def symbol(self):
        return self._symbol

    @property
    def formula(self):
        return self.args[0]

    def to_nnf(self):
        return self.func(self.formula.to_nnf())


class Forall(Quantifier):
    is_Forall = True
    name = 'Forall_'
    
    # not forall = exist not
    def _apply_not(self):
        exists = Exists(self.symbol)
        return exists(self.formula._apply_not())

    @staticmethod
    def _getAppliedClass():
        return AppliedForall
    
    def _latex(self,printer,*args):
        if self.formula.is_Predicate or self.args[0].func == FOLNot:
            return r'\forall %s %s'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
        return r'\forall %s (%s)'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
    
    def ground(self, set_const):
        formula = self.formula.ground(self.symbol.name, set_const[0])
        for i in range(1, len(set_const)):
          formula = FOLAnd(formula, self.formula.ground(self.symbol.name, set_const[i]))
        return formula

class AppliedForall(Forall):
    pass


class Exists(Quantifier):
    is_Exists = True
    name = 'Exists_'

    # not exists = forall not
    def _apply_not(self):
        forall = Forall(self.symbol)
        return forall(self.formula._apply_not())

    @staticmethod
    def _getAppliedClass():
        return AppliedExists
    
    def _latex(self,printer,*args):
        if self.formula.is_Predicate or self.args[0].func == FOLNot:
            return r'\exists %s %s'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
        return r'\exists %s (%s)'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
    
    def ground(self, set_const):
        formula = self.formula.ground(self.symbol.name, set_const[0])
        for i in range(1, len(set_const)):
          formula = FOLOr(formula, self.formula.ground(self.symbol.name, set_const[i]))
        return formula

class AppliedExists(Exists):
    pass


def formula_to_list(formula, delimiter):
    if isinstance(formula, delimiter):
        l = []
        l.extend(formula_to_list(formula.args[0], delimiter))
        l.extend(formula_to_list(formula.args[1], delimiter))
        return l
    else:
        return [formula]

def remove_quantifiers(formula):
    quantifiers = []
    if formula.is_Predicate:
        return formula, quantifiers
    func = formula.func
    if formula.is_Quantifier:
        quantifiers.append(func)
        ret, q = remove_quantifiers(formula.formula)
        quantifiers.extend(q)
        return ret, quantifiers
    else:
        new_args = []
        args = formula.args
        for arg in args:
            f, q = remove_quantifiers(arg)
            new_args.append(f)
            quantifiers.extend(q)
        return func(*new_args), quantifiers


class Clause(Application):
    _formula = None

    def __new__(cls, *args, **kwargs):
        args_list = list(args)
        if([] in args_list):
            args_list.remove([])
        if len(args_list) == 1 and not args_list[0].is_Predicate and args_list[0].func is not FOLNot:
            f,_ = remove_quantifiers(args_list[0])
            args = formula_to_list(f, FOLOr)
        return super(Clause, cls).__new__(cls, *args, **kwargs)

    def substitute(self, sub):
        f = self.formula.xreplace(sub)
        return Clause(f)


    @property
    def list(self):
        return list(self.args).copy()

    @property
    def formula(self):
        if self._formula is None:
            self._to_formula()

        return self._formula

    def _to_formula(self):
        f = None
        for a in self.args:
            if f is None:
                f = a
            else:
                f = FOLOr(f, a)

        for s in self._get_symbols(f).keys():
            f = Forall(s)(f)

        self._formula = f

    @classmethod
    def _get_symbols(cls, e):
        if e.func is Symbol and not e.is_Constant:
            return {e: True}

        dic = {}
        for a in e.args:
            dic.update(cls._get_symbols(a))
        return dic

# FOL Algorithms

In [2]:
# ---------
# GROUNDING
# ---------

def to_pnf(formula):
    # Assume to have a different symbol for each quantifier
    pnf_formula, quantifiers = remove_quantifiers(formula)
    quantifiers.reverse()
    for q in quantifiers:
        pnf_formula = q(pnf_formula)
    return pnf_formula


def grounding(formula, constants_set):
    # Assume to have the formula in NNF and PNF
    formula, quantifiers = remove_quantifiers(formula)
    for quant in quantifiers:
      formula = quant(formula).ground(constants_set)
    return formula


# -----------
# UNIFICATION
# -----------

def unification(literal1, literal2, check_equality=False):
    if literal1.func is not literal2.func:
        return {}
    l1_args = literal1.args
    l2_args = literal2.args
    if len(l1_args) != len(l2_args):
      return {}
    sub = {}
    for i in range(0, len(l1_args)):
        s = _unification_term(l1_args[i], l2_args[i], sub)
        if s is None:
            return {}
        else:
            sub.update(s)
    return sub


def _unification_term(t1, t2, sub):
    if t1 in sub.keys():
      t1 = sub.get(t1)
    if t2 in sub.keys():
      t2 = sub.get(t2)
    if t1.is_Symbol and t2.is_Symbol and t1.name == t2.name:
        return sub
    if t2.is_Symbol and not t2.is_Constant:
        t3 = t1
        t1 = t2
        t2 = t3
    if t1.is_Symbol and not t1.is_Constant:
        if t2.is_Symbol or not _contains_term_variable(t2, t1):
            sub.update({t1: t2})
            return sub
    if t1.is_Function and t2.is_Function:
        if t1.func is t2.func:
            t1_args = t1.args
            t2_args = t2.args
            for i in range(0, len(t1_args)):
                s = _unification_term(t1_args[i], t2_args[i], sub)
                if s is None:
                    return None
                else:
                    sub.update(s)
            return sub
    return None


def _contains_term_variable(term, var):
    if term.is_Symbol:
        if (term.is_Constant):
            return False
        else:
            return term is var
    res = False
    for a in term.args:
        res = res | _contains_term_variable(a, var)
    return res


# ----------------------
# BINARY RESOLUTION RULE
# ----------------------


def binary_res(clause1, clause2):
    cl1 = clause1.list
    cl2 = clause2.list
    sub = None

    for l1, l2 in itertools.product(cl1, cl2):
        if l1.func is FOLNot and issubclass(l2.func, Predicate):
            sub = unification(l1.args[0], l2)

        if l2.func is FOLNot and issubclass(l1.func, Predicate):
            sub = unification(l2.args[0], l1)

        if sub is not None:
            cl1.remove(l1)
            cl2.remove(l2)
            break

    if sub is None:
        return None

    if (len(cl1) == 0 and len(cl2) == 0):
        return False

    cl1.extend(cl2)

    return Clause(*cl1).substitute(sub)


In [3]:
from sympy.core import Function, Symbol

x = FOLVariable('x')
y = FOLVariable('y')
z = FOLVariable('z')

a = FOLConstant('a')
b = FOLConstant('b')
c = FOLConstant('c')
d = FOLConstant('d')

f = FOLFunction('f')
g = FOLFunction('g')
h = FOLFunction('h')

P = Predicate('P')
Q = Predicate('Q')
R = Predicate('R')


# Grounding: Simple Example

In [4]:
formula = ~Forall(x) (P(x) & Exists(y)(P(y)))
formula

FOLNot(Forall_x(FOLAnd(P(x), Exists_y(P(y)))))

Negation Normal Form


In [5]:
formula.nnf = formula.to_nnf()
formula.nnf

Exists_x(FOLOr(FOLNot(P(x)), Forall_y(FOLNot(P(y)))))

Prenex Normal Form

In [6]:
formula.pnf = to_pnf(formula.nnf)
formula.pnf

Exists_x(Forall_y(FOLOr(FOLNot(P(x)), FOLNot(P(y)))))

Grounding

In [7]:
Delta = [a,b]
grounding(formula.pnf,Delta)

FOLAnd(FOLOr(FOLOr(FOLNot(P(a)), FOLNot(P(a))), FOLOr(FOLNot(P(b)), FOLNot(P(a)))), FOLOr(FOLOr(FOLNot(P(a)), FOLNot(P(b))), FOLOr(FOLNot(P(b)), FOLNot(P(b)))))

# Grounding: More Complex Example

In [8]:
formula = ~Forall(x) (P(x) & (Exists (y) (~P(x,y) & Q(f(y))) >> Exists (z) (~P(x) & Q(f(x),a,z))))
formula

FOLNot(Forall_x(FOLAnd(P(x), FOLImplies(Exists_y(FOLAnd(FOLNot(P(x, y)), Q(f(y)))), Exists_z(FOLAnd(FOLNot(P(x)), Q(f(x), a, z)))))))

Negation Normal Form

In [9]:
formula.nnf = formula.to_nnf()
formula.nnf

Exists_x(FOLOr(FOLNot(P(x)), FOLAnd(Exists_y(FOLAnd(FOLNot(P(x, y)), Q(f(y)))), Forall_z(FOLOr(P(x), FOLNot(Q(f(x), a, z)))))))

Prenex Normal Form

In [10]:
formula.pnf = to_pnf(formula.nnf)
formula.pnf

Exists_x(Exists_y(Forall_z(FOLOr(FOLNot(P(x)), FOLAnd(FOLAnd(FOLNot(P(x, y)), Q(f(y))), FOLOr(P(x), FOLNot(Q(f(x), a, z))))))))

Grounding

In [11]:
grounding(formula.pnf,[a,b])

FOLAnd(FOLOr(FOLOr(FOLOr(FOLNot(P(a)), FOLAnd(FOLAnd(FOLNot(P(a, a)), Q(f(a))), FOLOr(P(a), FOLNot(Q(f(a), a, a))))), FOLOr(FOLNot(P(b)), FOLAnd(FOLAnd(FOLNot(P(b, a)), Q(f(a))), FOLOr(P(b), FOLNot(Q(f(b), a, a)))))), FOLOr(FOLOr(FOLNot(P(a)), FOLAnd(FOLAnd(FOLNot(P(a, b)), Q(f(b))), FOLOr(P(a), FOLNot(Q(f(a), a, a))))), FOLOr(FOLNot(P(b)), FOLAnd(FOLAnd(FOLNot(P(b, b)), Q(f(b))), FOLOr(P(b), FOLNot(Q(f(b), a, a))))))), FOLOr(FOLOr(FOLOr(FOLNot(P(a)), FOLAnd(FOLAnd(FOLNot(P(a, a)), Q(f(a))), FOLOr(P(a), FOLNot(Q(f(a), a, b))))), FOLOr(FOLNot(P(b)), FOLAnd(FOLAnd(FOLNot(P(b, a)), Q(f(a))), FOLOr(P(b), FOLNot(Q(f(b), a, b)))))), FOLOr(FOLOr(FOLNot(P(a)), FOLAnd(FOLAnd(FOLNot(P(a, b)), Q(f(b))), FOLOr(P(a), FOLNot(Q(f(a), a, b))))), FOLOr(FOLNot(P(b)), FOLAnd(FOLAnd(FOLNot(P(b, b)), Q(f(b))), FOLOr(P(b), FOLNot(Q(f(b), a, b))))))))

# Unification

In [12]:
unification(P(x), P(a))

{x: a}

In [13]:
unification(P(f(x)), P(y))

{y: f(x)}

In [14]:
unification(P(f(x)), P(f(y)))

{y: x}

In [15]:
unification(g(g(x)),g(y))

{y: g(x)}

In [16]:
unification(P(a,b,c),P(x,y,z))

{x: a, y: b, z: c}

In [17]:
unification(P(f(g(x,a)),x),P(z,b))

{x: b, z: f(g(x, a))}

In [18]:
unification(P(f(x)), P(x))

{}

In [19]:
unification(P(x,x),P(c,d))

{}

# Resolution

In [20]:
binary_res(Clause(P(x)), Clause(~P(x)))

False

In [21]:
binary_res(Clause(P(x)), Clause(P(x)))

In [22]:
binary_res(Clause(P(x) | Q(y)), Clause(P(x) | ~Q(y)))

Clause(Q(y), P(x))

In [23]:
binary_res(Clause(P(x) | Q(x)), Clause(P(x) | ~Q(y)))

Clause(Q(x), P(x))