Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex argument support for erfc #1

Open
tcrensink opened this issue Nov 30, 2019 · 5 comments
Open

complex argument support for erfc #1

tcrensink opened this issue Nov 30, 2019 · 5 comments

Comments

@tcrensink
Copy link

tcrensink commented Nov 30, 2019

I need a numba compiled version of erfc for complex argument, but it appears that this is not supported. How difficult would it be to implement this feature? If you don't have time to implement this, maybe you can point me to where to start working on this and/or an alternative to numba compiled version for performant computation of erfc for complex argument? Thank you

@tcrensink tcrensink changed the title erfc does not support complex argument complex argument support for erfc Nov 30, 2019
@max9111
Copy link

max9111 commented Dec 22, 2019

Wrapping complex functions from scipy.special

There are basically two issues why it isn't so straight forward to wrap up functions with complex dtypes.

Cython wrapper

cimport scipy.special.cython_special
cimport numpy as np

cdef api erfc(double* out_real,double* out_imag,double in_real,double in_imag):
    cdef double complex input
    input.real=in_real
    input.imag=in_imag
  
    cdef double complex output=scipy.special.cython_special.erfc(input)
  
    out_real[0]=output.real
    out_imag[0]=output.imag

Intrinsic to overcome the passing scalar by reference problem

from numba import types
from numba.extending import intrinsic
from numba.core import cgutils

@intrinsic
def val_to_double_ptr(typingctx, data):
    def impl(context, builder, signature, args):
        ptr = cgutils.alloca_once_value(builder,args[0])
        return ptr
    sig = types.CPointer(types.float64)(types.float64)
    return sig, impl

@intrinsic
def double_ptr_to_val(typingctx, data):
    def impl(context, builder, signature, args):
        val = builder.load(args[0])
        return val
    sig = types.float64(types.CPointer(types.float64))
    return sig, impl

Numba wrapper

import ctypes
import numpy as np
import numba as nb
from numba.extending import get_cython_function_address

double = ctypes.c_double
double_p = ctypes.POINTER(double)

addr = get_cython_function_address("special", "erfc")
functype = ctypes.CFUNCTYPE(None,double_p,double_p, double,double)
erfc_fn_complex = functype(addr)

#Fast version using intrinsic functions
@nb.njit("complex128(complex128)")
def nb_erc_complex(val):
    out_real_p=val_to_double_ptr(0.)
    out_imag_p=val_to_double_ptr(0.)
    
    erfc_fn_complex(out_real_p,out_imag_p,np.real(val), np.imag(val))
    
    out_real=double_ptr_to_val(out_real_p)
    out_imag=double_ptr_to_val(out_imag_p)
    return np.complex(out_real + 1.j * out_imag)

#Slow version using arrays with one element
@nb.njit("complex128(complex128)")
def nb_erc_complex_2(val):
    out_real = np.empty(1, dtype=np.float64)
    out_imag = np.empty(1, dtype=np.float64)
    
    erfc_fn_complex(out_real.ctypes,out_imag.ctypes,np.real(val), np.imag(val))
    
    return np.complex(out_real[0] + 1.j * out_imag[0])

Example and Timings

from scipy import special

val=np.random.rand(10_000)+ 1j*np.random.rand(10_000)

@nb.njit()
def numba_Test(val):
    out=np.empty_like(val)
    for i in nb.prange(out.shape[0]):
        out[i]=nb_erc_complex(val[i])
    return out

print(np.allclose(special.erfc(val),numba_Test(val)))
#True
%timeit special.erfc(val)
#1.53 ms ± 5.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit numba_Test(val)
#with nb_erc_complex using intrinsic
#1.53 ms ± 7.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
#with nb_erc_complex_2
#3.12 ms ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@perarve
Copy link

perarve commented Aug 7, 2023

I'm trying to use the fast version above, but get the following error:
ImportError: cannot import name 'cgutils' from 'numba' (/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/numba/init.py)

Where are the needed cgutils in the present Numba release?

@max9111
Copy link

max9111 commented Aug 7, 2023

I'm trying to use the fast version above, but get the following error: ImportError: cannot import name 'cgutils' from 'numba' (/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/numba/init.py)

Where are the needed cgutils in the present Numba release?

I updated the comment, cgutils are now in numba.core.cgutils

numba.core.cgutils

@perarve
Copy link

perarve commented Aug 7, 2023

Thanks for the promptness of your answer.

Trying to figure out how to use your "fast version" I have assumed that your Cython wrapper should be treated as Cython code. I went to Cython website to find out how to deal with that. I created a file special.pyx containing the Cython wrapper code and created a python setup file containing:

from setuptools import setup
from Cython.Build import cythonize
setup(
name='Complex valued erfc for numba',
ext_modules=cythonize("special.pyx"),
)

I get an error "special.c:1097:10: fatal error: 'numpy/arrayobject.h' file not found" when running

python3 setup.py build_ext --inplace

I have found in special.c include statements of several h-files, which have found resides in the site-packages/numpy/core/include/numpy directory.

In what directory should I run the build? If I am running the build in the include directory then I suppose the files will be created there. Would it be fine to then move them to my working directory or should I replace the --inplace with something else? Where should I then run the build?

Edit: Tried to run in the include directory but then it is missing "numpy/npy_math.h".
Edit2: Such a file actually in the include/numpy directory, so why it can't be found is odd. The error report contained:

In file included from special.c:1107:
./numpy/ufuncobject.h:4:10: fatal error: 'numpy/npy_math.h' file not found
#include <numpy/npy_math.h>

@perarve
Copy link

perarve commented Aug 18, 2023

Solved the problem with by adding a link in the python's directory with header files to the directory with numpy's header files.

But then, I ran into problems related complex types which was fixed by changing the code in the "Numba wrapper". The return statements had a "np.complex(...)" call which had to be changed to "complex(...)". Now the code works fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants