Skip to content

Commit

Permalink
Merge pull request #38 from pysmt/getting_rid_of_generic_number
Browse files Browse the repository at this point in the history
Getting rid of generic number
  • Loading branch information
marcogario committed Mar 9, 2015
2 parents 4d7dccd + 5babe2c commit 0c8495e
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 603 deletions.
3 changes: 3 additions & 0 deletions docs/CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ General:

* Expressions: Added Min/Max operators.

* SMT-LIB: Substantially improved parser performances.


0.2.2 2015-02-07 -- BDDs
------------------------

Expand Down
7 changes: 0 additions & 7 deletions pysmt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,3 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pysmt.shortcuts.pop_env()

# EOC Environment



class TypeUnsafeEnvironment(Environment):
FormulaManagerClass = pysmt.formula.TypeUnsafeFormulaManager

#EOC TypeUnsafeFormulaManager
27 changes: 2 additions & 25 deletions pysmt/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def ToReal(self, formula):
# Ignore casting of a Real
return formula
elif t == types.INT:
if formula.is_int_constant():
return self.Real(formula.constant_value())
return self.create_node(node_type=op.TOREAL,
args=(formula,))
else:
Expand Down Expand Up @@ -535,28 +537,3 @@ def __contains__(self, node):
return False

#EOC FormulaManager

class TypeUnsafeFormulaManager(FormulaManager):
"""Subclass of FormulaManager in which type-checking is disabled.
TypeUnsafeFormulaManager makes it possible to build expressions
that are incorrect: e.g., True + 1. This is used mainly to avoid
the overhead of having to check each expression for type. For
example, during parsing we post-pone the type-check after the
whole expression has been built.
This should be used with caution.
"""

def __init__(self, env=None):
FormulaManager.__init__(self, env)

def _do_type_check(self, formula):
pass

def ToReal(self, formula):
""" Cast a formula to real type. """
return self.create_node(node_type=op.TOREAL,
args=(formula,))

#EOC TypeUnsafeFormulaManager
87 changes: 22 additions & 65 deletions pysmt/smtlib/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from pysmt.typing import BOOL, REAL, INT, FunctionType
from pysmt.logics import get_logic_by_name, UndefinedLogicError
from pysmt.exceptions import UnknownSmtLibCommandError
from pysmt.utils.generic_number import GenericNumber, disambiguate
from pysmt.environment import TypeUnsafeEnvironment
from pysmt.smtlib.script import SmtLibCommand, SmtLibScript


Expand Down Expand Up @@ -225,34 +223,22 @@ def __init__(self, environment=None):
'to_real':mgr.ToReal,
})

def _is_unknown_constant_type(self):
"""
Returns true if the logic at hand allows for bot Real and Integer
constants
"""
return self.logic is None or \
(self.logic.theory.integer_arithmetic and
self.logic.theory.real_arithmetic)

def _minus_or_uminus(self, *args):
"""Utility function that handles both unary and binary minus"""
mgr = self._current_env.formula_manager
if len(args) == 1:
if self._is_unknown_constant_type():
if type(args[0]) == GenericNumber:
return GenericNumber(-1 * args[0].value)
return mgr.Times(GenericNumber(-1), args[0])
lty = self._current_env.stc.get_type(args[0])
mult = None
if lty == INT:
if args[0].is_int_constant():
return mgr.Int(-1 * args[0].constant_value())
mult = mgr.Int(-1)
else:
if self.logic.theory.real_arithmetic:
if args[0].is_real_constant():
return mgr.Real(-1 * args[0].constant_value())
return mgr.Times(mgr.Real(-1), args[0])
else:
assert self.logic.theory.integer_arithmetic
if args[0].is_int_constant():
return mgr.Int(-1 * args[0].constant_value())
return mgr.Times(mgr.Int(-1), args[0])
if args[0].is_real_constant():
return mgr.Real(-1 * args[0].constant_value())
mult = mgr.Real(-1)
return mgr.Times(mult, args[0])
else:
assert len(args) == 2
return mgr.Minus(args[0], args[1])
Expand All @@ -261,20 +247,16 @@ def _minus_or_uminus(self, *args):
def _equals_or_iff(self, left, right):
"""Utility function that treats = between booleans as <->"""
mgr = self._current_env.formula_manager
if self._is_unknown_constant_type():
return mgr.Equals(left, right)

lty = self._current_env.stc.get_type(left)
if lty == BOOL:
return mgr.Iff(left, right)
else:
lty = self._current_env.stc.get_type(left)
if lty == BOOL:
return mgr.Iff(left, right)
else:
return mgr.Equals(left, right)
return mgr.Equals(left, right)

