Skip to content

Commit

Permalink
add more numeric type mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
wlav committed Jun 26, 2022
1 parent d05db88 commit 47b67ea
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/cppyy/numba_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def _init_extension():
_numba2cpp = {
nb_types.void : 'void',
nb_types.voidptr : 'void*',
nb_types.long_ : 'long',
nb_types.int_ : 'int',
nb_types.int32 : 'int32_t',
nb_types.int64 : 'int64_t',
nb_types.long_ : 'long',
nb_types.float32 : 'float',
nb_types.float64 : 'double',
}
Expand All @@ -48,9 +49,10 @@ def numba2cpp(val):
_cpp2numba = {
'void' : nb_types.void,
'void*' : nb_types.voidptr,
'long' : nb_types.long_,
'int' : nb_types.intc,
'int32_t' : nb_types.int32,
'int64_t' : nb_types.int64,
'long' : nb_types.long_,
'float' : nb_types.float32,
'double' : nb_types.float64,
}
Expand All @@ -61,7 +63,11 @@ def cpp2numba(val):
return _cpp2numba[val]

_cpp2ir = {
'int' : irType.int(nb_types.intc.bitwidth),
'int32_t' : irType.int(32),
'int64_t' : irType.int(64),
'float' : irType.float(),
'double' : irType.double(),
}

def cpp2ir(val):
Expand Down
29 changes: 29 additions & 0 deletions test/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,32 @@ def go_fast(a, d):
assert((go_fast(x, d) == go_slow(x, d)).all())
assert self.compare(go_slow, go_fast, 10000, x, d)

def test05_datatype_mapping(self):
"""Numba-JITing of various data types"""

import cppyy

@numba.jit(nopython=True)
def access_field(d):
return d.fField

code = """\
namespace NumbaDTT {
struct M%d { M%d(%s f) : fField(f) {};
%s buf, fField;
}; }"""

cppyy.cppdef("namespace NumbaDTT { }")
ns = cppyy.gbl.NumbaDTT

types = (
'int', 'int32_t', 'int64_t',
'float', 'double',
)

nl = cppyy.gbl.std.numeric_limits
for i, ntype in enumerate(types):
cppyy.cppdef(code % (i, i, ntype, ntype))
for m in ('min', 'max'):
val = getattr(nl[ntype], m)()
assert access_field(getattr(ns, 'M%d'%i)(val)) == val

0 comments on commit 47b67ea

Please sign in to comment.