Skip to content

as_ctypes function for Lambdify #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions symengine/lib/symengine_wrapper.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@ cdef class _Lambdify(object):

cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
cdef _load(self, const string &s)
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
cpdef unsafe_real(self,
double[::1] inp, double[::1] out,
int inp_offset=*, int out_offset=*)
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out,
int inp_offset=*, int out_offset=*)
cpdef eval_real(self, inp, out)
Expand All @@ -53,17 +51,16 @@ cdef class LambdaDouble(_Lambdify):
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*)
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=*, int out_offset=*)
cpdef as_scipy_low_level_callable(self)
cpdef as_ctypes(self)

IF HAVE_SYMENGINE_LLVM:
cdef class LLVMDouble(_Lambdify):
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
cdef _load(self, const string &s)
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*)
cpdef as_scipy_low_level_callable(self)
cpdef as_scipy_low_level_callable(self)
cpdef as_ctypes(self)
66 changes: 45 additions & 21 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4507,19 +4507,11 @@ cdef class _Lambdify(object):
cdef _load(self, const string &s):
raise ValueError("Not supported")

cdef void unsafe_real_ptr(self, double *inp, double *out) nogil:
with gil:
raise ValueError("Not supported")

cpdef unsafe_real(self,
double[::1] inp, double[::1] out,
int inp_offset=0, int out_offset=0):
raise ValueError("Not supported")

cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil:
with gil:
raise ValueError("Not supported")

cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out,
int inp_offset=0, int out_offset=0):
raise ValueError("Not supported")
Expand Down Expand Up @@ -4670,6 +4662,9 @@ cdef double _scipy_callback_lambda_real(int n, double *x, void *user_data) nogil
deref(lamb).call(&result, x)
return result

cdef void _ctypes_callback_lambda_real(double *output, const double *input, void *user_data) nogil:
cdef symengine.LambdaRealDoubleVisitor* lamb = <symengine.LambdaRealDoubleVisitor *>user_data
deref(lamb).call(output, input)

IF HAVE_SYMENGINE_LLVM:
cdef double _scipy_callback_llvm_real(int n, double *x, void *user_data) nogil:
Expand All @@ -4678,6 +4673,10 @@ IF HAVE_SYMENGINE_LLVM:
deref(lamb).call(&result, x)
return result

cdef void _ctypes_callback_llvm_real(double *output, const double *input, void *user_data) nogil:
cdef symengine.LLVMDoubleVisitor* lamb = <symengine.LLVMDoubleVisitor *>user_data
deref(lamb).call(output, input)


def create_low_level_callable(lambdify, *args):
from scipy import LowLevelCallable
Expand All @@ -4698,17 +4697,11 @@ cdef class LambdaDouble(_Lambdify):
self.lambda_double_complex.resize(1)
self.lambda_double_complex[0].init(args_, outs_, cse)

cdef void unsafe_real_ptr(self, double *inp, double *out) nogil:
self.lambda_double[0].call(out, inp)

cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
self.unsafe_real_ptr(&inp[inp_offset], &out[out_offset])

cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil:
self.lambda_double_complex[0].call(out, inp)
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])

cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=0, int out_offset=0):
self.unsafe_complex_ptr(&inp[inp_offset], &out[out_offset])
self.lambda_double_complex[0].call(&out[out_offset], &inp[inp_offset])

cpdef as_scipy_low_level_callable(self):
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
Expand All @@ -4721,6 +4714,23 @@ cdef class LambdaDouble(_Lambdify):
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
return create_low_level_callable(self, addr1, addr2)

cpdef as_ctypes(self):
"""
Returns a tuple with first element being a ctypes function with signature

void func(double * output, const double *input, void *user_data)

and second element being a ctypes void pointer. This void pointer needs to be
passed as input to the function as the third argument `user_data`.
"""
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
if not self.real:
raise RuntimeError("Lambda function has to be real")
addr1 = cast(<size_t>&_ctypes_callback_lambda_real,
CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p))
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
return addr1, addr2


IF HAVE_SYMENGINE_LLVM:
cdef class LLVMDouble(_Lambdify):
Expand All @@ -4740,23 +4750,37 @@ IF HAVE_SYMENGINE_LLVM:
return llvm_loading_func, (self.args_size, self.tot_out_size, self.out_shapes, self.real, \
self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, s)

cdef void unsafe_real_ptr(self, double *inp, double *out) nogil:
self.lambda_double[0].call(out, inp)

cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
self.unsafe_real_ptr(&inp[inp_offset], &out[out_offset])
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])

cpdef as_scipy_low_level_callable(self):
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
if not self.real:
raise RuntimeError("Lambda function has to be real")
if self.tot_out_size > 1:
raise RuntimeError("SciPy LowLevelCallable supports only functions with 1 output")
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
addr1 = cast(<size_t>&_scipy_callback_llvm_real,
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
return create_low_level_callable(self, addr1, addr2)

cpdef as_ctypes(self):
"""
Returns a tuple with first element being a ctypes function with signature

void func(double * output, const double *input, void *user_data)

and second element being a ctypes void pointer. This void pointer needs to be
passed as input to the function as the third argument `user_data`.
"""
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
if not self.real:
raise RuntimeError("Lambda function has to be real")
addr1 = cast(<size_t>&_ctypes_callback_llvm_real,
CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p))
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
return addr1, addr2

def llvm_loading_func(*args):
return LLVMDouble(args, _load=True)

Expand Down
13 changes: 13 additions & 0 deletions symengine/tests/test_lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,3 +806,16 @@ def test_scipy():
lmb = se.Lambdify(args, [se.exp(-x*t)/t**5], as_scipy=True)
res = integrate.nquad(lmb, [[1, np.inf], [0, np.inf]])
assert abs(res[0] - 0.2) < 1e-7


@unittest.skipUnless(have_numpy, "Numpy not installed")
def test_as_ctypes():
import numpy as np
import ctypes
x, y, z = se.symbols('x, y, z')
l = se.Lambdify([x, y, z], [x+y+z, x*y*z+1])
addr1, addr2 = l.as_ctypes()
inp = np.array([1,2,3], dtype=np.double)
out = np.array([0, 0], dtype=np.double)
addr1(out.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), inp.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), addr2)
assert np.all(out == [6, 7])