Skip to content

Commit

Permalink
Fixed various bugs and added some basic examples on Arrays
Browse files Browse the repository at this point in the history
* Changed types printing to be clearer and optimized hashing and equality that was string-based
  • Loading branch information
mikand committed May 23, 2016
1 parent 923184b commit 9d7e158
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 47 deletions.
5 changes: 4 additions & 1 deletion pysmt/formula.py
Expand Up @@ -906,7 +906,10 @@ def Array(self, idx_type, default, assigned_values=None):
for k in sorted(assigned_values):
if not k.is_constant():
raise ValueError("Array initialization indexes must be constants")
args.append(k, assigned_values[k])
# It is useless to represent assignments equal to the default
if assigned_values[k] != default:
args.append(k)
args.append(assigned_values[k])
n = self.create_node(node_type=op.ARRAY_VALUE, args=tuple(args),
payload=idx_type)
return n
Expand Down
14 changes: 6 additions & 8 deletions pysmt/printers.py
Expand Up @@ -28,8 +28,8 @@ class HRPrinter(TreeWalker):
E.g., Implies(And(Symbol(x), Symbol(y)), Symbol(z)) ~> '(x * y) -> z'
"""

def __init__(self, stream):
TreeWalker.__init__(self)
def __init__(self, stream, env=None):
TreeWalker.__init__(self, env=env)
self.stream = stream
self.write = self.stream.write

Expand Down Expand Up @@ -271,19 +271,17 @@ def walk_array_store(self, formula):
self.write("]")

def walk_array_value(self, formula):
self.write("Array{")
self.write(formula.array_value_index_type())
self.write("}")
self.write(str(self.env.stc.get_type(formula)))
self.write("(")
self.walk(formula.array_value_default())
self.write(")")
assign = formula.array_value_assigned_values_map()
for k, v in iteritems(assign):
self.write("[")
self.walk(k)
self.write(" := ")
self.walk(v)
self.write("]")
self.write("[* := ")
self.walk(formula.array_value_default())
self.write("]")


class HRSerializer(object):
Expand Down
7 changes: 4 additions & 3 deletions pysmt/solvers/msat.py
Expand Up @@ -1041,13 +1041,14 @@ def walk_array_store(self, formula, args, **kwargs):
args[0], args[1], args[2])

def walk_array_value(self, formula, args, **kwargs):
idx_type = formula.array_value_index_type()
arr_type = self.env.stc.get_type(formula)
rval = mathsat.msat_make_array_const(self.msat_env(),
self._type_to_msat(idx_type),
self._type_to_msat(arr_type),
args[0])
assert not mathsat.MSAT_ERROR_TERM(rval)
for i,c in enumerate(args[1::2]):
rval = mathsat.msat_make_array_write(self.msat_env(), rval,
c, args[i+1])
c, args[(i*2)+2])
return rval

def _type_to_msat(self, tp):
Expand Down
4 changes: 2 additions & 2 deletions pysmt/solvers/z3.py
Expand Up @@ -493,9 +493,9 @@ def _back_single_term(self, expr, args, model=None):
elif z3.is_array_store(expr):
res = self.mgr.Store(args[0], args[1], args[2])
elif z3.is_const_array(expr):
idx_ty = self._z3_to_type(expr.sort())
arr_ty = self._z3_to_type(expr.sort())
k = args[0]
res = self.mgr.Array(idx_ty, k)
res = self.mgr.Array(arr_ty.index_type, k)
if res is None:
raise ConvertExpressionError(message=("Unsupported expression: %s" %
str(expr)),
Expand Down
8 changes: 7 additions & 1 deletion pysmt/test/examples.py
Expand Up @@ -34,7 +34,7 @@
BVLShl, BVLShr,BVRol, BVRor,
BVZExt, BVSExt, BVSub, BVComp, BVAShr, BVSLE,
BVSLT, BVSGT, BVSGE, BVSDiv, BVSRem,
Store, Select)
Store, Select, Array)

from pysmt.typing import REAL, BOOL, INT, FunctionType, BV8, BV16, ARRAY_INT_INT

Expand Down Expand Up @@ -584,6 +584,12 @@ def get_example_formulae(environment=None):
is_valid=True,
is_sat=True,
logic=pysmt.logics.QF_ALIA),
# Array<Int,Int>(0)[1 := 1] = aii & aii[1] = 0
Example(expr=And(Equals(Array(INT, Int(0), {Int(1) : Int(1)}), aii), Equals(Select(aii, Int(1)), Int(0))),
is_valid=False,
is_sat=False,
logic=pysmt.logics.QF_ALIA),

]
return result

Expand Down
7 changes: 7 additions & 0 deletions pysmt/test/test_formula.py
Expand Up @@ -952,6 +952,13 @@ def test_typing(self):
self.assertTrue(self.ftype.is_function_type())
self.assertFalse(self.ftype.is_int_type())

def test_array_value(self):
a1 = self.mgr.Array(INT, self.mgr.Int(0))
a2 = self.mgr.Array(INT, self.mgr.Int(0),
{self.mgr.Int(12) : self.mgr.Int(0)})
self.assertEquals(a1, a2)


class TestShortcuts(TestCase):

def test_shortcut_is_using_global_env(self):
Expand Down
3 changes: 2 additions & 1 deletion pysmt/test/test_printing.py
Expand Up @@ -266,7 +266,8 @@ def test_smart_serialize(self):
"""(! (((ToReal(...) = r) & (ToReal(...) = r)) -> ((p < ...(..., ...)) | (...(..., ...) <= p))))""",
"""("Did you know that any string works? #yolo" & "10" & "|#somesolverskeepthe||" & " ")""",
"""((q = 0) -> (aii[0 := 0] = aii[0 := q]))""",
"""(aii[0 := 0][0] = 0)"""
"""(aii[0 := 0][0] = 0)""",
"""((Array<Int, Int>(0)[1 := 1] = aii) & (aii[1] = 0))"""
]


Expand Down
10 changes: 10 additions & 0 deletions pysmt/test/test_simplify.py
Expand Up @@ -19,6 +19,9 @@
from pysmt.test import TestCase, skipIfSolverNotAvailable, main
from pysmt.test.examples import get_example_formulae
from pysmt.environment import get_env
from pysmt.shortcuts import Array, Store, Int
from pysmt.typing import INT



class TestSimplify(TestCase):
Expand Down Expand Up @@ -52,5 +55,12 @@ def test_simplify_q(self):
"result:\n f= %s\n sf = %s" % (f, sf))


def test_array_value(self):
a1 = Array(INT, Int(0), {Int(12) : Int(10)})
a2 = Store(Array(INT, Int(0)), Int(12), Int(10))
self.assertEquals(a1, a2.simplify())



if __name__ == '__main__':
main()
46 changes: 15 additions & 31 deletions pysmt/typing.py
Expand Up @@ -52,7 +52,7 @@ def is_int_type(self):
def is_real_type(self):
return False

def is_bv_type(self):
def is_bv_type(self, width=None):
return False

def is_function_type(self):
Expand All @@ -70,9 +70,7 @@ def __eq__(self, other):
return self.type_id == other.type_id

def __ne__(self, other):
if other is None:
return True
return self.type_id != other.type_id
return not self.__eq__(other)


class BooleanType(PySMTType):
Expand Down Expand Up @@ -164,7 +162,7 @@ def as_smtlib(self, funstyle=True):
return "(_ BitVec %d)" % self.width

def __str__(self):
return "BV%d" % self.width
return "BV<%d>" % self.width

def __eq__(self, other):
if other is None:
Expand All @@ -176,13 +174,7 @@ def __eq__(self, other):
return True

def __ne__(self, other):
if other is None:
return True
if self.type_id != other.type_id:
return True
if self.width != other.width:
return True
return False
return not self.__eq__(other)

def __hash__(self):
return hash(self.type_id + self.width)
Expand Down Expand Up @@ -223,7 +215,7 @@ def __init__(self, return_type, param_types):
PySMTType.__init__(self, type_id = 4)
self._return_type = return_type
self._param_types = param_types
self._hash = hash(str(self))
self._hash = hash(return_type) + sum(hash(p) for p in param_types)
return

@property
Expand Down Expand Up @@ -269,16 +261,12 @@ def __eq__(self, other):
return False
if id(self) == id(other):
return True
return str(self) == str(other)
if self.return_type != other.return_type:
return False
return self.param_types == other.param_types

def __ne__(self, other):
if other is None:
return True
if self.type_id != other.type_id:
return True
if id(self) == id(other):
return False
return str(self) != str(other)
return not self.__eq__(other)

def __hash__(self):
return self._hash
Expand Down Expand Up @@ -311,7 +299,7 @@ def __init__(self, index_type, elem_type):
PySMTType.__init__(self, type_id = 5)
self._index_type = index_type
self._elem_type = elem_type
self._hash = hash(str(self))
self._hash = hash(index_type) + hash(elem_type)
return

@property
Expand Down Expand Up @@ -342,7 +330,7 @@ def as_smtlib(self, funstyle=True):
return "(Array %s %s)" % (itype, etype)

def __str__(self):
return "ARR[%s -> %s]" % (self.index_type, self.elem_type)
return "Array<%s, %s>" % (self.index_type, self.elem_type)

def is_array_type(self):
return True
Expand All @@ -354,16 +342,12 @@ def __eq__(self, other):
return False
if id(self) == id(other):
return True
return str(self) == str(other)
if self.index_type != other.index_type:
return False
return self.elem_type == other.elem_type

def __ne__(self, other):
if other is None:
return True
if self.type_id != other.type_id:
return True
if id(self) == id(other):
return False
return str(self) != str(other)
return not self.__eq__(other)

def __hash__(self):
return self._hash
Expand Down

0 comments on commit 9d7e158

Please sign in to comment.