Skip to content

Commit

Permalink
Merge pull request #13046 from bjodah/pythonprinter2
Browse files Browse the repository at this point in the history
Further work on PythonCodePrinter
  • Loading branch information
bjodah committed Aug 9, 2017
2 parents accb35d + d4a6198 commit 8ff744d
Show file tree
Hide file tree
Showing 13 changed files with 585 additions and 298 deletions.
4 changes: 4 additions & 0 deletions sympy/logic/boolalg.py
Expand Up @@ -980,6 +980,10 @@ def to_nnf(self, simplify=True):
def _eval_derivative(self, x):
return self.func(self.args[0], *[a.diff(x) for a in self.args[1:]])

def _eval_rewrite_as_Piecewise(self, *args):
from sympy.functions import Piecewise
return Piecewise((args[1], args[0]), (args[2], True))

# the diff method below is copied from Expr class
def diff(self, *symbols, **assumptions):
new_symbols = list(map(sympify, symbols)) # e.g. x, 2, y, z
Expand Down
5 changes: 5 additions & 0 deletions sympy/logic/tests/test_boolalg.py
Expand Up @@ -5,6 +5,7 @@
from sympy.core.relational import Equality
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, symbols)
from sympy.functions import Piecewise
from sympy.sets.sets import (EmptySet, Interval, Union)
from sympy.simplify.simplify import simplify
from sympy.logic.boolalg import (
Expand Down Expand Up @@ -504,6 +505,10 @@ def test_ITE():
assert ITE(x, A, B) == Not(x)
assert ITE(x, B, A) == x

def test_ITE_rewrite_Piecewise():

assert ITE(A, B, C).rewrite(Piecewise) == Piecewise((B, A), (C, True))


def test_ITE_diff():
# analogous to Piecewise.diff
Expand Down
1 change: 1 addition & 0 deletions sympy/printing/__init__.py
Expand Up @@ -5,6 +5,7 @@
from .latex import latex, print_latex
from .mathml import mathml, print_mathml
from .python import python, print_python
from .pycode import pycode
from .ccode import ccode, print_ccode
from .glsl import glsl_code, print_glsl
from .cxxcode import cxxcode
Expand Down
26 changes: 16 additions & 10 deletions sympy/printing/codeprinter.py
Expand Up @@ -31,11 +31,14 @@ class CodePrinter(StrPrinter):
'not': '!',
}

_default_settings = {'order': None,
'full_prec': 'auto',
'error_on_reserved': False,
'reserved_word_suffix': '_',
'human': True}
_default_settings = {
'order': None,
'full_prec': 'auto',
'error_on_reserved': False,
'reserved_word_suffix': '_',
'human': True,
'inline': False
}

def __init__(self, settings=None):

Expand Down Expand Up @@ -337,11 +340,14 @@ def _print_Function(self, expr):
_print_Expr = _print_Function

def _print_NumberSymbol(self, expr):
# A Number symbol that is not implemented here or with _printmethod
# is registered and evaluated
self._number_symbols.add((expr,
self._print(expr.evalf(self._settings["precision"]))))
return str(expr)
if self._settings.get("inline", False):
return self._print(expr.evalf(self._settings["precision"]))
else:
# A Number symbol that is not implemented here or with _printmethod
# is registered and evaluated
self._number_symbols.add((expr,
self._print(expr.evalf(self._settings["precision"]))))
return str(expr)

def _print_Catalan(self, expr):
return self._print_NumberSymbol(expr)
Expand Down
8 changes: 0 additions & 8 deletions sympy/printing/julia.py
Expand Up @@ -246,14 +246,6 @@ def _print_GoldenRatio(self, expr):
return super(JuliaCodePrinter, self)._print_NumberSymbol(expr)


def _print_NumberSymbol(self, expr):
if self._settings["inline"]:
return self._print(expr.evalf(self._settings["precision"]))
else:
# assign to a variable, perhaps more readable for longer program
return super(JuliaCodePrinter, self)._print_NumberSymbol(expr)


def _print_Assignment(self, expr):
from sympy.functions.elementary.piecewise import Piecewise
from sympy.tensor.indexed import IndexedBase
Expand Down
183 changes: 11 additions & 172 deletions sympy/printing/lambdarepr.py
@@ -1,7 +1,11 @@
from __future__ import print_function, division