def _division(self, left, right):
"""Utility function that builds a division"""
mgr = self._current_env.formula_manager
if left.is_real_constant() and right.is_real_constant():
if left.is_constant() and right.is_constant():
return mgr.Real(Fraction(left.constant_value()) / \
Fraction(right.constant_value()))
return mgr.Div(left, right)
Expand Down Expand Up @@ -319,16 +301,11 @@ def atom(self, token, mgr):
else:
iterm = int(token)
# We found an integer, depending on the logic this can be
# an Int, a Real, or an unknown GenericNumber.
if self._is_unknown_constant_type():
res = GenericNumber(iterm)
elif self.logic.theory.real_arithmetic:
res = mgr.Real(iterm)
else:
assert self.logic.theory.integer_arithmetic, \
"Integer constant found in a logic that does not " \
"support arithmetic"
# an Int or a Real
if self.logic is None or self.logic.theory.integer_arithmetic:
res = mgr.Int(iterm)
else:
res = mgr.Real(iterm)
self.cache.bind(token, res)
return res

Expand Down Expand Up @@ -367,26 +344,6 @@ def _reset_env(self, env):
self._current_env = env


def get_expression(self, tokens):
"""
Returns the pysmt representation of the given parsed expression
"""
tu_env = None
if self._is_unknown_constant_type():
old_env = self._current_env
tu_env = TypeUnsafeEnvironment()
self._use_env(tu_env)

r = self._do_get_expression(tokens)

dis = disambiguate(tu_env, r, fix_equals=True)
self._reset_env(old_env)
return self.pysmt_env.formula_manager.normalize(dis)

else:
return self._do_get_expression(tokens)


def _handle_let(self, varlist, bdy):
""" Cleans the execution environment when we exit the scope of a 'let' """
for k in varlist:
Expand Down Expand Up @@ -423,9 +380,9 @@ def _handle_annotation(self, pyterm, attrs):
return pyterm


def _do_get_expression(self, tokens):
def get_expression(self, tokens):
"""
Iteratively parse the token stream
Returns the pysmt representation of the given parsed expression
"""
mgr = self._current_env.formula_manager
stack = []
Expand All @@ -448,7 +405,7 @@ def _do_get_expression(self, tokens):
if current != "(":
raise SyntaxError("Expected '(' in let binding")
vname = self.parse_atom(tokens, "expression")
expr = self._do_get_expression(tokens)
expr = self.get_expression(tokens)
newvals[vname] = expr
self.consume_closing(tokens, "expression")
current = next(tokens)
Expand Down
4 changes: 2 additions & 2 deletions pysmt/smtlib/printers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def walk_real_constant(self, formula):
if d != 1:
res = template % ( "(/ " + str(n) + " " + str(d) + ")" )
else:
res = template % str(n)
res = template % (str(n) + ".0")

self.write(res)

Expand Down Expand Up @@ -221,7 +221,7 @@ def walk_real_constant(self, formula, args, **kwargs):
if d != 1:
return template % ( "(/ " + str(n) + " " + str(d) + ")" )
else:
return template % str(n)
return template % (str(n) + ".0")

def walk_bool_constant(self, formula, args, **kwargs):
if formula.constant_value():
Expand Down
99 changes: 91 additions & 8 deletions pysmt/solvers/msat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from pysmt.exceptions import SolverReturnedUnknownResultError
from pysmt.exceptions import InternalSolverError
from pysmt.decorators import clear_pending_pop
from pysmt.environment import TypeUnsafeEnvironment
from pysmt.utils.generic_number import GenericNumber, disambiguate


class MathSAT5Solver(Solver, SmtLibBasicSolver, SmtLibIgnoreMixin):
Expand Down Expand Up @@ -221,11 +219,87 @@ def __init__(self, environment, msat_env):


def back(self, expr):
tu_env = TypeUnsafeEnvironment()
tu_res = self._walk_back(expr, tu_env.formula_manager)
tu_f = disambiguate(tu_env, tu_res, create_toreal_on_demand=True)
return self.env.formula_manager.normalize(tu_f)
return self._walk_back(expr, self.mgr)

def _most_generic(self, ty1, ty2):
"""Returns teh most generic, yet compatible type between ty1 and ty2"""
if ty1 == ty2:
return ty1

assert ty1 in [types.REAL, types.INT]
assert ty2 in [types.REAL, types.INT]
return types.REAL


def _get_signature(self, term, args):
"""Returns the signature of the given term.
For example:
- a term x & y returns a function type Bool -> Bool -> Bool,
- a term 14 returns Int
- a term x ? 13 : 15.0 returns Bool -> Real -> Real -> Real
"""
res = None

if mathsat.msat_term_is_true(self.msat_env, term) or \
mathsat.msat_term_is_false(self.msat_env, term) or \
mathsat.msat_term_is_boolean_constant(self.msat_env, term):
res = types.BOOL

elif mathsat.msat_term_is_number(self.msat_env, term):
ty = mathsat.msat_term_get_type(term)
if mathsat.msat_is_integer_type(self.msat_env, ty):
res = types.INT
elif mathsat.msat_is_rational_type(self.msat_env, ty):
res = types.REAL
else:
raise NotImplementedError

elif mathsat.msat_term_is_and(self.msat_env, term) or \
mathsat.msat_term_is_or(self.msat_env, term) or \
mathsat.msat_term_is_iff(self.msat_env, term):
res = types.FunctionType(types.BOOL, [types.BOOL, types.BOOL])

