Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -960,17 +960,17 @@ cdef extern from "<symengine/eval_double.h>" namespace "SymEngine":
cdef extern from "<symengine/lambda_double.h>" namespace "SymEngine":
cdef cppclass LambdaRealDoubleVisitor:
LambdaRealDoubleVisitor() nogil
void init(const vec_basic &x, const vec_basic &b) nogil except +
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
void call(double *r, const double *x) nogil
cdef cppclass LambdaComplexDoubleVisitor:
LambdaComplexDoubleVisitor() nogil
void init(const vec_basic &x, const vec_basic &b) nogil except +
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
void call(double complex *r, const double complex *x) nogil

cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
cdef cppclass LLVMDoubleVisitor:
LLVMDoubleVisitor() nogil
void init(const vec_basic &x, const vec_basic &b) nogil except +
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
void call(double *r, const double *x) nogil

cdef extern from "<symengine/series.h>" namespace "SymEngine":
Expand Down
77 changes: 17 additions & 60 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4369,7 +4369,7 @@ cdef class _Lambdify(object):
cdef vector[int] accum_out_sizes
cdef object numpy_dtype

def __init__(self, args, *exprs, cppbool real=True, order='C'):
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False):
cdef:
Basic e_
size_t ri, ci, nr, nc
Expand Down Expand Up @@ -4409,9 +4409,9 @@ cdef class _Lambdify(object):
for e in np.ravel(curr_expr, order=self.order):
e_ = _sympify(e)
outs_.push_back(e_.thisptr)
self._init(args_, outs_)
self._init(args_, outs_, cse)

cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
raise ValueError("Not supported")

cpdef unsafe_real(self,
Expand Down Expand Up @@ -4590,13 +4590,13 @@ cdef class LambdaDouble(_Lambdify):
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex

cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
if self.real:
self.lambda_double.resize(1)
self.lambda_double[0].init(args_, outs_)
self.lambda_double[0].init(args_, outs_, cse)
else:
self.lambda_double_complex.resize(1)
self.lambda_double_complex[0].init(args_, outs_)
self.lambda_double_complex[0].init(args_, outs_, cse)

cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
Expand All @@ -4621,9 +4621,9 @@ IF HAVE_SYMENGINE_LLVM:

cdef vector[symengine.LLVMDoubleVisitor] lambda_double

cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
self.lambda_double.resize(1)
self.lambda_double[0].init(args_, outs_)
self.lambda_double[0].init(args_, outs_, cse)

cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
Expand All @@ -4640,7 +4640,7 @@ IF HAVE_SYMENGINE_LLVM:
return create_low_level_callable(self, addr1, addr2)


def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False):
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False):
"""
Lambdify instances are callbacks that numerically evaluate their symbolic
expressions from user provided input (real or complex) into (possibly user
Expand All @@ -4666,6 +4666,9 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
as_scipy : bool
return a SciPy LowLevelCallable which can be used in SciPy's integrate
methods
cse : bool
Run Common Subexpression Elimination on the output before generating
the callback.

Returns
-------
Expand All @@ -4687,7 +4690,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
if backend == "llvm":
IF HAVE_SYMENGINE_LLVM:
ret = LLVMDouble(args, *exprs, real=real, order=order)
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse)
if as_scipy:
return ret.as_scipy_low_level_callable()
return ret
Expand All @@ -4698,63 +4701,17 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
pass
else:
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
ret = LambdaDouble(args, *exprs, real=real, order=order)
ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse)
if as_scipy:
return ret.as_scipy_low_level_callable()
return ret


def LambdifyCSE(args, *exprs, cse=None, order='C', **kwargs):
def LambdifyCSE(args, *exprs, order='C', **kwargs):
""" Analogous with Lambdify but performs common subexpression elimination.

See docstring of Lambdify.

Parameters
----------
args: iterable of symbols
exprs: iterable of expressions (with symbols from args)
cse: callback (default: None)
defaults to sympy.cse (see SymPy documentation)
order : str
order (passed to numpy.ravel and numpy.reshape).
\\*\\*kwargs: Keyword arguments passed onto Lambdify

