Skip to content

Commit

Permalink
Adds NumPyPrinter to lambdify for dealing with piecewise functions, m…
Browse files Browse the repository at this point in the history
…atrix multiplication and logical operators. Ends deprecation cycle started by sympy/sympy#7853 so that ImmutableMatrix now maps to numpy.array by default.
  • Loading branch information
richardotis authored and skirpichev committed Sep 30, 2015
1 parent de21708 commit 3e4fbed
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 46 deletions.
54 changes: 54 additions & 0 deletions sympy/printing/lambdarepr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,60 @@ def _print_BooleanTrue(self, expr):
def _print_BooleanFalse(self, expr):
return "False"


class NumPyPrinter(LambdaPrinter):
"""
Numpy printer which handles vectorized piecewise functions,
logical operators, etc.
"""
_default_settings = {
"order": "none",
"full_prec": "auto",
}

def _print_seq(self, seq, delimiter=', '):
"General sequence printer: converts to tuple"
# Print tuples here instead of lists because numba supports
# tuples in nopython mode.
return '({},)'.format(delimiter.join(self._print(item) for item in seq))

def _print_MatMul(self, expr):
"Matrix multiplication printer"
return '({0})'.format(').dot('.join(self._print(i) for i in expr.args))

def _print_Piecewise(self, expr):
"Piecewise function printer"
# Print tuples here instead of lists because numba may add support
# for select in nopython mode; see numba#1313 on github.
exprs = '({0},)'.format(','.join(self._print(arg.expr) for arg in expr.args))
conds = '({0},)'.format(','.join(self._print(arg.cond) for arg in expr.args))
# If (default_value, True) is a (expr, cond) tuple in a Piecewise object
# it will behave the same as passing the 'default' kwarg to select()
# *as long as* it is the last element in expr.args.
# If this is not the case, it may be triggered prematurely.
return 'select({0}, {1})'.format(conds, exprs)

def _print_And(self, expr):
"Logical And printer"
# We have to override LambdaPrinter because it uses Python 'and' keyword.
# If LambdaPrinter didn't define it, we could use StrPrinter's
# version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
return '{0}({1})'.format('logical_and', ','.join(self._print(i) for i in expr.args))

def _print_Or(self, expr):
"Logical Or printer"
# We have to override LambdaPrinter because it uses Python 'or' keyword.
# If LambdaPrinter didn't define it, we could use StrPrinter's
# version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
return '{0}({1})'.format('logical_or', ','.join(self._print(i) for i in expr.args))

def _print_Not(self, expr):
"Logical Not printer"
# We have to override LambdaPrinter because it uses Python 'not' keyword.
# If LambdaPrinter didn't define it, we would still have to define our
# own because StrPrinter doesn't define it.
return '{0}({1})'.format('logical_not', ','.join(self._print(i) for i in expr.args))

# numexpr works by altering the string passed to numexpr.evaluate
# rather than by populating a namespace. Thus a special printer...

Expand Down
60 changes: 17 additions & 43 deletions sympy/utilities/lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sympy.external import import_module
from sympy.core.compatibility import exec_, is_sequence, iterable, string_types, range
from sympy.utilities.decorator import doctest_depends_on
from sympy.utilities.exceptions import SymPyDeprecationWarning

# These are the namespaces the lambda functions will use.
MATH = {}
Expand Down Expand Up @@ -79,12 +78,16 @@
"E": "e",
"im": "imag",
"ln": "log",
"MutableDenseMatrix": "matrix",
"ImmutableMatrix": "matrix",
"Max": "amax",
"Min": "amin",
"oo": "inf",
"re": "real",
"SparseMatrix": "array",
"ImmutableSparseMatrix": "array",
"Matrix": "array",
"MutableDenseMatrix": "array",
"ImmutableMatrix": "array",
"ImmutableDenseMatrix": "array",
}

NUMEXPR_TRANSLATIONS = {}
Expand Down Expand Up @@ -182,23 +185,19 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
functions can be found at:
https://github.com/pydata/numexpr#supported-functions
Deprecation Warnings
====================
In previous releases ``lambdify`` replaced ``Matrix`` with ``numpy.matrix``
by default. As of release 0.7.6 ``numpy.array`` is being transitioned to
the default. In release 0.7.7 this transition will be complete. For now, to
use the new default behavior you must pass in ``[{'ImmutableMatrix':
numpy.array}, 'numpy']`` to the ``modules`` kwarg.
by default. As of release 0.7.7 ``numpy.array`` is the default.
To get the old default behavior you must pass in ``[{'ImmutableMatrix':
numpy.matrix}, 'numpy']`` to the ``modules`` kwarg.
>>> from sympy import lambdify, Matrix
>>> from sympy.abc import x, y
>>> import numpy
>>> mat2array = [{'ImmutableMatrix': numpy.array}, 'numpy']
>>> f = lambdify((x, y), Matrix([x, y]), modules=mat2array)
>>> array2mat = [{'ImmutableMatrix': numpy.matrix}, 'numpy']
>>> f = lambdify((x, y), Matrix([x, y]), modules=array2mat)
>>> f(1, 2)
array([[1],
[2]])
matrix([[1],
[2]])
Usage
=====
Expand Down Expand Up @@ -333,7 +332,6 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
for m in namespaces[::-1]:
buf = _get_namespace(m)
namespace.update(buf)
_issue_7853_dep_check(namespaces, namespace, expr)

