# Sympy Expansion with First Order Logic

In [1]:
from sympy import cacheit, Symbol
from sympy.core.function import Function, UndefinedFunction, Application
import itertools

'''
All classes that require the implementation of arithmentic operators, should
derive from class Expr. Expr class is the base class of any algebraic expression.
Function is a subclass of Expr.
'''

class Constant_FOL(Symbol):
    is_Constant = True


class Variable_FOL(Symbol):
    is_Constant = False


class FunctionalSymbol(Function):
    def __new__(cls, *args, **options):
        if cls is FunctionalSymbol:
            # to make the object callable
            options['bases'] = (AppliedFunctionalSymbol,)
            # a functional symbol is considered as an UndefinedFunction
            fun_symbol = UndefinedFunction(*args, **options)
            return fun_symbol
        return super(FunctionalSymbol, cls).__new__(cls, *args, **options)

    def ground(self, var_name, constant):
        fun = FunctionalSymbol(self.name)
        variables = []
        # len(self.args) = arity of the functional symbol
        for i in range(0, len(self.args)):
          if self.args[i].is_Symbol and self.args[i].name == var_name:
            variables.append(constant)
          elif self.args[i].is_Symbol:
            variables.append(self.args[i])
          else:
            variables.append(self.args[i].ground(var_name, constant))
        return fun(*variables)


class AppliedFunctionalSymbol(FunctionalSymbol):
    """
    Base class for expressions resulting from the application of a functional
    symbol.
    """
    pass


class Formula(Function):
    is_Predicate = False
    is_Quantifier = False
    
    # redefinition of and
    def __and__(self, other):
        """Overloading for &."""
        return And_FOL(self, other)

    # redefinition of or
    def __or__(self, other):
        """Overloading for |."""
        return Or_FOL(self, other)

    # redefinition of not
    def __invert__(self):
        """Overloading for ~."""
        return Not_FOL(self)

    # redefinition of implies
    def __rshift__(self, other):
        """Overloading for >>."""
        return Implies_FOL(self, other)

    # redefinition of is implied
    def __lshift__(self, other):
        """Overloading for <<."""
        return Implies_FOL(other, self)

    # reverse operations
    __rand__ = __and__
    __ror__ = __or__
    __rrshift__ = __lshift__
    __rlshift__ = __rshift__


"""
func returns the class name.
def func(self):
  return self.__class__
"""

class LogicalConnective(Formula):
    
    # In NNF the logical connective 'not' must only be in front of predicates.
    # Furthermore, the logical connective 'implies' should be expressed as a
    # conjunction.
    def to_nnf(self):
        # every argument of the formula must have its own to_nnf() method
        return self.func(*[a.to_nnf() for a in self.args])


class And_FOL(LogicalConnective):

    # not(A and B) = not A or not B
    def _apply_not(self):
        return Or_FOL(self.args[0]._apply_not(), self.args[1]._apply_not())
    
    def _latex(self,printer,*args):        
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == Not_FOL:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
                return r'%s \wedge %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \wedge (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
            return r'(%s) \wedge %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

        return r'(%s) \wedge (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
    
    def ground(self, var_name, constant):
        return And_FOL(self.args[0].ground(var_name, constant),self.args[1].ground(var_name, constant))


class Or_FOL(LogicalConnective):

    # not(A or B) = not A and not B
    def _apply_not(self):
        return And_FOL(self.args[0]._apply_not(), self.args[1]._apply_not())
    
    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == Not_FOL:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
                return r'%s \vee %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \vee (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
            return r'(%s) \vee %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

        return r'(%s) \vee (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return Or_FOL(self.args[0].ground(var_name, constant),self.args[1].ground(var_name, constant))
    
    