elif mathsat.msat_term_is_not(self.msat_env, term):
res = types.FunctionType(types.BOOL, [types.BOOL])

elif mathsat.msat_term_is_term_ite(self.msat_env, term):
t1 = self.env.stc.get_type(args[1])
t2 = self.env.stc.get_type(args[2])
t = self._most_generic(t1, t2)
res = types.FunctionType(t, [types.BOOL, t, t])

elif mathsat.msat_term_is_equal(self.msat_env, term) or \
mathsat.msat_term_is_leq(self.msat_env, term):
t1 = self.env.stc.get_type(args[0])
t2 = self.env.stc.get_type(args[1])
t = self._most_generic(t1, t2)
res = types.FunctionType(types.BOOL, [t, t])

elif mathsat.msat_term_is_plus(self.msat_env, term) or \
mathsat.msat_term_is_times(self.msat_env, term):
t1 = self.env.stc.get_type(args[0])
t2 = self.env.stc.get_type(args[1])
t = self._most_generic(t1, t2)
res = types.FunctionType(t, [t, t])

elif mathsat.msat_term_is_constant(self.msat_env, term):
ty = mathsat.msat_term_get_type(term)
if mathsat.msat_is_rational_type(self.msat_env, ty):
res = types.REAL
elif mathsat.msat_is_integer_type(self.msat_env, ty):
res = types.INT
else:
raise NotImplementedError("Unsupported variable type found")

elif mathsat.msat_term_is_uf(self.msat_env, term):
d = mathsat.msat_term_get_decl(term)
fun = self.get_symbol_from_declaration(d)
res = fun.symbol_type()

else:
raise TypeError("Unsupported expression:",
mathsat.msat_term_repr(term))
return res

def _back_single_term(self, term, mgr, args):
"""Builds the pysmt formula given a term and the list of formulae
Expand Down Expand Up @@ -258,7 +332,7 @@ def _back_single_term(self, term, mgr, args):
elif mathsat.msat_term_is_number(self.msat_env, term):
ty = mathsat.msat_term_get_type(term)
if mathsat.msat_is_integer_type(self.msat_env, ty):
res = GenericNumber(int(mathsat.msat_term_repr(term)))
res = mgr.Int(int(mathsat.msat_term_repr(term)))
elif mathsat.msat_is_rational_type(self.msat_env, ty):
res = mgr.Real(Fraction(mathsat.msat_term_repr(term)))
else:
Expand Down Expand Up @@ -323,6 +397,7 @@ def _back_single_term(self, term, mgr, args):
return res



def get_symbol_from_declaration(self, decl):
return self.decl_to_symbol[mathsat.msat_decl_id(decl)]

Expand All @@ -341,7 +416,15 @@ def _walk_back(self, term, mgr):
elif self.back_memoization[current] is None:
args=[self.back_memoization[mathsat.msat_term_get_arg(current,i)]
for i in xrange(arity)]
res = self._back_single_term(current, mgr, args)

signature = self._get_signature(current, args)
new_args = []
for i, a in enumerate(args):
t = self.env.stc.get_type(a)
if t != signature.param_types[i]:
a = mgr.ToReal(a)
new_args.append(a)
res = self._back_single_term(current, mgr, new_args)
self.back_memoization[current] = res
else:
# we already visited the node, nothing else to do
Expand Down
2 changes: 1 addition & 1 deletion pysmt/test/smtlib/small_set/QF_LIRA/lira1.smt2
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
(declare-fun y_1 () Int)
(declare-fun y_2 () Int)

(assert (> (ite (= x_0 x_1) 2 5) x_2))
(assert (> (ite (= x_0 x_1) 2.0 5.0) x_2))
(assert (> (ite (= y_0 y_1) 2 8) y_2))
(assert (> (ite (= y_0 y_1) 2 4) 1))
(check-sat)
Expand Down
2 changes: 1 addition & 1 deletion pysmt/test/smtlib/small_set/QF_LIRA/prp-20-46.smt2

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pysmt/test/smtlib/test_smtlibscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_get_strict_formula(self):
(declare-fun x () Bool)
(declare-fun y () Bool)
(declare-fun r () Real)
(assert (> r 0))
(assert (> r 0.0))
(assert x)
(check-sat)
"""
Expand Down
4 changes: 2 additions & 2 deletions pysmt/test/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_real(self):
Symbol("y", REAL)])

f_string = self.print_to_string(f)
self.assertEquals(f_string, "(+ 1 x y)")
self.assertEquals(f_string, "(+ 1.0 x y)")

def test_boolean(self):
x, y, z = Symbol("x"), Symbol("y"), Symbol("z")
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_constant(self):
r3_string = self.print_to_string(r3)

self.assertEquals(r1_string, "(/ 11 2)")
self.assertEquals(r2_string, "5")
self.assertEquals(r2_string, "5.0")
self.assertEquals(r3_string, "(- (/ 11 2))")

i1_string = self.print_to_string(i1)
Expand Down

0 comments on commit 0c8495e

Please sign in to comment.