diff --git a/sympy/logic/boolalg.py b/sympy/logic/boolalg.py index 9da8a80b6117..93535adb9b11 100644 --- a/sympy/logic/boolalg.py +++ b/sympy/logic/boolalg.py @@ -980,6 +980,10 @@ def to_nnf(self, simplify=True): def _eval_derivative(self, x): return self.func(self.args[0], *[a.diff(x) for a in self.args[1:]]) + def _eval_rewrite_as_Piecewise(self, *args): + from sympy.functions import Piecewise + return Piecewise((args[1], args[0]), (args[2], True)) + # the diff method below is copied from Expr class def diff(self, *symbols, **assumptions): new_symbols = list(map(sympify, symbols)) # e.g. x, 2, y, z diff --git a/sympy/logic/tests/test_boolalg.py b/sympy/logic/tests/test_boolalg.py index 388927a6f806..faed19ee2b95 100644 --- a/sympy/logic/tests/test_boolalg.py +++ b/sympy/logic/tests/test_boolalg.py @@ -5,6 +5,7 @@ from sympy.core.relational import Equality from sympy.core.singleton import S from sympy.core.symbol import (Dummy, symbols) +from sympy.functions import Piecewise from sympy.sets.sets import (EmptySet, Interval, Union) from sympy.simplify.simplify import simplify from sympy.logic.boolalg import ( @@ -504,6 +505,10 @@ def test_ITE(): assert ITE(x, A, B) == Not(x) assert ITE(x, B, A) == x +def test_ITE_rewrite_Piecewise(): + + assert ITE(A, B, C).rewrite(Piecewise) == Piecewise((B, A), (C, True)) + def test_ITE_diff(): # analogous to Piecewise.diff diff --git a/sympy/printing/__init__.py b/sympy/printing/__init__.py index a6882d83e7cc..a4d7f9630685 100644 --- a/sympy/printing/__init__.py +++ b/sympy/printing/__init__.py @@ -5,6 +5,7 @@ from .latex import latex, print_latex from .mathml import mathml, print_mathml from .python import python, print_python +from .pycode import pycode from .ccode import ccode, print_ccode from .glsl import glsl_code, print_glsl from .cxxcode import cxxcode diff --git a/sympy/printing/codeprinter.py b/sympy/printing/codeprinter.py index f9d3c19b80b8..e873cda507c5 100644 --- a/sympy/printing/codeprinter.py +++ b/sympy/printing/codeprinter.py @@ -31,11 +31,14 @@ class CodePrinter(StrPrinter): 'not': '!', } - _default_settings = {'order': None, - 'full_prec': 'auto', - 'error_on_reserved': False, - 'reserved_word_suffix': '_', - 'human': True} + _default_settings = { + 'order': None, + 'full_prec': 'auto', + 'error_on_reserved': False, + 'reserved_word_suffix': '_', + 'human': True, + 'inline': False + } def __init__(self, settings=None): @@ -337,11 +340,14 @@ def _print_Function(self, expr): _print_Expr = _print_Function def _print_NumberSymbol(self, expr): - # A Number symbol that is not implemented here or with _printmethod - # is registered and evaluated - self._number_symbols.add((expr, - self._print(expr.evalf(self._settings["precision"])))) - return str(expr) + if self._settings.get("inline", False): + return self._print(expr.evalf(self._settings["precision"])) + else: + # A Number symbol that is not implemented here or with _printmethod + # is registered and evaluated + self._number_symbols.add((expr, + self._print(expr.evalf(self._settings["precision"])))) + return str(expr) def _print_Catalan(self, expr): return self._print_NumberSymbol(expr) diff --git a/sympy/printing/julia.py b/sympy/printing/julia.py index a895b61a7a48..8b4a220988f8 100644 --- a/sympy/printing/julia.py +++ b/sympy/printing/julia.py @@ -246,14 +246,6 @@ def _print_GoldenRatio(self, expr): return super(JuliaCodePrinter, self)._print_NumberSymbol(expr) - def _print_NumberSymbol(self, expr): - if self._settings["inline"]: - return self._print(expr.evalf(self._settings["precision"])) - else: - # assign to a variable, perhaps more readable for longer program - return super(JuliaCodePrinter, self)._print_NumberSymbol(expr) - - def _print_Assignment(self, expr): from sympy.functions.elementary.piecewise import Piecewise from sympy.tensor.indexed import IndexedBase diff --git a/sympy/printing/lambdarepr.py b/sympy/printing/lambdarepr.py index a20d483ef253..7f2d12b258ee 100644 --- a/sympy/printing/lambdarepr.py +++ b/sympy/printing/lambdarepr.py @@ -1,7 +1,11 @@ from __future__ import print_function, division from .str import StrPrinter -from .pycode import PythonCodePrinter +from .pycode import ( + PythonCodePrinter, + MpmathPrinter, # MpmathPrinter is imported for backward compatibility + NumPyPrinter # NumPyPrinter is imported for backward compatibility +) from sympy.utilities import default_sort_key @@ -10,48 +14,8 @@ class LambdaPrinter(PythonCodePrinter): This printer converts expressions into strings that can be used by lambdify. """ + printmethod = "_lambdacode" - def _print_MatrixBase(self, expr): - return "%s(%s)" % (expr.__class__.__name__, - self._print(expr.tolist())) - - _print_SparseMatrix = \ - _print_MutableSparseMatrix = \ - _print_ImmutableSparseMatrix = \ - _print_Matrix = \ - _print_DenseMatrix = \ - _print_MutableDenseMatrix = \ - _print_ImmutableMatrix = \ - _print_ImmutableDenseMatrix = \ - _print_MatrixBase - - def _print_Piecewise(self, expr): - result = [] - i = 0 - for arg in expr.args: - e = arg.expr - c = arg.cond - result.append('((') - result.append(self._print(e)) - result.append(') if (') - result.append(self._print(c)) - result.append(') else (') - i += 1 - result = result[:-1] - result.append(') else None)') - result.append(')'*(2*i - 2)) - return ''.join(result) - - def _print_Sum(self, expr): - loops = ( - 'for {i} in range({a}, {b}+1)'.format( - i=self._print(i), - a=self._print(a), - b=self._print(b)) - for i, a, b in expr.limits) - return '(builtins.sum({function} {loops}))'.format( - function=self._print(expr.function), - loops=' '.join(loops)) def _print_And(self, expr): result = ['('] @@ -98,6 +62,7 @@ class TensorflowPrinter(LambdaPrinter): Tensorflow printer which handles vectorized piecewise functions, logical operators, max/min, and relational operators. """ + printmethod = "_tensorflowcode" def _print_And(self, expr): "Logical And printer" @@ -171,121 +136,13 @@ def _print_Relational(self, expr): return super(TensorflowPrinter, self)._print_Relational(expr) -class NumPyPrinter(LambdaPrinter): - """ - Numpy printer which handles vectorized piecewise functions, - logical operators, etc. - """ - - def _print_seq(self, seq, delimiter=', '): - "General sequence printer: converts to tuple" - # Print tuples here instead of lists because numba supports - # tuples in nopython mode. - return '({},)'.format(delimiter.join(self._print(item) for item in seq)) - - def _print_MatMul(self, expr): - "Matrix multiplication printer" - return '({0})'.format(').dot('.join(self._print(i) for i in expr.args)) - - def _print_DotProduct(self, expr): - # DotProduct allows any shape order, but numpy.dot does matrix - # multiplication, so we have to make sure it gets 1 x n by n x 1. - arg1, arg2 = expr.args - if arg1.shape[0] != 1: - arg1 = arg1.T - if arg2.shape[1] != 1: - arg2 = arg2.T - - return "dot(%s, %s)" % (self._print(arg1), self._print(arg2)) - - def _print_Piecewise(self, expr): - "Piecewise function printer" - exprs = '[{0}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) - conds = '[{0}]'.format(','.join(self._print(arg.cond) for arg in expr.args)) - # If [default_value, True] is a (expr, cond) sequence in a Piecewise object - # it will behave the same as passing the 'default' kwarg to select() - # *as long as* it is the last element in expr.args. - # If this is not the case, it may be triggered prematurely. - return 'select({0}, {1}, default=nan)'.format(conds, exprs) - - def _print_Relational(self, expr): - "Relational printer for Equality and Unequality" - op = { - '==' :'equal', - '!=' :'not_equal', - '<' :'less', - '<=' :'less_equal', - '>' :'greater', - '>=' :'greater_equal', - } - if expr.rel_op in op: - lhs = self._print(expr.lhs) - rhs = self._print(expr.rhs) - return '{op}({lhs}, {rhs})'.format(op=op[expr.rel_op], - lhs=lhs, - rhs=rhs) - return super(NumPyPrinter, self)._print_Relational(expr) - - def _print_And(self, expr): - "Logical And printer" - # We have to override LambdaPrinter because it uses Python 'and' keyword. - # If LambdaPrinter didn't define it, we could use StrPrinter's - # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS. - return '{0}({1})'.format('logical_and', ','.join(self._print(i) for i in expr.args)) - - def _print_Or(self, expr): - "Logical Or printer" - # We have to override LambdaPrinter because it uses Python 'or' keyword. - # If LambdaPrinter didn't define it, we could use StrPrinter's - # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS. - return '{0}({1})'.format('logical_or', ','.join(self._print(i) for i in expr.args)) - - def _print_Not(self, expr): - "Logical Not printer" - # We have to override LambdaPrinter because it uses Python 'not' keyword. - # If LambdaPrinter didn't define it, we would still have to define our - # own because StrPrinter doesn't define it. - return '{0}({1})'.format('logical_not', ','.join(self._print(i) for i in expr.args)) - - def _print_Min(self, expr): - return '{0}(({1}))'.format('amin', ','.join(self._print(i) for i in expr.args)) - - def _print_Max(self, expr): - return '{0}(({1}))'.format('amax', ','.join(self._print(i) for i in expr.args)) - - def _print_Pow(self, expr): - if expr.exp == 0.5: - return '{0}({1})'.format('sqrt', self._print(expr.base)) - else: - return super(NumPyPrinter, self)._print_Pow(expr) - - def _print_log10(self, expr): # log10 in C89, but type-generic macro in C99 - return 'log10({0})'.format(self._print(expr.args[0])) - - def _print_Sqrt(self, expr): - return 'sqrt({0})'.format(self._print(expr.args[0])) - - def _print_hypot(self, expr): - return 'hypot({0}, {1})'.format(*map(self._print, expr.args)) - - def _print_expm1(self, expr): - return 'expm1({0})'.format(self._print(expr.args[0])) - - def _print_log1p(self, expr): - return 'log1p({0})'.format(self._print(expr.args[0])) - - def _print_exp2(self, expr): - return 'exp2({0})'.format(self._print(expr.args[0])) - - def _print_log2(self, expr): - return 'log2({0})'.format(self._print(expr.args[0])) - - # numexpr works by altering the string passed to numexpr.evaluate # rather than by populating a namespace. Thus a special printer... class NumExprPrinter(LambdaPrinter): # key, value pairs correspond to sympy name and numexpr name # functions not appearing in this dict will raise a TypeError + printmethod = "_numexprcode" + _numexpr_functions = { 'sin' : 'sin', 'cos' : 'cos', @@ -363,27 +220,9 @@ def doprint(self, expr): lstr = super(NumExprPrinter, self).doprint(expr) return "evaluate('%s', truediv=True)" % lstr -class MpmathPrinter(LambdaPrinter): - """ - Lambda printer for mpmath which maintains precision for floats - """ - def _print_Integer(self, e): - return 'mpf(%d)' % e - - def _print_Float(self, e): - # XXX: This does not handle setting mpmath.mp.dps. It is assumed that - # the caller of the lambdified function will have set it to sufficient - # precision to match the Floats in the expression. - - # Remove 'mpz' if gmpy is installed. - args = str(tuple(map(int, e._mpf_))) - return 'mpf(%s)' % args - - def _print_uppergamma(self,e): #printer for the uppergamma function - return "gammainc({0}, {1}, inf)".format(self._print(e.args[0]), self._print(e.args[1])) - def _print_lowergamma(self,e): #printer for the lowergamma functioin - return "gammainc({0}, 0, {1})".format(self._print(e.args[0]), self._print(e.args[1])) +for k in NumExprPrinter._numexpr_functions: + setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function) def lambdarepr(expr, **settings): """ diff --git a/sympy/printing/octave.py b/sympy/printing/octave.py index cfcae81dae9c..ff5007781027 100644 --- a/sympy/printing/octave.py +++ b/sympy/printing/octave.py @@ -231,14 +231,6 @@ def _print_GoldenRatio(self, expr): return "(1+sqrt(5))/2" - def _print_NumberSymbol(self, expr): - if self._settings["inline"]: - return self._print(expr.evalf(self._settings["precision"])) - else: - # assign to a variable, perhaps more readable for longer program - return super(OctaveCodePrinter, self)._print_NumberSymbol(expr) - - def _print_Assignment(self, expr): from sympy.functions.elementary.piecewise import Piecewise from sympy.tensor.indexed import IndexedBase diff --git a/sympy/printing/pycode.py b/sympy/printing/pycode.py index bb9226e39f01..e68f6595cdc8 100644 --- a/sympy/printing/pycode.py +++ b/sympy/printing/pycode.py @@ -1,3 +1,7 @@ +from collections import defaultdict +from functools import wraps +from itertools import chain +from sympy.core import sympify from .precedence import precedence from .codeprinter import CodePrinter @@ -11,23 +15,99 @@ _kw_only_py3 = {'False', 'nonlocal', 'True'} _known_functions = { - 'Abs': 'abs' + 'Abs': 'abs', } +_known_functions_math = { + 'acos': 'acos', + 'acosh': 'acosh', + 'asin': 'asin', + 'asinh': 'asinh', + 'atan': 'atan', + 'atan2': 'atan2', + 'atanh': 'atanh', + 'ceiling': 'ceil', + 'cos': 'cos', + 'cosh': 'cosh', + 'erf': 'erf', + 'erfc': 'erfc', + 'exp': 'exp', + 'expm1': 'expm1', + 'factorial': 'factorial', + 'floor': 'floor', + 'gamma': 'gamma', + 'hypot': 'hypot', + 'loggamma': 'lgamma', + 'log': 'log', + 'log10': 'log10', + 'log1p': 'log1p', + 'log2': 'log2', + 'sin': 'sin', + 'sinh': 'sinh', + 'Sqrt': 'sqrt', + 'tan': 'tan', + 'tanh': 'tanh' +} # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf +# radians trunc fmod fsum gcd degrees fabs] +_known_constants_math = { + 'Exp1': 'e', + 'Pi': 'pi', + # Only in python >= 3.5: + # 'Infinity': 'inf', + # 'NaN': 'nan' +} + +def _print_known_func(self, expr): + known = self.known_functions[expr.__class__.__name__] + return '{name}({args})'.format(name=self._module_format(known), + args=', '.join(map(self._print, expr.args))) + + +def _print_known_const(self, expr): + known = self.known_constants[expr.__class__.__name__] + return self._module_format(known) + class PythonCodePrinter(CodePrinter): printmethod = "_pythoncode" language = "Python" standard = "python3" reserved_words = _kw_py2and3.union(_kw_only_py3) + modules = None # initialized to a set in __init__ tab = ' ' - _kf = _known_functions + _kf = dict(chain( + _known_functions.items(), + [(k, 'math.' + v) for k, v in _known_functions_math.items()] + )) + _kc = {k: 'math.'+v for k, v in _known_constants_math.items()} _operators = {'and': 'and', 'or': 'or', 'not': 'not'} - _default_settings = dict(CodePrinter._default_settings, precision=15) + _default_settings = dict( + CodePrinter._default_settings, + user_functions={}, + precision=17, + inline=True, + fully_qualified_modules=True + ) def __init__(self, settings=None): super(PythonCodePrinter, self).__init__(settings) + self.module_imports = defaultdict(set) self.known_functions = dict(self._kf, **(settings or {}).get( 'user_functions', {})) + self.known_constants = dict(self._kc, **(settings or {}).get( + 'user_constants', {})) + + def _declare_number_const(self, name, value): + return "%s = %s" % (name, value) + + def _module_format(self, fqn, register=True): + parts = fqn.split('.') + if register and len(parts) > 1: + self.module_imports['.'.join(parts[:-1])].add(parts[-1]) + + if self._settings['fully_qualified_modules']: + return fqn + else: + return fqn.split('(')[0].split('[')[0].split('.')[-1] def _format_code(self, lines): return lines @@ -35,22 +115,316 @@ def _format_code(self, lines): def _get_comment(self, text): return " # {0}".format(text) + def _print_NaN(self, expr): + return "float('nan')" + + def _print_Infinity(self, expr): + return "float('inf')" + def _print_Mod(self, expr): PREC = precedence(expr) return ('{0} % {1}'.format(*map(lambda x: self.parenthesize(x, PREC), expr.args))) def _print_Piecewise(self, expr): - lines = [] - for i, (e, c) in enumerate(expr.args): - if i == 0: - lines.append("if %s:" % self._print(c)) - elif i == len(expr.args) - 1 and c == True: - lines.append('else:') - else: - lines.append('elif %s:' % self._print(c)) - lines.append(self.tab + 'return ' + self._print(e)) - if i == len(expr.args) - 1 and c != True: - lines.append('else:') - lines.append('%sraise NotImplementedError("Unhandled condition in: %s")' % ( - self.tab, expr)) - return '\n'.join(lines) + result = [] + i = 0 + for arg in expr.args: + e = arg.expr + c = arg.cond + result.append('((') + result.append(self._print(e)) + result.append(') if (') + result.append(self._print(c)) + result.append(') else (') + i += 1 + result = result[:-1] + result.append(') else None)') + result.append(')'*(2*i - 2)) + return ''.join(result) + + def _print_ITE(self, expr): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def _print_Sum(self, expr): + loops = ( + 'for {i} in range({a}, {b}+1)'.format( + i=self._print(i), + a=self._print(a), + b=self._print(b)) + for i, a, b in expr.limits) + return '(builtins.sum({function} {loops}))'.format( + function=self._print(expr.function), + loops=' '.join(loops)) + + def _print_ImaginaryUnit(self, expr): + return '1j' + + def _print_MatrixBase(self, expr): + name = expr.__class__.__name__ + func = self.known_functions.get(name, name) + return "%s(%s)" % (func, self._print(expr.tolist())) + + _print_SparseMatrix = \ + _print_MutableSparseMatrix = \ + _print_ImmutableSparseMatrix = \ + _print_Matrix = \ + _print_DenseMatrix = \ + _print_MutableDenseMatrix = \ + _print_ImmutableMatrix = \ + _print_ImmutableDenseMatrix = \ + lambda self, expr: self._print_MatrixBase(expr) + + +for k in PythonCodePrinter._kf: + setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func) + +for k in _known_constants_math: + setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const) + + +def pycode(expr, **settings): + return PythonCodePrinter(settings).doprint(expr) + + +_not_in_mpmath = 'log1p log2'.split() +_in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath] +_known_functions_mpmath = dict(_in_mpmath) +_known_constants_mpmath = { + 'Pi': 'pi' +} + + +class MpmathPrinter(PythonCodePrinter): + """ + Lambda printer for mpmath which maintains precision for floats + """ + printmethod = "_mpmathcode" + + _kf = dict(chain( + PythonCodePrinter._kf.items(), + [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()] + )) + + def doprint(self, expr, **kwargs): + from sympy.functions import log + from sympy.codegen.cfunctions import log2, log1p + expr = expr.replace(log2, lambda arg: arg.rewrite(log)) + expr = expr.replace(log1p, lambda arg: arg.rewrite(log)) + return super(MpmathPrinter, self).doprint(expr, **kwargs) + + def _print_Integer(self, e): + return '%s(%d)' % (self._module_format('mpmath.mpf'), e) + + def _print_Float(self, e): + # XXX: This does not handle setting mpmath.mp.dps. It is assumed that + # the caller of the lambdified function will have set it to sufficient + # precision to match the Floats in the expression. + + # Remove 'mpz' if gmpy is installed. + args = str(tuple(map(int, e._mpf_))) + return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args) + + + def _print_uppergamma(self,e): #printer for the uppergamma function + return "{0}({1}, {2}, {3})".format( + self._module_format('mpmath.gammainc'), self._print(e.args[0]), self._print(e.args[1]), + self._module_format('mpmath.inf')) + + def _print_lowergamma(self,e): #printer for the lowergamma functioin + return "{0}({1}, 0, {2})".format( + self._module_format('mpmath.gammainc'), self._print(e.args[0]), self._print(e.args[1])) + +for k in MpmathPrinter._kf: + setattr(MpmathPrinter, '_print_%s' % k, _print_known_func) + +for k in _known_constants_mpmath: + setattr(MpmathPrinter, '_print_%s' % k, _print_known_const) + + +_not_in_numpy = 'erf erfc factorial gamma lgamma'.split() +_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy] +_known_functions_numpy = dict(_in_numpy, **{ + 'acos': 'arccos', + 'acosh': 'arccosh', + 'asin': 'arcsin', + 'asinh': 'arcsinh', + 'atan': 'arctan', + 'atan2': 'arctan2', + 'atanh': 'arctanh', + 'exp2': 'exp2', +}) + + +class NumPyPrinter(PythonCodePrinter): + """ + Numpy printer which handles vectorized piecewise functions, + logical operators, etc. + """ + printmethod = "_numpycode" + + _kf = dict(chain( + PythonCodePrinter._kf.items(), + [(k, 'numpy.' + v) for k, v in _known_functions_numpy.items()] + )) + _kc = {k: 'numpy.'+v for k, v in _known_constants_math.items()} + + + def _print_seq(self, seq, delimiter=', '): + "General sequence printer: converts to tuple" + # Print tuples here instead of lists because numba supports + # tuples in nopython mode. + return '({},)'.format(delimiter.join(self._print(item) for item in seq)) + + def _print_MatMul(self, expr): + "Matrix multiplication printer" + return '({0})'.format(').dot('.join(self._print(i) for i in expr.args)) + + def _print_DotProduct(self, expr): + # DotProduct allows any shape order, but numpy.dot does matrix + # multiplication, so we have to make sure it gets 1 x n by n x 1. + arg1, arg2 = expr.args + if arg1.shape[0] != 1: + arg1 = arg1.T + if arg2.shape[1] != 1: + arg2 = arg2.T + + return "%s(%s, %s)" % (self._module_format('numpy.dot'), self._print(arg1), self._print(arg2)) + + def _print_Piecewise(self, expr): + "Piecewise function printer" + exprs = '[{0}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) + conds = '[{0}]'.format(','.join(self._print(arg.cond) for arg in expr.args)) + # If [default_value, True] is a (expr, cond) sequence in a Piecewise object + # it will behave the same as passing the 'default' kwarg to select() + # *as long as* it is the last element in expr.args. + # If this is not the case, it may be triggered prematurely. + return '{0}({1}, {2}, default=numpy.nan)'.format(self._module_format('numpy.select'), conds, exprs) + + def _print_Relational(self, expr): + "Relational printer for Equality and Unequality" + op = { + '==' :'equal', + '!=' :'not_equal', + '<' :'less', + '<=' :'less_equal', + '>' :'greater', + '>=' :'greater_equal', + } + if expr.rel_op in op: + lhs = self._print(expr.lhs) + rhs = self._print(expr.rhs) + return '{op}({lhs}, {rhs})'.format(op=self._module_format('numpy.'+op[expr.rel_op]), + lhs=lhs, rhs=rhs) + return super(NumPyPrinter, self)._print_Relational(expr) + + def _print_And(self, expr): + "Logical And printer" + # We have to override LambdaPrinter because it uses Python 'and' keyword. + # If LambdaPrinter didn't define it, we could use StrPrinter's + # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS. + return '{0}({1})'.format(self._module_format('numpy.logical_and'), ','.join(self._print(i) for i in expr.args)) + + def _print_Or(self, expr): + "Logical Or printer" + # We have to override LambdaPrinter because it uses Python 'or' keyword. + # If LambdaPrinter didn't define it, we could use StrPrinter's + # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS. + return '{0}({1})'.format(self._module_format('numpy.logical_or'), ','.join(self._print(i) for i in expr.args)) + + def _print_Not(self, expr): + "Logical Not printer" + # We have to override LambdaPrinter because it uses Python 'not' keyword. + # If LambdaPrinter didn't define it, we would still have to define our + # own because StrPrinter doesn't define it. + return '{0}({1})'.format(self._module_format('numpy.logical_not'), ','.join(self._print(i) for i in expr.args)) + + def _print_Min(self, expr): + return '{0}(({1}))'.format(self._module_format('numpy.amin'), ','.join(self._print(i) for i in expr.args)) + + def _print_Max(self, expr): + return '{0}(({1}))'.format(self._module_format('numpy.amax'), ','.join(self._print(i) for i in expr.args)) + + def _print_Pow(self, expr): + if expr.exp == 0.5: + return '{0}({1})'.format(self._module_format('numpy.sqrt'), self._print(expr.base)) + else: + return super(NumPyPrinter, self)._print_Pow(expr) + + def _print_arg(self, expr): + return "%s(%s)" % (self._module_format('numpy.angle'), self._print(expr.args[0])) + + def _print_im(self, expr): + return "%s(%s)" % (self._module_format('numpy.imag', self._print(expr.args[0]))) + + def _print_Mod(self, expr): + return "%s(%s)" % (self._module_format('numpy.mod'), ', '.join(map(self._print, expr.args))) + + def _print_re(self, expr): + return "%s(%s)" % (self._module_format('numpy.real'), self._print(expr.args[0])) + + def _print_MatrixBase(self, expr): + func = self.known_functions.get(expr.__class__.__name__, None) + if func is None: + func = self._module_format('numpy.array') + return "%s(%s)" % (func, self._print(expr.tolist())) + + +for k in NumPyPrinter._kf: + setattr(NumPyPrinter, '_print_%s' % k, _print_known_func) + +for k in NumPyPrinter._kc: + setattr(NumPyPrinter, '_print_%s' % k, _print_known_const) + + +_known_functions_scipy_special = { + 'erf': 'erf', + 'erfc': 'erfc', + 'gamma': 'gamma', + 'loggamma': 'gammaln' +} +_known_constants_scipy_constants = { + 'GoldenRatio': 'golden_ratio' +} + +class SciPyPrinter(NumPyPrinter): + + _kf = dict(chain( + NumPyPrinter._kf.items(), + [(k, 'scipy.special.' + v) for k, v in _known_functions_scipy_special.items()] + )) + _kc = {k: 'scipy.constants.' + v for k, v in _known_constants_scipy_constants.items()} + + def _print_SparseMatrix(self, expr): + i, j, data = [], [], [] + for (r, c), v in expr._smat.items(): + i.append(r) + j.append(c) + data.append(v) + + return "{name}({data}, ({i}, {j}), shape={shape})".format( + name=self._module_format('scipy.sparse.coo_matrix'), + data=data, i=i, j=j, shape=expr.shape + ) + + _print_ImmutableSparseMatrix = _print_SparseMatrix + + +for k in SciPyPrinter._kf: + setattr(SciPyPrinter, '_print_%s' % k, _print_known_func) + +for k in SciPyPrinter._kc: + setattr(SciPyPrinter, '_print_%s' % k, _print_known_const) + + +class SymPyPrinter(PythonCodePrinter): + + _kf = dict([(k, 'sympy.' + v) for k, v in chain( + _known_functions.items(), + _known_functions_math.items() + )]) + + def _print_Function(self, expr): + mod = expr.func.__module__ or '' + return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__), + ', '.join(map(self._print, expr.args))) diff --git a/sympy/printing/tests/test_lambdarepr.py b/sympy/printing/tests/test_lambdarepr.py index 4c436a4740d6..8da54af1eb11 100644 --- a/sympy/printing/tests/test_lambdarepr.py +++ b/sympy/printing/tests/test_lambdarepr.py @@ -1,7 +1,8 @@ -from sympy import symbols, sin, Matrix, Interval, Piecewise, Sum, lambdify +from sympy import symbols, sin, Matrix, Interval, Piecewise, Sum, lambdify,Expr from sympy.utilities.pytest import raises -from sympy.printing.lambdarepr import lambdarepr +from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, TensorflowPrinter, NumExprPrinter + x, y, z = symbols("x,y,z") i, a, b = symbols("i,a,b") @@ -147,25 +148,27 @@ def test_piecewise(): "(True) else None)))))))))))" -def test_sum(): +def test_sum__1(): # In each case, test eval() the lambdarepr() to make sure that # it evaluates to the same results as the symbolic expression - s = Sum(x ** i, (i, a, b)) - l = lambdarepr(s) assert l == "(builtins.sum(x**i for i in range(a, b+1)))" - assert (lambdify((x, a, b), s)(2, 3, 8) == - s.subs([(x, 2), (a, 3), (b, 8)]).doit()) + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() +def test_sum__2(): s = Sum(i * x, (i, a, b)) - l = lambdarepr(s) assert l == "(builtins.sum(i*x for i in range(a, b+1)))" - assert (lambdify((x, a, b), s)(2, 3, 8) == - s.subs([(x, 2), (a, 3), (b, 8)]).doit()) + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() def test_multiple_sums(): @@ -174,8 +177,40 @@ def test_multiple_sums(): l = lambdarepr(s) assert l == "(builtins.sum(i*x + j for i in range(a, b+1) for j in range(c, d+1)))" - assert (lambdify((x, a, b, c, d), s)(2, 3, 4, 5, 6) == - s.subs([(x, 2), (a, 3), (b, 4), (c, 5), (d, 6)]).doit()) + args = x, a, b, c, d + f = lambdify(args, s) + vals = 2, 3, 4, 5, 6 + f_ref = s.subs(zip(args, vals)).doit() + f_res = f(*vals) + assert f_res == f_ref + def test_settings(): raises(TypeError, lambda: lambdarepr(sin(x), method="garbage")) + + +class CustomPrintedObject(Expr): + def _lambdacode(self, printer): + return 'lambda' + + def _tensorflowcode(self, printer): + return 'tensorflow' + + def _numpycode(self, printer): + return 'numpy' + + def _numexprcode(self, printer): + return 'numexpr' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + # In each case, printmethod is called to test + # its working + + obj = CustomPrintedObject() + assert LambdaPrinter().doprint(obj) == 'lambda' + assert TensorflowPrinter().doprint(obj) == 'tensorflow' + assert NumExprPrinter().doprint(obj) == "evaluate('numexpr', truediv=True)" diff --git a/sympy/printing/tests/test_numpy.py b/sympy/printing/tests/test_numpy.py index 3f7a68456edd..f206e84aa5a8 100644 --- a/sympy/printing/tests/test_numpy.py +++ b/sympy/printing/tests/test_numpy.py @@ -15,7 +15,7 @@ def test_numpy_piecewise_regression(): See gh-9747 and gh-9749 for details. """ p = Piecewise((1, x < 0), (0, True)) - assert NumPyPrinter().doprint(p) == 'select([less(x, 0),True], [1,0], default=nan)' + assert NumPyPrinter().doprint(p) == 'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)' def test_sum(): diff --git a/sympy/printing/tests/test_pycode.py b/sympy/printing/tests/test_pycode.py index 641cc3ba892b..5e632c7fe0b0 100644 --- a/sympy/printing/tests/test_pycode.py +++ b/sympy/printing/tests/test_pycode.py @@ -1,31 +1,59 @@ # -*- coding: utf-8 -*- from __future__ import (absolute_import, division, print_function) -from sympy.core import Mod, symbols +from sympy.core import Expr, Mod, symbols +from sympy.core.numbers import pi from sympy.logic import And, Or -from sympy.functions import Piecewise -from sympy.printing.pycode import PythonCodePrinter +from sympy.functions import acos +from sympy.matrices import SparseMatrix +from sympy.printing.pycode import ( + MpmathPrinter, NumPyPrinter, PythonCodePrinter, pycode, SciPyPrinter +) +from sympy.utilities.pytest import raises x, y, z = symbols('x y z') def test_PythonCodePrinter(): prntr = PythonCodePrinter() + assert not prntr.module_imports assert prntr.doprint(x**y) == 'x**y' assert prntr.doprint(Mod(x, 2)) == 'x % 2' assert prntr.doprint(And(x, y)) == 'x and y' assert prntr.doprint(Or(x, y)) == 'x or y' - assert prntr.doprint(Piecewise((x, x > 1), (y, True))) == ( - 'if x > 1:\n' - ' return x\n' - 'else:\n' - ' return y' - ) - pw = Piecewise((x, x > 1), (y, x > 0)) - assert prntr.doprint(pw) == ( - 'if x > 1:\n' - ' return x\n' - 'elif x > 0:\n' - ' return y\n' - 'else:\n' - ' raise NotImplementedError("Unhandled condition in: %s")' % pw - ) + assert not prntr.module_imports + assert prntr.doprint(pi) == 'math.pi' + assert prntr.module_imports == {'math': {'pi'}} + assert prntr.doprint(acos(x)) == 'math.acos(x)' + + +def test_SciPyPrinter(): + p = SciPyPrinter() + expr = acos(x) + assert 'numpy' not in p.module_imports + assert p.doprint(expr) == 'numpy.arccos(x)' + assert 'numpy' in p.module_imports + assert not any(m.startswith('scipy') for m in p.module_imports) + smat = SparseMatrix(2, 5, {(0, 1): 3}) + assert p.doprint(smat) == 'scipy.sparse.coo_matrix([3], ([0], [1]), shape=(2, 5))' + assert 'scipy.sparse' in p.module_imports + + +def test_pycode_reserved_words(): + s1, s2 = symbols('if else') + raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True)) + py_str = pycode(s1 + s2) + assert py_str in ('else_ + if_', 'if_ + else_') + + +class CustomPrintedObject(Expr): + def _numpycode(self, printer): + return 'numpy' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + obj = CustomPrintedObject() + assert NumPyPrinter().doprint(obj) == 'numpy' + assert MpmathPrinter().doprint(obj) == 'mpmath' diff --git a/sympy/utilities/lambdify.py b/sympy/utilities/lambdify.py index 6d962a954358..5056cc60c2af 100644 --- a/sympy/utilities/lambdify.py +++ b/sympy/utilities/lambdify.py @@ -66,28 +66,7 @@ "Ci": "ci" } -NUMPY_TRANSLATIONS = { - "acos": "arccos", - "acosh": "arccosh", - "arg": "angle", - "asin": "arcsin", - "asinh": "arcsinh", - "atan": "arctan", - "atan2": "arctan2", - "atanh": "arctanh", - "ceiling": "ceil", - "E": "e", - "im": "imag", - "ln": "log", - "Mod": "mod", - "oo": "inf", - "re": "real", - "SparseMatrix": "array", - "ImmutableSparseMatrix": "array", - "Matrix": "array", - "MutableDenseMatrix": "array", - "ImmutableDenseMatrix": "array", -} +NUMPY_TRANSLATIONS = {} TENSORFLOW_TRANSLATIONS = { "Abs": "abs", @@ -105,7 +84,7 @@ MODULES = { "math": (MATH, MATH_DEFAULT, MATH_TRANSLATIONS, ("from math import *",)), "mpmath": (MPMATH, MPMATH_DEFAULT, MPMATH_TRANSLATIONS, ("from mpmath import *",)), - "numpy": (NUMPY, NUMPY_DEFAULT, NUMPY_TRANSLATIONS, ("import_module('numpy')",)), + "numpy": (NUMPY, NUMPY_DEFAULT, NUMPY_TRANSLATIONS, ("import numpy; from numpy import *",)), "tensorflow": (TENSORFLOW, TENSORFLOW_DEFAULT, TENSORFLOW_TRANSLATIONS, ("import_module('tensorflow')",)), "sympy": (SYMPY, SYMPY_DEFAULT, {}, ( "from sympy.functions import *", @@ -386,21 +365,26 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True, for term in syms: namespace.update({str(term): term}) - if _module_present('mpmath',namespaces) and printer is None: - #XXX: This has to be done here because of circular imports - from sympy.printing.lambdarepr import MpmathPrinter as printer - - if _module_present('numpy',namespaces) and printer is None: - #XXX: This has to be done here because of circular imports - from sympy.printing.lambdarepr import NumPyPrinter as printer - - if _module_present('numexpr',namespaces) and printer is None: - #XXX: This has to be done here because of circular imports - from sympy.printing.lambdarepr import NumExprPrinter as printer - - if _module_present('tensorflow',namespaces) and printer is None: - #XXX: This has to be done here because of circular imports - from sympy.printing.lambdarepr import TensorflowPrinter as printer + if printer is None: + if _module_present('mpmath', namespaces): + from sympy.printing.pycode import MpmathPrinter as Printer + elif _module_present('numpy', namespaces): + from sympy.printing.pycode import NumPyPrinter as Printer + elif _module_present('numexpr', namespaces): + from sympy.printing.lambdarepr import NumExprPrinter as Printer + elif _module_present('tensorflow', namespaces): + from sympy.printing.lambdarepr import TensorflowPrinter as Printer + elif _module_present('sympy', namespaces): + from sympy.printing.pycode import SymPyPrinter as Printer + else: + from sympy.printing.pycode import PythonCodePrinter as Printer + user_functions = {} + for m in namespaces[::-1]: + if isinstance(m, dict): + for k in m: + user_functions[k] = k + printer = Printer({'fully_qualified_modules': False, 'inline': True, + 'user_functions': user_functions}) # Get the names of the args, for creating a docstring if not iterable(args): @@ -424,6 +408,13 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True, # Create lambda function. lstr = lambdastr(args, expr, printer=printer, dummify=dummify) flat = '__flatten_args__' + imp_mod_lines = [] + for mod, keys in (getattr(printer, 'module_imports', None) or {}).items(): + for k in keys: + if k not in namespace: + imp_mod_lines.append("from %s import %s" % (mod, k)) + for ln in imp_mod_lines: + exec_(ln, {}, namespace) if flat in lstr: namespace.update({flat: flatten}) @@ -450,8 +441,16 @@ def wrapper(*argsx, **kwargsx): expr_str = str(expr) if len(expr_str) > 78: expr_str = textwrap.wrap(expr_str, 75)[0] + '...' - func.__doc__ = ("Created with lambdify. Signature:\n\n{sig}\n\n" - "Expression:\n\n{expr}").format(sig=sig, expr=expr_str) + func.__doc__ = ( + "Created with lambdify. Signature:\n\n" + "{sig}\n\n" + "Expression:\n\n" + "{expr}\n\n" + "Source code:\n\n" + "{src}\n\n" + "Imported modules:\n\n" + "{imp_mods}" + ).format(sig=sig, expr=expr_str, src=lstr, imp_mods='\n'.join(imp_mod_lines)) return func def _module_present(modname, modlist): diff --git a/sympy/utilities/tests/test_lambdify.py b/sympy/utilities/tests/test_lambdify.py index c92cc1dfac1c..2c5d5fee525b 100644 --- a/sympy/utilities/tests/test_lambdify.py +++ b/sympy/utilities/tests/test_lambdify.py @@ -53,13 +53,21 @@ def test_str_args(): raises(TypeError, lambda: f(0)) -def test_own_namespace(): +def test_own_namespace_1(): myfunc = lambda x: 1 f = lambdify(x, sin(x), {"sin": myfunc}) assert f(0.1) == 1 assert f(100) == 1 +def test_own_namespace_2(): + def myfunc(x): + return 1 + f = lambdify(x, sin(x), {'sin': myfunc}) + assert f(0.1) == 1 + assert f(100) == 1 + + def test_own_module(): f = lambdify(x, sin(x), math) assert f(0) == 0.0 @@ -702,19 +710,23 @@ def test_python_keywords(): def test_lambdify_docstring(): func = lambdify((w, x, y, z), w + x + y + z) - assert func.__doc__ == ( - "Created with lambdify. Signature:\n\n" - "func(w, x, y, z)\n\n" - "Expression:\n\n" - "w + x + y + z") + ref = ( + "Created with lambdify. Signature:\n\n" + "func(w, x, y, z)\n\n" + "Expression:\n\n" + "w + x + y + z" + ).splitlines() + assert func.__doc__.splitlines()[:len(ref)] == ref syms = symbols('a1:26') func = lambdify(syms, sum(syms)) - assert func.__doc__ == ( - "Created with lambdify. Signature:\n\n" - "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n" - " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n" - "Expression:\n\n" - "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +...") + ref = ( + "Created with lambdify. Signature:\n\n" + "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n" + " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n" + "Expression:\n\n" + "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..." + ).splitlines() + assert func.__doc__.splitlines()[:len(ref)] == ref #================== Test special printers ==========================