-
Notifications
You must be signed in to change notification settings - Fork 3
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
complex argument support for erfc #1
Comments
Wrapping complex functions from scipy.special There are basically two issues why it isn't so straight forward to wrap up functions with complex dtypes.
Cython wrapper cimport scipy.special.cython_special
cimport numpy as np
cdef api erfc(double* out_real,double* out_imag,double in_real,double in_imag):
cdef double complex input
input.real=in_real
input.imag=in_imag
cdef double complex output=scipy.special.cython_special.erfc(input)
out_real[0]=output.real
out_imag[0]=output.imag Intrinsic to overcome the passing scalar by reference problem from numba import types
from numba.extending import intrinsic
from numba.core import cgutils
@intrinsic
def val_to_double_ptr(typingctx, data):
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder,args[0])
return ptr
sig = types.CPointer(types.float64)(types.float64)
return sig, impl
@intrinsic
def double_ptr_to_val(typingctx, data):
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = types.float64(types.CPointer(types.float64))
return sig, impl Numba wrapper import ctypes
import numpy as np
import numba as nb
from numba.extending import get_cython_function_address
double = ctypes.c_double
double_p = ctypes.POINTER(double)
addr = get_cython_function_address("special", "erfc")
functype = ctypes.CFUNCTYPE(None,double_p,double_p, double,double)
erfc_fn_complex = functype(addr)
#Fast version using intrinsic functions
@nb.njit("complex128(complex128)")
def nb_erc_complex(val):
out_real_p=val_to_double_ptr(0.)
out_imag_p=val_to_double_ptr(0.)
erfc_fn_complex(out_real_p,out_imag_p,np.real(val), np.imag(val))
out_real=double_ptr_to_val(out_real_p)
out_imag=double_ptr_to_val(out_imag_p)
return np.complex(out_real + 1.j * out_imag)
#Slow version using arrays with one element
@nb.njit("complex128(complex128)")
def nb_erc_complex_2(val):
out_real = np.empty(1, dtype=np.float64)
out_imag = np.empty(1, dtype=np.float64)
erfc_fn_complex(out_real.ctypes,out_imag.ctypes,np.real(val), np.imag(val))
return np.complex(out_real[0] + 1.j * out_imag[0]) Example and Timings from scipy import special
val=np.random.rand(10_000)+ 1j*np.random.rand(10_000)
@nb.njit()
def numba_Test(val):
out=np.empty_like(val)
for i in nb.prange(out.shape[0]):
out[i]=nb_erc_complex(val[i])
return out
print(np.allclose(special.erfc(val),numba_Test(val)))
#True
%timeit special.erfc(val)
#1.53 ms ± 5.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit numba_Test(val)
#with nb_erc_complex using intrinsic
#1.53 ms ± 7.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
#with nb_erc_complex_2
#3.12 ms ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
I'm trying to use the fast version above, but get the following error: Where are the needed cgutils in the present Numba release? |
I updated the comment, cgutils are now in numba.core.cgutils numba.core.cgutils |
Thanks for the promptness of your answer. Trying to figure out how to use your "fast version" I have assumed that your Cython wrapper should be treated as Cython code. I went to Cython website to find out how to deal with that. I created a file special.pyx containing the Cython wrapper code and created a python setup file containing: from setuptools import setup I get an error "special.c:1097:10: fatal error: 'numpy/arrayobject.h' file not found" when running python3 setup.py build_ext --inplace I have found in special.c include statements of several h-files, which have found resides in the site-packages/numpy/core/include/numpy directory. In what directory should I run the build? If I am running the build in the include directory then I suppose the files will be created there. Would it be fine to then move them to my working directory or should I replace the --inplace with something else? Where should I then run the build? Edit: Tried to run in the include directory but then it is missing "numpy/npy_math.h". In file included from special.c:1107: |
Solved the problem with by adding a link in the python's directory with header files to the directory with numpy's header files. But then, I ran into problems related complex types which was fixed by changing the code in the "Numba wrapper". The return statements had a "np.complex(...)" call which had to be changed to "complex(...)". Now the code works fine. |
I need a numba compiled version of
erfc
for complex argument, but it appears that this is not supported. How difficult would it be to implement this feature? If you don't have time to implement this, maybe you can point me to where to start working on this and/or an alternative to numba compiled version for performant computation of erfc for complex argument? Thank youThe text was updated successfully, but these errors were encountered: