Skip to content

Commit

Permalink
Add overloads for float signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
person142 committed Jul 9, 2019
1 parent 1b36e26 commit 0ee76b3
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 0 deletions.
2 changes: 2 additions & 0 deletions generate_signatures_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

SPECIAL_DOC_TO_CTYPES = {
'double': 'c_double',
'float': 'c_float',
'long': 'c_long'
}

SPECIAL_DOC_TO_CYTHON_SPECIALIZATION = {
'double': 'double',
'float': 'float',
'long': 'long'
}

Expand Down
2 changes: 2 additions & 0 deletions numba_special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@
- :data:`scipy.special.expit`::
float64 expit(float64)
float32 expit(float32)
- :data:`scipy.special.expm1`::
Expand Down Expand Up @@ -514,6 +515,7 @@
- :data:`scipy.special.logit`::
float64 logit(float64)
float32 logit(float32)
- :data:`scipy.special.lpmv`::
Expand Down
2 changes: 2 additions & 0 deletions numba_special/function_pointers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ functions = {
'exp2': PyLong_FromVoidPtr(<void *>sc.exp2),
'expi[double]': PyLong_FromVoidPtr(<void *>sc.expi[double]),
'expit[double]': PyLong_FromVoidPtr(<void *>sc.expit[double]),
'expit[float]': PyLong_FromVoidPtr(<void *>sc.expit[float]),
'expm1[double]': PyLong_FromVoidPtr(<void *>sc.expm1[double]),
'expn[double]': PyLong_FromVoidPtr(<void *>sc.expn[double]),
'expn[long]': PyLong_FromVoidPtr(<void *>sc.expn[long]),
Expand Down Expand Up @@ -145,6 +146,7 @@ functions = {
'log_ndtr[double]': PyLong_FromVoidPtr(<void *>sc.log_ndtr[double]),
'loggamma[double]': PyLong_FromVoidPtr(<void *>sc.loggamma[double]),
'logit[double]': PyLong_FromVoidPtr(<void *>sc.logit[double]),
'logit[float]': PyLong_FromVoidPtr(<void *>sc.logit[float]),
'lpmv': PyLong_FromVoidPtr(<void *>sc.lpmv),
'mathieu_a': PyLong_FromVoidPtr(<void *>sc.mathieu_a),
'mathieu_b': PyLong_FromVoidPtr(<void *>sc.mathieu_b),
Expand Down
2 changes: 2 additions & 0 deletions numba_special/generate_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def {FUNCTION}(*args):"""

CTYPES_TO_NUMBA_TYPES = {
'c_double': 'numba.types.float64',
'c_float': 'numba.types.float32',
'c_long': 'numba_long'
}

CTYPES_TO_SHORT_NUMBA_TYPES = {
'c_double': 'float64',
'c_float': 'float32',
'c_long': 'long'
}

Expand Down
10 changes: 10 additions & 0 deletions numba_special/numba_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
'exp2': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'expi[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'expit[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'expit[float]': ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float),
'expm1[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'expn[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double),
'expn[long]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_long, ctypes.c_double),
Expand Down Expand Up @@ -154,6 +155,7 @@
'log_ndtr[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'loggamma[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'logit[double]': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double),
'logit[float]': ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float),
'lpmv': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double),
'mathieu_a': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double),
'mathieu_b': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double),
Expand Down Expand Up @@ -738,6 +740,10 @@ def expit(*args):
f = get_scalar_function('expit[double]')
return lambda *args: f(*args)

if args == (numba.types.float32,):
f = get_scalar_function('expit[float]')
return lambda *args: f(*args)


@numba.extending.overload(sc.expm1)
def expm1(*args):
Expand Down Expand Up @@ -1145,6 +1151,10 @@ def logit(*args):
f = get_scalar_function('logit[double]')
return lambda *args: f(*args)

if args == (numba.types.float32,):
f = get_scalar_function('logit[float]')
return lambda *args: f(*args)


@numba.extending.overload(sc.lpmv)
def lpmv(*args):
Expand Down
8 changes: 8 additions & 0 deletions numba_special/signatures.json
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,10 @@
"double": [
"c_double",
"c_double"
],
"float": [
"c_float",
"c_float"
]
},
"expm1": {
Expand Down Expand Up @@ -929,6 +933,10 @@
"double": [
"c_double",
"c_double"
],
"float": [
"c_float",
"c_float"
]
},
"lpmv": {
Expand Down
4 changes: 4 additions & 0 deletions numba_special/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

CTYPES_TO_TEST_POINTS = {
'c_double': [-100.0, -10.0, -1.0, -0.1, 0.0, 0.1, 1.0, 10.0, 100.0],
'c_float': np.array(
[-100.0, -10.0, -1.0, -0.1, 0.0, 0.1, 1.0, 10.0, 100.0],
dtype=np.float32
),
'c_long': [-100, -10, -1, 0, 1, 10, 100],
}

Expand Down

0 comments on commit 0ee76b3

Please sign in to comment.