class Not_FOL(LogicalConnective):

    def to_nnf(self):
        if self.args[0].is_Predicate:
            return self
        # if we have a formula preceeded by the not operator, then we call ._apply_not()
        neg = self.args[0]._apply_not()
        return neg.to_nnf()
    
    # not (not A) = A
    def _apply_not(self):
        return self.args[0]

    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == Not_FOL:
            return r'\lnot %s'%tuple([self.args[0]._latex(printer,*args)])
        return r'\lnot (%s)'%tuple([self.args[0]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return Not_FOL(self.args[0].ground(var_name, constant))


class Implies_FOL(LogicalConnective):

    # A -> B = not A or B
    def to_nnf(self):
        return Or_FOL(self.args[0]._apply_not().to_nnf(), self.args[1].to_nnf())

    # not (A -> B) = A and not B
    def _apply_not(self):
        return And_FOL(self.args[0], self.args[1]._apply_not())
    
    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == Not_FOL:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
                return r'%s \rightarrow %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \rightarrow (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
            return r'(%s) \rightarrow %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        return r'(%s) \rightarrow (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return Implies_FOL(self.args[0].ground(var_name, constant),self.args[1].ground(var_name, constant))


class IsEquivalent_FOL(LogicalConnective):

    # A <-> B = (A->B) and (B->A)
    def to_nnf(self):
        return And_FOL(Implies_FOL(self.args[0], self.args[1]).to_nnf(), Implies_FOL(self.args[1], self.args[0]).to_nnf())

    # not (A <-> B) = not(A->B) or not(B->A)
    def _apply_not(self):
        return Or_FOL(Implies_FOL(self.args[0], self.args[1])._apply_not(), Implies_FOL(self.args[1], self.args[0])._apply_not())
    
    def _latex(self,printer,*args):
        if self.args[0].is_Predicate or self.args[0].is_Quantifier or self.args[0].func == Not_FOL:
            if self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
                return r'%s \equiv %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
            else:
                return r'%s \equiv (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        elif self.args[1].is_Predicate or self.args[1].is_Quantifier or self.args[1].func == Not_FOL:
            return r'(%s) \equiv %s'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])
        return r'(%s) \equiv (%s)'%tuple([self.args[0]._latex(printer,*args), self.args[1]._latex(printer,*args)])

    def ground(self, var_name, constant):
        return IsEquivalent_FOL(self.args[0].ground(var_name, constant), self.args[1].ground(var_name, constant))


class Predicate(Formula):
    is_Atom = True
    is_Predicate = True

    def __new__(cls, *args, **options):
      if cls is Predicate:
          options['bases'] = (AppliedPredicate,)
          # a predicate is considered as an UndefinedFunction
          pred = UndefinedFunction(*args, **options)
          return pred
      return super(Predicate, cls).__new__(cls, *args, **options)
     
    def to_nnf(self):
        return self

    def _apply_not(self):
        return Not_FOL(self)

    def _latex(self,printer,*args):
        return r'%s'% str(self)
    
    def ground(self, var_name, constant):

        # len(self.args) = arity of the predicate
        pred = Predicate(self.name)
        variables = []
        for i in range(0,len(self.args)):
          if self.args[i].name == var_name:
            variables.append(constant)
          else:
            if self.args[i].is_Symbol:
              variables.append(self.args[i])
            else:
              variables.append(self.args[i].ground(var_name,constant))
        return pred(*variables)


class AppliedPredicate(Predicate):
  """
  Base class for expressions resulting from the application of a predicate.
  """
  pass


class Quantifier(Formula):
    is_Quantifier = True
    is_Forall = False
    is_Exists = False
    symbol = None

    def __new__(cls, *args, **options):
        if cls in (Forall, Exists):
            symbol = args[0] # variable which is not free
            argslist = list(args)
            argslist[0] = cls.name + symbol.name
            args = tuple(argslist)
            options['bases'] = (cls._getAppliedClass(),)
            quantifier = UndefinedFunction(*args, **options)
            quantifier.symbol = symbol
            return quantifier
        return super(Quantifier, cls).__new__(cls, *args, **options)

    # quantifier.formula
    @property
    def formula(self):
        return self.args[0]

    def to_nnf(self):
        return self.func(self.formula.to_nnf())

class Forall(Quantifier):
    is_Forall = True
    name = 'Forall_'
    
    # not forall = exist not
    def _apply_not(self):
        exists = Exists(self.symbol)
        return exists(self.formula._apply_not())

    # it works in the same way for any object of the Forall class
    @staticmethod
    def _getAppliedClass():
        return AppliedForall
    
    def _latex(self,printer,*args):
        if self.formula.is_Predicate or self.args[0].func == Not_FOL:
            return r'\forall %s %s'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
        return r'\forall %s (%s)'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
    
    def ground(self, set_const):
        formula = self.formula.ground(self.symbol.name, set_const[0])
        for i in range(1, len(set_const)):
          formula = And_FOL(formula, self.formula.ground(self.symbol.name, set_const[i]))
        return formula


class AppliedForall(Forall):
  """
  Base class for expressions resulting from the application of a for all.
  """
  pass


class Exists(Quantifier):
    is_Exists = True
    name = 'Exists_'

    # not exists = forall not
    def _apply_not(self):
        forall = Forall(self.symbol)
        return forall(self.formula._apply_not())

    # it works in the same way for any object of the Exists class
    @staticmethod
    def _getAppliedClass():
        return AppliedExists
    
    def _latex(self,printer,*args):
        if self.formula.is_Predicate or self.args[0].func == Not_FOL:
            return r'\exists %s %s'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
        return r'\exists %s (%s)'%tuple([printer._print(self.symbol,*args), self.formula._latex(printer,*args)])
    
    def ground(self, set_const):
        formula = self.formula.ground(self.symbol.name, set_const[0])
        for i in range(1, len(set_const)):
          formula = Or_FOL(formula, self.formula.ground(self.symbol.name, set_const[i]))
        return formula


class AppliedExists(Exists):
  """
  Base class for expressions resulting from the application of an exists.
  """
  pass


def formula_to_list(formula, delimiter):
    if isinstance(formula, delimiter):
        l = []
        l.extend(formula_to_list(formula.args[0], delimiter))
        l.extend(formula_to_list(formula.args[1], delimiter))
        return l
    else:
        return [formula]

def extract_quantifiers(formula):
    if formula.is_Predicate:
        return formula, []
    func = formula.func
    if formula.is_Quantifier:
        # for example: fun = Exists_x
        quant = [func]
        f, q = extract_quantifiers(formula.formula)
        quant.extend(q)
        return f, quant
    else:
        formulas = []
        args = formula.args
        quant = [] 
        for arg in args:
            f, q = extract_quantifiers(arg)
            formulas.append(f)
            quant.extend(q)
        return func(*formulas), quant


class Clause(Application):
    _formula = None
    def __new__(cls, *args, **kwargs):
        args_list = list(args)
        if([] in args_list):
            args_list.remove([])
        if len(args_list) == 1 and not args_list[0].is_Predicate and args_list[0].func is not Not_FOL:
            f,_ = extract_quantifiers(args_list[0])
            args = formula_to_list(f, Or_FOL)
        return super(Clause, cls).__new__(cls, *args, **kwargs)

    @property
    def formula(self):
        if self._formula is None:
          for a in self.args:
            if self._formula is None:
                self._formula = a
            else:
                self._formula = Or_FOL(self._formula, a)

        for s in self._get_symbols(self._formula).keys():
            self._formula = Forall(s)(self._formula)
        return self._formula

    @classmethod
    def _get_symbols(cls, e):
        if e.func is Symbol and not e.is_Constant:
            return {e: True}
        dic = {}
        for a in e.args:
            dic.update(cls._get_symbols(a))
        return dic

    def substitute(self, sub):
        f = self.formula.xreplace(sub)
        return Clause(f)


# FOL Algorithms

In [2]:
# ---------
# GROUNDING
# ---------

# Prenex Normal Form
def to_pnf(formula):
    # Assume to have a different symbol for each quantifier
    pnf_formula, quantifiers = extract_quantifiers(formula)
    quantifiers.reverse()
    for q in quantifiers:
        pnf_formula = q(pnf_formula)
    return pnf_formula


def grounding(formula, constants_set):
    # Assume to have the formula in NNF
    # Assume to have a different symbol for each quantifier
    formula, quantifiers = extract_quantifiers(formula)
    for quant in quantifiers:
      formula = quant(formula).ground(constants_set)
    return formula


# -----------
# UNIFICATION
# -----------

def unification(first_atom, second_atom):
    if first_atom.func != second_atom.func:
      return {}
    if len(first_atom.args) != len(second_atom.args):
      return {}
    sub = {}
    for i in range(0, len(first_atom.args)):
        s = terms_unification(first_atom.args[i], second_atom.args[i], sub)
        if s is None:
            return {}
        else:
            sub.update(s)
    return sub


def terms_unification(first_term, second_term, sub):
    # apply the previous transformations to the 2 terms if needed
    if first_term in sub.keys():
      first_term = sub.get(first_term)
    if second_term in sub.keys():
      second_term = sub.get(second_term)
    if first_term.is_Symbol and second_term.is_Symbol and first_term.name == second_term.name:
        return sub
   
    if first_term.is_Symbol and not first_term.is_Constant:
      # first_term is variable
      if second_term.is_Symbol or not is_variable_in_term(second_term, first_term):
          sub.update({first_term: second_term})
          return sub
    elif second_term.is_Symbol and not second_term.is_Constant:
      # first_term is not variable and second_term is variable
      if first_term.is_Symbol or not is_variable_in_term(first_term, second_term):
          sub.update({second_term: first_term})
          return sub

    if first_term.is_Function and second_term.is_Function:
      if first_term.func is second_term.func:
          for i in range(0, len(first_term.args)):
              s = terms_unification(first_term.args[i], second_term.args[i], sub)
              if s is None:
                  return None
              else:
                  sub.update(s)
          return sub
    # in all the other cases
    return None


def is_variable_in_term(term, var):
    if term.is_Symbol:
        if term.is_Constant:
            return False
        else:
            return term is var
    res = False
    for a in term.args:
        res = res | is_variable_in_term(a, var)
    return res

# ----------------------
# BINARY RESOLUTION RULE
# ----------------------

def binary_resolution(clause1, clause2):
    cl1 = list(clause1.args).copy()
    cl2 = list(clause2.args).copy()
    sub = None
    b = False
    for l1, l2 in itertools.product(cl1, cl2):
        if l1.func == Not_FOL and l2.is_Predicate:
          if l1.args[0] == l2:
            cl1.remove(l1)
            cl2.remove(l2)
            b = True
            break
          sub = unification(l1.args[0], l2)

        if l2.func == Not_FOL and l1.is_Predicate:
          if l2.args[0] == l1:
            cl1.remove(l1)
            cl2.remove(l2)
            b = True
            break
          sub = unification(l2.args[0], l1)
        

        if sub != None and sub != {}:
            cl1.remove(l1)
            cl2.remove(l2)
            break

    if ((sub == None) or (sub == {})) and (b == False):
        return "Binary Resolution does not apply"

    if (len(cl1) == 0 and len(cl2) == 0):
        return "Empty Clause"
    
    result = binary_resolution(Clause(*cl1).substitute(sub),Clause(*cl2).substitute(sub))
    if result == "Binary Resolution does not apply" or result == "Empty Clause":
      cl1.extend(cl2)
      return Clause(*cl1).substitute(sub)
    else:
      return result 


In [3]:
from sympy.core import Function, Symbol

x = Variable_FOL('x')
y = Variable_FOL('y')
z = Variable_FOL('z')

a = Constant_FOL('a')
b = Constant_FOL('b')
c = Constant_FOL('c')
d = Constant_FOL('d')

f = FunctionalSymbol('f')
g = FunctionalSymbol('g')
h = FunctionalSymbol('h')

P = Predicate('P')
Q = Predicate('Q')
R = Predicate('R')
S = Predicate('S')

# Grounding: Simple Example

In [4]:
formula = ~Forall(x) (P(x) & Exists(y)(P(y)))
formula

Not_FOL(Forall_x(And_FOL(P(x), Exists_y(P(y)))))

Negation Normal Form


In [5]:
formula.nnf = formula.to_nnf()
formula.nnf

Exists_x(Or_FOL(Not_FOL(P(x)), Forall_y(Not_FOL(P(y)))))

Prenex Normal Form

In [6]:
formula.pnf = to_pnf(formula.nnf)
formula.pnf

Exists_x(Forall_y(Or_FOL(Not_FOL(P(x)), Not_FOL(P(y)))))

Grounding

In [7]:
Delta = [a,b]
grounding(formula.pnf,Delta)

And_FOL(Or_FOL(Or_FOL(Not_FOL(P(a)), Not_FOL(P(a))), Or_FOL(Not_FOL(P(b)), Not_FOL(P(a)))), Or_FOL(Or_FOL(Not_FOL(P(a)), Not_FOL(P(b))), Or_FOL(Not_FOL(P(b)), Not_FOL(P(b)))))

# Grounding: More Complex Example

In [8]:
formula = ~Forall(x) (P(x) & (Exists (y) (~P(x,y) & Q(f(y))) >> Exists (z) (~P(x) & Q(f(x),a,z))))
formula

Not_FOL(Forall_x(And_FOL(P(x), Implies_FOL(Exists_y(And_FOL(Not_FOL(P(x, y)), Q(f(y)))), Exists_z(And_FOL(Not_FOL(P(x)), Q(f(x), a, z)))))))

Negation Normal Form

In [9]:
formula.nnf = formula.to_nnf()
formula.nnf

Exists_x(Or_FOL(Not_FOL(P(x)), And_FOL(Exists_y(And_FOL(Not_FOL(P(x, y)), Q(f(y)))), Forall_z(Or_FOL(P(x), Not_FOL(Q(f(x), a, z)))))))

Prenex Normal Form

In [10]:
formula.pnf = to_pnf(formula.nnf)
formula.pnf

Exists_x(Exists_y(Forall_z(Or_FOL(Not_FOL(P(x)), And_FOL(And_FOL(Not_FOL(P(x, y)), Q(f(y))), Or_FOL(P(x), Not_FOL(Q(f(x), a, z))))))))

Grounding

In [11]:
grounding(formula.pnf,[a,b])

And_FOL(Or_FOL(Or_FOL(Or_FOL(Not_FOL(P(a)), And_FOL(And_FOL(Not_FOL(P(a, a)), Q(f(a))), Or_FOL(P(a), Not_FOL(Q(f(a), a, a))))), Or_FOL(Not_FOL(P(b)), And_FOL(And_FOL(Not_FOL(P(b, a)), Q(f(a))), Or_FOL(P(b), Not_FOL(Q(f(b), a, a)))))), Or_FOL(Or_FOL(Not_FOL(P(a)), And_FOL(And_FOL(Not_FOL(P(a, b)), Q(f(b))), Or_FOL(P(a), Not_FOL(Q(f(a), a, a))))), Or_FOL(Not_FOL(P(b)), And_FOL(And_FOL(Not_FOL(P(b, b)), Q(f(b))), Or_FOL(P(b), Not_FOL(Q(f(b), a, a))))))), Or_FOL(Or_FOL(Or_FOL(Not_FOL(P(a)), And_FOL(And_FOL(Not_FOL(P(a, a)), Q(f(a))), Or_FOL(P(a), Not_FOL(Q(f(a), a, b))))), Or_FOL(Not_FOL(P(b)), And_FOL(And_FOL(Not_FOL(P(b, a)), Q(f(a))), Or_FOL(P(b), Not_FOL(Q(f(b), a, b)))))), Or_FOL(Or_FOL(Not_FOL(P(a)), And_FOL(And_FOL(Not_FOL(P(a, b)), Q(f(b))), Or_FOL(P(a), Not_FOL(Q(f(a), a, b))))), Or_FOL(Not_FOL(P(b)), And_FOL(And_FOL(Not_FOL(P(b, b)), Q(f(b))), Or_FOL(P(b), Not_FOL(Q(f(b), a, b))))))))

# Unification

In [12]:
unification(P(x), P(a))

{x: a}

In [13]:
unification(P(f(x)), P(y))

{y: f(x)}

In [14]:
unification(P(f(x)), P(f(y)))

{x: y}

In [15]:
unification(g(g(x)),g(y))

{y: g(x)}

In [16]:
unification(P(a,b,c),P(x,y,z))

{x: a, y: b, z: c}

In [17]:
unification(P(f(g(x,a)),x),P(z,b))

{x: b, z: f(g(x, a))}

In [18]:
unification(P(f(x)), P(x))

{}

In [19]:
unification(P(x,x),P(c,d))

{}

# Resolution

In [20]:
binary_resolution(Clause(P(x)), Clause(~P(x)))

'Empty Clause'

In [21]:
binary_resolution(Clause(P(x)), Clause(P(x)))

'Binary Resolution does not apply'

In [22]:
binary_resolution(Clause(S(x) | Q(y)), Clause(P(x) | ~Q(y)))

Clause(S(x), P(x))

In [23]:
binary_resolution(Clause(S(x) | Q(f(x))), Clause(P(x) | ~Q(y)))

Clause(S(x), P(x))

In [24]:
binary_resolution(Clause(P(x) | S(z) | Q(x) | R(z) |S(f(g(x)))), Clause(R(z) | ~Q(y) | ~R(z)| ~S(z)) )

Clause(P(x), S(f(g(x))), R(z))