from .str import StrPrinter
from .pycode import PythonCodePrinter
from .pycode import (
PythonCodePrinter,
MpmathPrinter, # MpmathPrinter is imported for backward compatibility
NumPyPrinter # NumPyPrinter is imported for backward compatibility
)
from sympy.utilities import default_sort_key


Expand All @@ -10,48 +14,8 @@ class LambdaPrinter(PythonCodePrinter):
This printer converts expressions into strings that can be used by
lambdify.
"""
printmethod = "_lambdacode"

def _print_MatrixBase(self, expr):
return "%s(%s)" % (expr.__class__.__name__,
self._print(expr.tolist()))

_print_SparseMatrix = \
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
_print_MatrixBase

def _print_Piecewise(self, expr):
result = []
i = 0
for arg in expr.args:
e = arg.expr
c = arg.cond
result.append('((')
result.append(self._print(e))
result.append(') if (')
result.append(self._print(c))
result.append(') else (')
i += 1
result = result[:-1]
result.append(') else None)')
result.append(')'*(2*i - 2))
return ''.join(result)

def _print_Sum(self, expr):
loops = (
'for {i} in range({a}, {b}+1)'.format(
i=self._print(i),
a=self._print(a),
b=self._print(b))
for i, a, b in expr.limits)
return '(builtins.sum({function} {loops}))'.format(
function=self._print(expr.function),
loops=' '.join(loops))

def _print_And(self, expr):
result = ['(']
Expand Down Expand Up @@ -98,6 +62,7 @@ class TensorflowPrinter(LambdaPrinter):
Tensorflow printer which handles vectorized piecewise functions,
logical operators, max/min, and relational operators.
"""
printmethod = "_tensorflowcode"

def _print_And(self, expr):
"Logical And printer"
Expand Down Expand Up @@ -171,121 +136,13 @@ def _print_Relational(self, expr):
return super(TensorflowPrinter, self)._print_Relational(expr)


class NumPyPrinter(LambdaPrinter):
"""
Numpy printer which handles vectorized piecewise functions,
logical operators, etc.
"""

def _print_seq(self, seq, delimiter=', '):
"General sequence printer: converts to tuple"
# Print tuples here instead of lists because numba supports
# tuples in nopython mode.
return '({},)'.format(delimiter.join(self._print(item) for item in seq))

def _print_MatMul(self, expr):
"Matrix multiplication printer"
return '({0})'.format(').dot('.join(self._print(i) for i in expr.args))

def _print_DotProduct(self, expr):
# DotProduct allows any shape order, but numpy.dot does matrix
# multiplication, so we have to make sure it gets 1 x n by n x 1.
arg1, arg2 = expr.args
if arg1.shape[0] != 1:
arg1 = arg1.T
if arg2.shape[1] != 1:
arg2 = arg2.T

return "dot(%s, %s)" % (self._print(arg1), self._print(arg2))

def _print_Piecewise(self, expr):
"Piecewise function printer"
exprs = '[{0}]'.format(','.join(self._print(arg.expr) for arg in expr.args))
conds = '[{0}]'.format(','.join(self._print(arg.cond) for arg in expr.args))
# If [default_value, True] is a (expr, cond) sequence in a Piecewise object
# it will behave the same as passing the 'default' kwarg to select()
# *as long as* it is the last element in expr.args.
# If this is not the case, it may be triggered prematurely.
return 'select({0}, {1}, default=nan)'.format(conds, exprs)

def _print_Relational(self, expr):
"Relational printer for Equality and Unequality"
op = {
'==' :'equal',
'!=' :'not_equal',
'<' :'less',
'<=' :'less_equal',
'>' :'greater',
'>=' :'greater_equal',
}
if expr.rel_op in op:
lhs = self._print(expr.lhs)
rhs = self._print(expr.rhs)
return '{op}({lhs}, {rhs})'.format(op=op[expr.rel_op],
lhs=lhs,
rhs=rhs)
return super(NumPyPrinter, self)._print_Relational(expr)

def _print_And(self, expr):
"Logical And printer"
# We have to override LambdaPrinter because it uses Python 'and' keyword.
# If LambdaPrinter didn't define it, we could use StrPrinter's
# version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
return '{0}({1})'.format('logical_and', ','.join(self._print(i) for i in expr.args))