"""
if cse is None:
from sympy import cse
_exprs = [np.asanyarray(e) for e in exprs]
_args = np.ravel(args, order=order)
from sympy import sympify as s_sympify
flat_exprs = list(itertools.chain(*[np.ravel(e, order=order) for e in _exprs]))
subs, flat_new_exprs = cse([s_sympify(expr) for expr in flat_exprs])

if subs:
explicit_subs = {}
for k, v in subs:
explicit_subs[k] = v.xreplace(explicit_subs)

cse_symbs, cse_exprs = zip(*subs)
new_exprs = []
n_taken = 0
for expr in _exprs:
new_exprs.append(np.reshape(flat_new_exprs[n_taken:n_taken+expr.size],
expr.shape, order=order))
n_taken += expr.size
new_lmb = Lambdify(tuple(_args) + cse_symbs, *new_exprs, order=order, **kwargs)
cse_lambda = Lambdify(_args, [ce.xreplace(explicit_subs) for ce in cse_exprs], **kwargs)
def cb(inp, *, out=None, **kw):
_inp = np.asanyarray(inp)
cse_vals = cse_lambda(_inp, **kw)
if order == 'C':
new_inp = np.concatenate((_inp[(Ellipsis,) + (np.newaxis,)*(cse_vals.ndim - _inp.ndim)],
cse_vals), axis=-1)
else:
new_inp = np.concatenate((_inp[(np.newaxis,)*(cse_vals.ndim - _inp.ndim) + (Ellipsis,)],
cse_vals), axis=0)
return new_lmb(new_inp, out=out, **kw)
return cb
else:
return Lambdify(args, *exprs, **kwargs)
warnings.warn("LambdifyCSE is deprecated. Use Lambdify(..., cse=True)", DeprecationWarning)
return Lambdify(args, *exprs, cse=True, order=order, **kwargs)


def ccode(expr):
Expand Down
24 changes: 6 additions & 18 deletions symengine/tests/test_lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,32 +193,28 @@ def test_broadcast_multiple_extra_dimensions():
assert abs(out[-1, -1, 1] - 11**3) < 1e-14


@unittest.skipUnless(have_sympy, "SymPy not installed")
def _get_cse_exprs():
import sympy as sp
args = x, y = sp.symbols('x y')
args = x, y = se.symbols('x y')
exprs = [x*x + y, y/(x*x), y*x*x+x]
inp = [11, 13]
ref = [121+13, 13/121, 13*121 + 11]
return args, exprs, inp, ref


@unittest.skipUnless(have_numpy, "Numpy not installed")
@unittest.skipUnless(have_sympy, "SymPy not installed")
def test_cse():
args, exprs, inp, ref = _get_cse_exprs()
lmb = se.LambdifyCSE(args, exprs)
lmb = se.Lambdify(args, exprs, cse=True)
out = lmb(inp)
assert allclose(out, ref)


@unittest.skipUnless(have_numpy, "Numpy not installed")
@unittest.skipUnless(have_sympy, "SymPy not installed")
def test_cse_gh174():
x = se.symbols('x')
funcs = [se.cos(x)**i for i in range(5)]
f_lmb = se.Lambdify([x], funcs)
f_cse = se.LambdifyCSE([x], funcs)
f_cse = se.Lambdify([x], funcs, cse=True)
a = np.array([1, 2, 3])
assert np.allclose(f_lmb(a), f_cse(a))

Expand Down Expand Up @@ -250,10 +246,9 @@ def _get_cse_exprs_big():


@unittest.skipUnless(have_numpy, "Numpy not installed")
@unittest.skipUnless(have_sympy, "SymPy not installed")
def test_cse_big():
args, exprs, inp = _get_cse_exprs_big()
lmb = se.LambdifyCSE(args, exprs)
lmb = se.Lambdify(args, exprs, cse=True)
out = lmb(inp)
ref = [expr.xreplace(dict(zip(args, inp))) for expr in exprs]
assert allclose(out, ref)
Expand Down Expand Up @@ -526,12 +521,6 @@ def test_Lambdify_heterogeneous_output():
_Lambdify_heterogeneous_output(se.Lambdify)


@unittest.skipUnless(have_numpy, "Numpy not installed")
@unittest.skipUnless(have_sympy, "SymPy not installed")
def test_LambdifyCSE_heterogeneous_output():
_Lambdify_heterogeneous_output(se.LambdifyCSE)


def _sympy_lambdify_heterogeneous_output(cb, Mtx):
x, y = se.symbols('x, y')
args = Mtx(2, 1, [x, y])
Expand Down Expand Up @@ -600,11 +589,10 @@ def test_Lambdify_scalar_vector_matrix():
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='llvm'))


@unittest.skipUnless(have_sympy, "SymPy not installed")
def test_Lambdify_scalar_vector_matrix_cse():
_test_Lambdify_scalar_vector_matrix(lambda *args: se.LambdifyCSE(*args, backend='lambda'))
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='lambda', cse=True))
if se.have_llvm:
_test_Lambdify_scalar_vector_matrix(lambda *args: se.LambdifyCSE(*args, backend='llvm'))
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='llvm', cse=True))


@unittest.skipUnless(have_numpy, "Numpy not installed")
Expand Down
2 changes: 1 addition & 1 deletion symengine_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f6eac46122d5da4519bc612d4c5203cbb3aa46b0
2f5ff9db9ff511ee243438a85ea8e2da2d05af39