Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RandomSymbol and assumptions in SMT-Lib Printer. #24406

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion sympy/functions/special/error_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ class expint(Function):
>>> expint(nu, z).diff(nu)
-z**(nu - 1)*meijerg(((), (1, 1)), ((0, 0, 1 - nu), ()), z)

At non-postive integer orders, the exponential integral reduces to the
At non-positive integer orders, the exponential integral reduces to the
exponential function:

>>> expint(0, z)
Expand Down
144 changes: 97 additions & 47 deletions sympy/printing/smtlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sympy
from sympy.core import Add, Mul
from sympy.core import Symbol, Expr, Float, Rational, Integer, Basic
from sympy.core.assumptions import assumptions
from sympy.core.function import UndefinedFunction, Function
from sympy.core.relational import Relational, Unequality, Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
from sympy.functions.elementary.complexes import Abs
Expand Down Expand Up @@ -70,8 +71,7 @@ class SMTLibPrinter(Printer):

symbol_table: dict

def __init__(self, settings: typing.Optional[dict] = None,
symbol_table=None):
def __init__(self, settings: typing.Optional[dict] = None, symbol_table=None):
settings = settings or {}
self.symbol_table = symbol_table or {}
Printer.__init__(self, settings)
Expand All @@ -80,15 +80,18 @@ def __init__(self, settings: typing.Optional[dict] = None,
self._known_constants = dict(self._settings['known_constants'])
self._known_functions = dict(self._settings['known_functions'])

for _ in self._known_types.values(): assert self._is_legal_name(_)
for _ in self._known_constants.values(): assert self._is_legal_name(_)
# for _ in self._known_functions.values(): assert self._is_legal_name(_) # +, *, <, >, etc.
for _ in self._known_types.values(): self._check_is_legal_name(_)
for _ in self._known_constants.values(): self._check_is_legal_name(_)
# for _ in self._known_functions.values(): self._check_is_legal_name(_) # +, *, <, >, etc.

def _is_legal_name(self, s: str):
if not s: return False
if s[0].isnumeric(): return False
return all(_.isalnum() or _ == '_' for _ in s)

def _check_is_legal_name(self, s: str):
assert self._is_legal_name(s), f"Name `{s}` may not be legal in SMT-Lib."

def _s_expr(self, op: str, args: typing.Union[list, tuple]) -> str:
args_str = ' '.join(
a if isinstance(a, str)
Expand Down Expand Up @@ -130,7 +133,7 @@ def _print_Piecewise(self, e: Piecewise):
def _print_Piecewise_recursive(args: typing.Union[list, tuple]):
e, c = args[0]
if len(args) == 1:
assert (c is True) or isinstance(c, BooleanTrue)
assert (c is True) or isinstance(c, BooleanTrue), "Piecewise expression must end in (expr, True) statement."
return self._print(e)
else:
ite = self._known_functions[ITE]
Expand Down Expand Up @@ -185,15 +188,19 @@ def _print_int(self, x: int):
return str(x)

def _print_Symbol(self, x: Symbol):
assert self._is_legal_name(x.name)
self._check_is_legal_name(x.name)
return x.name

def _print_RandomSymbol(self, x):
self._check_is_legal_name(x.name)
return x.name

def _print_NumberSymbol(self, x):
name = self._known_constants.get(x)
return name if name else self._print_Float(x)

def _print_UndefinedFunction(self, x):
assert self._is_legal_name(x.name)
self._check_is_legal_name(x.name)
return x.name

def _print_Exp1(self, x):
Expand Down Expand Up @@ -306,7 +313,7 @@ def smtlib_code(

if not symbol_table: symbol_table = {}
symbol_table = _auto_infer_smtlib_types(
*expr, symbol_table=symbol_table
expr, symbol_table, auto_assert
)
# See [FALLBACK RULES]
# Need SMTLibPrinter to populate known_functions and known_constants first.
Expand All @@ -332,39 +339,39 @@ def smtlib_code(

# [FALLBACK RULES]
for e in expr:
for sym in e.atoms(Symbol, Function):
for sym in _atoms_symbols_preserve_rv(e):
if (
sym.is_Symbol and
sym.is_symbol and
sym not in p._known_constants and
sym not in p.symbol_table
):
log_warn(f"Could not infer type of `{sym}`. Defaulting to float.")
p.symbol_table[sym] = float
for fun in e.atoms(Function):
if (
sym.is_Function and
type(sym) not in p._known_functions and
type(sym) not in p.symbol_table and
not sym.is_Piecewise
fun.is_Function and
type(fun) not in p._known_functions and
type(fun) not in p.symbol_table and
not fun.is_Piecewise
): raise TypeError(
f"Unknown type of undefined function `{sym}`. "
f"Unknown type of undefined function `{fun}`. "
f"Must be mapped to ``str`` in known_functions or mapped to ``Callable[..]`` in symbol_table."
)

declarations = []
if auto_declare:
constants = {sym.name: sym for e in expr for sym in e.free_symbols
if sym not in p._known_constants}
functions = {fnc.name: fnc for e in expr for fnc in e.atoms(Function)
if type(fnc) not in p._known_functions and not fnc.is_Piecewise}
declarations = \
[
_auto_declare_smtlib(sym, p, log_warn)
for sym in constants.values()
] + [
_auto_declare_smtlib(fnc, p, log_warn)
for fnc in functions.values()
]
declarations = [decl for decl in declarations if decl]
declarations = {}
declarations.update({
sym.name: sym for e in expr for sym in e.free_symbols # .free_symbols preserves random variables
if sym not in p._known_constants
})
declarations.update({
fun.name: fun for e in expr for fun in e.atoms(Function)
if type(fun) not in p._known_functions and not fun.is_Piecewise
})
declarations = [
_auto_declare_smtlib(sym_or_func, p, log_warn) for sym_or_func in declarations.values()
]

if auto_assert:
expr = [_auto_assert_smtlib(e, p, log_warn) for e in expr]
Expand All @@ -378,7 +385,7 @@ def smtlib_code(
],

# ';; DECLARATIONS',
*sorted(e for e in declarations),
*[line for block in sorted(declarations, key=lambda b: b[0]) for line in block],

# ';; EXPRESSIONS',
*[
Expand All @@ -394,12 +401,45 @@ def smtlib_code(
])


def _atoms_symbols_preserve_rv(expr):
from sympy.stats.rv import RandomSymbol
random_symbols = expr.atoms(RandomSymbol)
simple_symbols = set(expr.atoms(Symbol)) - {_.symbol for _ in random_symbols}
return list(simple_symbols) + list(random_symbols)


def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn: typing.Callable[[str], None]):
if sym.is_Symbol:
if sym.is_symbol:
type_signature = p.symbol_table[sym]
assert isinstance(type_signature, type)
type_signature = p._known_types[type_signature]
return p._s_expr('declare-const', [sym, type_signature])

from sympy.stats.rv import RandomSymbol
current_assumptions = assumptions(sym)
current_assumptions['__is_random_symbol'] = isinstance(sym, RandomSymbol)
unsupported_assumptions = {
'infinite', 'antihermitian', 'transcendental', 'imaginary', 'irrational', 'noninteger', 'even', 'odd', 'prime', 'composite'
}
supported_assumptions = {
'zero': lambda: Equality(sym, 0, evaluate=False),
'nonzero': lambda: Unequality(sym, 0, evaluate=False),
'positive': lambda: StrictGreaterThan(sym, 0, evaluate=False),
'nonnegative': lambda: GreaterThan(sym, 0, evaluate=False),
'negative': lambda: StrictLessThan(sym, 0, evaluate=False),
'nonpositive': lambda: LessThan(sym, 0, evaluate=False),
# mypy isn't smart enough to understand this lambda is only invoked if `isinstance(sym, RandomSymbol)`
'__is_random_symbol': lambda: sym.pspace.domain.as_boolean().simplify() # type: ignore[union-attr]
}
for a in unsupported_assumptions:
if current_assumptions.get(a) and not current_assumptions.get('zero'): # zero checks pretty much everything
raise ValueError(f"Cannot automatically assert '{a}'-ness when declaring `{sym}`. Please assert explicitly.")
return [
p._s_expr('declare-const', [sym, type_signature])
] + [
p._s_expr('assert', [predicate()])
for a, predicate in supported_assumptions.items()
if current_assumptions.get(a)
]

elif sym.is_Function:
type_signature = p.symbol_table[type(sym)]
Expand All @@ -408,11 +448,10 @@ def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter,
assert len(type_signature) > 0
params_signature = f"({' '.join(type_signature[:-1])})"
return_signature = type_signature[-1]
return p._s_expr('declare-fun', [type(sym), params_signature, return_signature])
return [p._s_expr('declare-fun', [type(sym), params_signature, return_signature])]

else:
log_warn(f"Non-Symbol/Function `{sym}` will not be declared.")
return None
raise ValueError(f"Non-Symbol/Function `{sym}` will not be declared.")


def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[str], None]):
Expand All @@ -430,15 +469,17 @@ def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[st


def _auto_infer_smtlib_types(
*exprs: Basic,
symbol_table: typing.Optional[dict] = None
exprs: typing.List[Basic],
symbol_table: typing.Optional[dict] = None,
auto_assert: bool = True
) -> dict:
# [TYPE INFERENCE RULES]
# X is alone in an expr => X is bool
# X is alone in an expr and auto-assert => X is bool
# X in BooleanFunction.args => X is bool
# X matches to a bool param of a symbol_table function => X is bool
# X matches to an int param of a symbol_table function => X is int
# X.is_integer => X is int
# X is random and X.pspace is continuous => X is float
# X == Y, where X is T => Y is T

# [FALLBACK RULES]
Expand All @@ -450,23 +491,23 @@ def _auto_infer_smtlib_types(

def safe_update(syms: set, inf):
for s in syms:
assert s.is_Symbol
assert s.is_symbol
if (old_type := _symbols.setdefault(s, inf)) != inf:
raise TypeError(f"Could not infer type of `{s}`. Apparently both `{old_type}` and `{inf}`?")

# EXPLICIT TYPES
safe_update({
e
for e in exprs
if e.is_Symbol
if e.is_symbol and auto_assert
}, bool)

safe_update({
symbol
for e in exprs
for boolfunc in e.atoms(BooleanFunction)
for symbol in boolfunc.args
if symbol.is_Symbol
if symbol.is_symbol
}, bool)

safe_update({
Expand All @@ -475,7 +516,7 @@ def safe_update(syms: set, inf):
for boolfunc in e.atoms(Function)
if type(boolfunc) in _symbols
for symbol, param in zip(boolfunc.args, _symbols[type(boolfunc)].__args__)
if symbol.is_Symbol and param == bool
if symbol.is_symbol and param == bool
}, bool)

safe_update({
Expand All @@ -484,29 +525,38 @@ def safe_update(syms: set, inf):
for intfunc in e.atoms(Function)
if type(intfunc) in _symbols
for symbol, param in zip(intfunc.args, _symbols[type(intfunc)].__args__)
if symbol.is_Symbol and param == int
if symbol.is_symbol and param == int
}, int)

safe_update({
symbol
for e in exprs
for symbol in e.atoms(Symbol)
for symbol in _atoms_symbols_preserve_rv(e)
if symbol.is_integer
}, int)

safe_update({
symbol
for e in exprs
for symbol in e.atoms(Symbol)
for symbol in _atoms_symbols_preserve_rv(e)
if symbol.is_real and not symbol.is_integer
}, float)

# CONTINUOUS RANDOM VARIABLE RULE
from sympy.stats.rv import RandomSymbol
safe_update({
rv
for e in exprs
for rv in e.atoms(RandomSymbol)
if rv.pspace.is_Continuous
}, float)

# EQUALITY RELATION RULE
rels = [rel for expr in exprs for rel in expr.atoms(Equality)]
rels = [
(rel.lhs, rel.rhs) for rel in rels if rel.lhs.is_Symbol
(rel.lhs, rel.rhs) for rel in rels if rel.lhs.is_symbol
] + [
(rel.rhs, rel.lhs) for rel in rels if rel.rhs.is_Symbol
(rel.rhs, rel.lhs) for rel in rels if rel.rhs.is_symbol
]
for infer, reltd in rels:
inference = (
Expand Down