def _print_Or(self, expr):
"Logical Or printer"
# We have to override LambdaPrinter because it uses Python 'or' keyword.
# If LambdaPrinter didn't define it, we could use StrPrinter's
# version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
return '{0}({1})'.format('logical_or', ','.join(self._print(i) for i in expr.args))

def _print_Not(self, expr):
"Logical Not printer"
# We have to override LambdaPrinter because it uses Python 'not' keyword.
# If LambdaPrinter didn't define it, we would still have to define our
# own because StrPrinter doesn't define it.
return '{0}({1})'.format('logical_not', ','.join(self._print(i) for i in expr.args))

def _print_Min(self, expr):
return '{0}(({1}))'.format('amin', ','.join(self._print(i) for i in expr.args))

def _print_Max(self, expr):
return '{0}(({1}))'.format('amax', ','.join(self._print(i) for i in expr.args))

def _print_Pow(self, expr):
if expr.exp == 0.5:
return '{0}({1})'.format('sqrt', self._print(expr.base))
else:
return super(NumPyPrinter, self)._print_Pow(expr)

def _print_log10(self, expr): # log10 in C89, but type-generic macro in C99
return 'log10({0})'.format(self._print(expr.args[0]))

def _print_Sqrt(self, expr):
return 'sqrt({0})'.format(self._print(expr.args[0]))

def _print_hypot(self, expr):
return 'hypot({0}, {1})'.format(*map(self._print, expr.args))

def _print_expm1(self, expr):
return 'expm1({0})'.format(self._print(expr.args[0]))

def _print_log1p(self, expr):
return 'log1p({0})'.format(self._print(expr.args[0]))

def _print_exp2(self, expr):
return 'exp2({0})'.format(self._print(expr.args[0]))

def _print_log2(self, expr):
return 'log2({0})'.format(self._print(expr.args[0]))


# numexpr works by altering the string passed to numexpr.evaluate
# rather than by populating a namespace. Thus a special printer...
class NumExprPrinter(LambdaPrinter):
# key, value pairs correspond to sympy name and numexpr name
# functions not appearing in this dict will raise a TypeError
printmethod = "_numexprcode"

_numexpr_functions = {
'sin' : 'sin',
'cos' : 'cos',
Expand Down Expand Up @@ -363,27 +220,9 @@ def doprint(self, expr):
lstr = super(NumExprPrinter, self).doprint(expr)
return "evaluate('%s', truediv=True)" % lstr

class MpmathPrinter(LambdaPrinter):
"""
Lambda printer for mpmath which maintains precision for floats
"""
def _print_Integer(self, e):
return 'mpf(%d)' % e

def _print_Float(self, e):
# XXX: This does not handle setting mpmath.mp.dps. It is assumed that
# the caller of the lambdified function will have set it to sufficient
# precision to match the Floats in the expression.

# Remove 'mpz' if gmpy is installed.
args = str(tuple(map(int, e._mpf_)))
return 'mpf(%s)' % args

def _print_uppergamma(self,e): #printer for the uppergamma function
return "gammainc({0}, {1}, inf)".format(self._print(e.args[0]), self._print(e.args[1]))

def _print_lowergamma(self,e): #printer for the lowergamma functioin
return "gammainc({0}, 0, {1})".format(self._print(e.args[0]), self._print(e.args[1]))
for k in NumExprPrinter._numexpr_functions:
setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function)

def lambdarepr(expr, **settings):
"""
Expand Down
8 changes: 0 additions & 8 deletions sympy/printing/octave.py
Expand Up @@ -231,14 +231,6 @@ def _print_GoldenRatio(self, expr):
return "(1+sqrt(5))/2"


def _print_NumberSymbol(self, expr):
if self._settings["inline"]:
return self._print(expr.evalf(self._settings["precision"]))
else:
# assign to a variable, perhaps more readable for longer program
return super(OctaveCodePrinter, self)._print_NumberSymbol(expr)


def _print_Assignment(self, expr):
from sympy.functions.elementary.piecewise import Piecewise
from sympy.tensor.indexed import IndexedBase
Expand Down

0 comments on commit 8ff744d

Please sign in to comment.