Skip to content

Commit

Permalink
Merge pull request #300 from ichumuh/master
Browse files Browse the repository at this point in the history
exposed llvm lambdify opt_level in python api
  • Loading branch information
bjodah committed Sep 20, 2019
2 parents ef36dc2 + 602563b commit 7dd26e0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
2 changes: 1 addition & 1 deletion symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ cdef extern from "<symengine/lambda_double.h>" namespace "SymEngine":
cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
cdef cppclass LLVMDoubleVisitor:
LLVMDoubleVisitor() nogil
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
void init(const vec_basic &x, const vec_basic &b, bool cse, int opt_level) nogil except +
void call(double *r, const double *x) nogil
const string& dumps() nogil
void loads(const string&) nogil
Expand Down
1 change: 1 addition & 0 deletions symengine/lib/symengine_wrapper.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cdef class LambdaDouble(_Lambdify):

IF HAVE_SYMENGINE_LLVM:
cdef class LLVMDouble(_Lambdify):
cdef int opt_level
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
cdef _load(self, const string &s)
Expand Down
17 changes: 12 additions & 5 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4469,7 +4469,7 @@ def has_symbol(obj, symbol=None):


cdef class _Lambdify(object):
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, **kwargs):
cdef:
Basic e_
size_t ri, ci, nr, nc
Expand Down Expand Up @@ -4706,6 +4706,10 @@ def create_low_level_callable(lambdify, *args):


cdef class LambdaDouble(_Lambdify):
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
# reject additional arguments
pass

cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
if self.real:
self.lambda_double.resize(1)
Expand Down Expand Up @@ -4751,9 +4755,12 @@ cdef class LambdaDouble(_Lambdify):

IF HAVE_SYMENGINE_LLVM:
cdef class LLVMDouble(_Lambdify):
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3):
self.opt_level = opt_level

cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
self.lambda_double.resize(1)
self.lambda_double[0].init(args_, outs_, cse)
self.lambda_double[0].init(args_, outs_, cse, self.opt_level)

cdef _load(self, const string &s):
self.lambda_double.resize(1)
Expand Down Expand Up @@ -4801,7 +4808,7 @@ IF HAVE_SYMENGINE_LLVM:
def llvm_loading_func(*args):
return LLVMDouble(args, _load=True)

def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False):
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False, **kwargs):
"""
Lambdify instances are callbacks that numerically evaluate their symbolic
expressions from user provided input (real or complex) into (possibly user
Expand Down Expand Up @@ -4851,7 +4858,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
if backend == "llvm":
IF HAVE_SYMENGINE_LLVM:
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse)
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs)
if as_scipy:
return ret.as_scipy_low_level_callable()
return ret
Expand All @@ -4862,7 +4869,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
pass
else:
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse)
ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs)
if as_scipy:
return ret.as_scipy_low_level_callable()
return ret
Expand Down
18 changes: 17 additions & 1 deletion symengine/tests/test_lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def test_Lambdify():
assert allclose(L(range(n, n+len(args))),
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])

@unittest.skipUnless(have_numpy, "Numpy not installed")
def test_Lambdify_with_opt_level():
args = x, y, z = se.symbols('x y z')
raises(TypeError, lambda: se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z], backend='lambda', opt_level=0))

def _test_Lambdify_Piecewise(Lambdify):
x = se.symbols('x')
Expand All @@ -91,7 +95,6 @@ def test_Lambdify_Piecewise():
if se.have_llvm:
_test_Lambdify_Piecewise(lambda *args: se.Lambdify(*args, backend='llvm'))


@unittest.skipUnless(have_numpy, "Numpy not installed")
def test_Lambdify_LLVM():
n = 7
Expand All @@ -105,6 +108,19 @@ def test_Lambdify_LLVM():
assert allclose(L(range(n, n+len(args))),
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])

@unittest.skipUnless(have_numpy, "Numpy not installed")
def test_Lambdify_LLVM_with_opt_level():
for opt_level in range(4):
n = 7
args = x, y, z = se.symbols('x y z')
if not se.have_llvm:
raises(ValueError, lambda: se.Lambdify(args, [x+y+z, x**2,
(x-y)/z, x*y*z],
backend='llvm', opt_level=opt_level))
raise SkipTest("No LLVM support")
L = se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z], backend='llvm', opt_level=opt_level)
assert allclose(L(range(n, n+len(args))),
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])

def _get_2_to_2by2():
args = x, y = se.symbols('x y')
Expand Down

0 comments on commit 7dd26e0

Please sign in to comment.