Skip to content

Commit

Permalink
Fix hardcoded long int type
Browse files Browse the repository at this point in the history
Before the type was incorrectly determined during the code generation
step, which prevented it from working cross-platform. Change that to
instead determine it when `numba_overloads.py` is imported.
  • Loading branch information
person142 committed Jul 9, 2019
1 parent 7e88a50 commit 1b36e26
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
9 changes: 8 additions & 1 deletion numba_special/generate_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from . import function_pointers
bytes_to_int_type = {{
2: numba.types.int16,
4: numba.types.int32,
8: numba.types.int64
}}
numba_long = bytes_to_int_type[ctypes.sizeof(ctypes.c_long)]
functions = {{
{FUNCTIONS}
Expand Down Expand Up @@ -71,7 +78,7 @@ def {FUNCTION}(*args):"""

CTYPES_TO_NUMBA_TYPES = {
'c_double': 'numba.types.float64',
'c_long': 'numba.types.int{}'.format(8 * ctypes.sizeof(ctypes.c_long))
'c_long': 'numba_long'
}

CTYPES_TO_SHORT_NUMBA_TYPES = {
Expand Down
65 changes: 36 additions & 29 deletions numba_special/numba_overloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

from . import function_pointers

bytes_to_int_type = {
2: numba.types.int16,
4: numba.types.int32,
8: numba.types.int64
}

numba_long = bytes_to_int_type[ctypes.sizeof(ctypes.c_long)]

functions = {
'agm': ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double),
Expand Down Expand Up @@ -233,7 +240,7 @@ def bdtr(*args):
f = get_scalar_function('bdtr[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.int64, numba.types.float64,):
if args == (numba_long, numba_long, numba.types.float64,):
f = get_scalar_function('bdtr[long]')
return lambda *args: f(*args)

Expand All @@ -244,7 +251,7 @@ def bdtrc(*args):
f = get_scalar_function('bdtrc[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.int64, numba.types.float64,):
if args == (numba_long, numba_long, numba.types.float64,):
f = get_scalar_function('bdtrc[long]')
return lambda *args: f(*args)

Expand All @@ -255,7 +262,7 @@ def bdtri(*args):
f = get_scalar_function('bdtri[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.int64, numba.types.float64,):
if args == (numba_long, numba_long, numba.types.float64,):
f = get_scalar_function('bdtri[long]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -546,7 +553,7 @@ def eval_chebyc(*args):
f = get_scalar_function('eval_chebyc[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_chebyc[long, double]')
return lambda *args: f(*args)

Expand All @@ -557,7 +564,7 @@ def eval_chebys(*args):
f = get_scalar_function('eval_chebys[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_chebys[long, double]')
return lambda *args: f(*args)

Expand All @@ -568,7 +575,7 @@ def eval_chebyt(*args):
f = get_scalar_function('eval_chebyt[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_chebyt[long, double]')
return lambda *args: f(*args)

Expand All @@ -579,7 +586,7 @@ def eval_chebyu(*args):
f = get_scalar_function('eval_chebyu[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_chebyu[long, double]')
return lambda *args: f(*args)

Expand All @@ -590,7 +597,7 @@ def eval_gegenbauer(*args):
f = get_scalar_function('eval_gegenbauer[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64, numba.types.float64,):
if args == (numba_long, numba.types.float64, numba.types.float64,):
f = get_scalar_function('eval_gegenbauer[long, double]')
return lambda *args: f(*args)

Expand All @@ -601,21 +608,21 @@ def eval_genlaguerre(*args):
f = get_scalar_function('eval_genlaguerre[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64, numba.types.float64,):
if args == (numba_long, numba.types.float64, numba.types.float64,):
f = get_scalar_function('eval_genlaguerre[long, double]')
return lambda *args: f(*args)


@numba.extending.overload(sc.eval_hermite)
def eval_hermite(*args):
if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_hermite')
return lambda *args: f(*args)


@numba.extending.overload(sc.eval_hermitenorm)
def eval_hermitenorm(*args):
if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_hermitenorm')
return lambda *args: f(*args)

Expand All @@ -626,7 +633,7 @@ def eval_jacobi(*args):
f = get_scalar_function('eval_jacobi[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64, numba.types.float64, numba.types.float64,):
if args == (numba_long, numba.types.float64, numba.types.float64, numba.types.float64,):
f = get_scalar_function('eval_jacobi[long, double]')
return lambda *args: f(*args)

Expand All @@ -637,7 +644,7 @@ def eval_laguerre(*args):
f = get_scalar_function('eval_laguerre[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_laguerre[long, double]')
return lambda *args: f(*args)

Expand All @@ -648,7 +655,7 @@ def eval_legendre(*args):
f = get_scalar_function('eval_legendre[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_legendre[long, double]')
return lambda *args: f(*args)

Expand All @@ -659,7 +666,7 @@ def eval_sh_chebyt(*args):
f = get_scalar_function('eval_sh_chebyt[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_sh_chebyt[long, double]')
return lambda *args: f(*args)

Expand All @@ -670,7 +677,7 @@ def eval_sh_chebyu(*args):
f = get_scalar_function('eval_sh_chebyu[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_sh_chebyu[long, double]')
return lambda *args: f(*args)

Expand All @@ -681,7 +688,7 @@ def eval_sh_jacobi(*args):
f = get_scalar_function('eval_sh_jacobi[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64, numba.types.float64, numba.types.float64,):
if args == (numba_long, numba.types.float64, numba.types.float64, numba.types.float64,):
f = get_scalar_function('eval_sh_jacobi[long, double]')
return lambda *args: f(*args)

Expand All @@ -692,7 +699,7 @@ def eval_sh_legendre(*args):
f = get_scalar_function('eval_sh_legendre[double, double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('eval_sh_legendre[long, double]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -745,7 +752,7 @@ def expn(*args):
f = get_scalar_function('expn[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('expn[long]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -1078,7 +1085,7 @@ def kn(*args):
f = get_scalar_function('kn[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('kn[long]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -1173,7 +1180,7 @@ def nbdtr(*args):
f = get_scalar_function('nbdtr[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.int64, numba.types.float64,):
if args == (numba_long, numba_long, numba.types.float64,):
f = get_scalar_function('nbdtr[long]')
return lambda *args: f(*args)

Expand All @@ -1184,7 +1191,7 @@ def nbdtrc(*args):
f = get_scalar_function('nbdtrc[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.int64, numba.types.float64,):
if args == (numba_long, numba_long, numba.types.float64,):
f = get_scalar_function('nbdtrc[long]')
return lambda *args: f(*args)

Expand All @@ -1195,7 +1202,7 @@ def nbdtri(*args):
f = get_scalar_function('nbdtri[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.int64, numba.types.float64,):
if args == (numba_long, numba_long, numba.types.float64,):
f = get_scalar_function('nbdtri[long]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -1325,7 +1332,7 @@ def pdtr(*args):
f = get_scalar_function('pdtr[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('pdtr[long]')
return lambda *args: f(*args)

Expand All @@ -1336,7 +1343,7 @@ def pdtrc(*args):
f = get_scalar_function('pdtrc[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('pdtrc[long]')
return lambda *args: f(*args)

Expand All @@ -1347,7 +1354,7 @@ def pdtri(*args):
f = get_scalar_function('pdtri[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('pdtri[long]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -1428,7 +1435,7 @@ def smirnov(*args):
f = get_scalar_function('smirnov[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('smirnov[long]')
return lambda *args: f(*args)

Expand All @@ -1439,7 +1446,7 @@ def smirnovi(*args):
f = get_scalar_function('smirnovi[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('smirnovi[long]')
return lambda *args: f(*args)

Expand Down Expand Up @@ -1527,7 +1534,7 @@ def yn(*args):
f = get_scalar_function('yn[double]')
return lambda *args: f(*args)

if args == (numba.types.int64, numba.types.float64,):
if args == (numba_long, numba.types.float64,):
f = get_scalar_function('yn[long]')
return lambda *args: f(*args)

Expand Down

0 comments on commit 1b36e26

Please sign in to comment.