if hasattr(expr, "atoms"):
# Try if you can extract symbols from the expression.
Expand All @@ -342,6 +340,10 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
for term in syms:
namespace.update({str(term): term})

if _module_present('numpy', namespaces) and printer is None:
# XXX: This has to be done here because of circular imports
from sympy.printing.lambdarepr import NumPyPrinter as printer

if _module_present('numexpr', namespaces) and printer is None:
# XXX: This has to be done here because of circular imports
from sympy.printing.lambdarepr import NumExprPrinter as printer
Expand Down Expand Up @@ -383,34 +385,6 @@ def lambdify(args, expr, modules=None, printer=None, use_imps=True,
return func


def _issue_7853_dep_check(namespaces, namespace, expr):
"""Used for checking things passed into modules kwarg for deprecation
issue #7853. This function and the call to it in lambdify should be
deleted once the cycle has ended."""

# If some module changed `ImmutableMatrix` to be something else
mat = namespace.get('ImmutableMatrix', False)
if not mat or 'numpy' not in namespaces or ('%s.%s' % (mat.__module__,
mat.__name__) == 'numpy.matrixlib.defmatrix.matrix'):
return
dicts = [m for m in namespaces if isinstance(m, dict)]

def test(expr):
return hasattr(expr, 'is_Matrix') and expr.is_Matrix

if test(expr) and not [d for d in dicts if 'ImmutableMatrix' in d]:
SymPyDeprecationWarning(
"Currently, `sympy.Matrix` is replaced with `numpy.matrix` if "
"the NumPy package is utilized in lambdify. In future versions "
"of SymPy (> 0.7.6), we will default to replacing "
"`sympy.Matrix` with `numpy.array`. To use the future "
"behavior now, supply the kwarg "
"`modules=[{'ImmutableMatrix': numpy.array}, 'numpy']`. "
"The old behavior can be retained in future versions by "
"supplying `modules=[{'ImmutableMatrix': numpy.matrix}, "
"'numpy']`.", issue=7853).warn()


def _module_present(modname, modlist):
if modname in modlist:
return True
Expand Down
69 changes: 66 additions & 3 deletions sympy/utilities/tests/test_lambdify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sympy.utilities.pytest import XFAIL, raises
from sympy import (
symbols, lambdify, sqrt, sin, cos, tan, pi, acos, acosh, Rational,
Float, Matrix, Lambda, exp, Integral, oo, I, Abs, Function, true, false)
Float, Matrix, Lambda, Piecewise, exp, Integral, oo, I, Abs, Function,
true, false, And, Or, Not)
from sympy.printing.lambdarepr import LambdaPrinter
import mpmath
from sympy.utilities.lambdify import implemented_function
Expand Down Expand Up @@ -303,13 +304,75 @@ def test_numpy_matrix():
skip("numpy not installed.")
A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
# Lambdify array first, to ensure return to matrix as default
f = lambdify((x, y, z), A, [{'ImmutableMatrix': numpy.array}, 'numpy'])
# Lambdify array first, to ensure return to array as default
f = lambdify((x, y, z), A, ['numpy'])
numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
# Check that the types are arrays and matrices
assert isinstance(f(1, 2, 3), numpy.ndarray)


def test_numpy_transpose():
if not numpy:
skip("numpy not installed.")
A = Matrix([[1, x], [0, 1]])
f = lambdify((x), A.T, modules="numpy")
numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]]))


def test_numpy_inverse():
if not numpy:
skip("numpy not installed.")
A = Matrix([[1, x], [0, 1]])
f = lambdify((x), A**-1, modules="numpy")
numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]]))


def test_numpy_old_matrix():
if not numpy:
skip("numpy not installed.")
A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
f = lambdify((x, y, z), A, [{'ImmutableMatrix': numpy.matrix}, 'numpy'])
numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
assert isinstance(f(1, 2, 3), numpy.matrix)


def test_numpy_piecewise():
if not numpy:
skip("numpy not installed.")
pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True))
f = lambdify(x, pieces, modules="numpy")
numpy.testing.assert_array_equal(f(numpy.arange(10)),
numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81]))


def test_numpy_logical_ops():
if not numpy:
skip("numpy not installed.")
and_func = lambdify((x, y), And(x, y), modules="numpy")
or_func = lambdify((x, y), Or(x, y), modules="numpy")
not_func = lambdify((x), Not(x), modules="numpy")
arr1 = numpy.array([True, True])
arr2 = numpy.array([False, True])
numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True]))
numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True]))
numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False]))


def test_numpy_matmul():
if not numpy:
skip("numpy not installed.")
xmat = Matrix([[x, y], [z, 1+z]])
ymat = Matrix([[x**2], [Abs(x)]])
mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy")
numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]]))
numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]]))
# Multiple matrices chained together in multiplication
f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy")
numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25],
[159, 251]]))


def test_numpy_numexpr():
if not numpy:
skip("numpy not installed.")
Expand Down

0 comments on commit 3e4fbed

Please sign in to comment.