-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support for free (templated) functions in Numba
- Loading branch information
Showing
3 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
""" cppyy extensions for numba | ||
""" | ||
|
||
import cppyy | ||
import cppyy.types | ||
|
||
import numba | ||
import numba.extending as nb_ext | ||
import numba.core.datamodel.models as nb_models | ||
import numba.core.imputils as nb_iutils | ||
import numba.core.types as nb_types | ||
import numba.core.registry as nb_reg | ||
|
||
from llvmlite import ir | ||
|
||
|
||
numba2cpp = { | ||
nb_types.long_ : 'long', | ||
nb_types.int64 : 'int64_t', | ||
nb_types.float64 : 'double', | ||
} | ||
|
||
cpp2numba = { | ||
'long' : nb_types.long_, | ||
'int64_t' : nb_types.int64, | ||
'double' : nb_types.float64, | ||
} | ||
|
||
|
||
class CppyyFunctionNumbaType(nb_types.Callable): | ||
targetdescr = nb_reg.cpu_target | ||
requires_gil = False | ||
|
||
def __init__(self, func): | ||
super(CppyyFunctionNumbaType, self).__init__(str(func)) | ||
|
||
self.sig = None | ||
self._func = func | ||
|
||
self._signatures = list() | ||
self._impl_keys = dict() | ||
|
||
def is_precise(self): | ||
return True # by definition | ||
|
||
def get_call_type(self, context, args, kwds): | ||
ol = CppyyFunctionNumbaType(self._func.__overload__(*(numba2cpp[x] for x in args))) | ||
|
||
ol.sig = numba.core.typing.Signature( | ||
return_type=cpp2numba[ol._func.__cpp_rettype__], | ||
args=args, | ||
recvr=None) # this pointer | ||
|
||
self._impl_keys[args] = ol | ||
|
||
@nb_iutils.lower_builtin(ol, *args) | ||
def lower_external_call(context, builder, sig, args, | ||
ty=nb_types.ExternalFunctionPointer(ol.sig, ol.get_pointer), pyval=self._func): | ||
ptrty = context.get_function_pointer_type(ty) | ||
ptrval = context.add_dynamic_addr( | ||
builder, ty.get_pointer(pyval), info=str(pyval)) | ||
fptr = builder.bitcast(ptrval, ptrty) | ||
return context.call_function_pointer(builder, fptr, args) | ||
|
||
return ol.sig | ||
|
||
def get_call_signatures(self): | ||
return list(self._signatures), False | ||
|
||
def get_impl_key(self, sig): | ||
return self._impl_keys[sig.args] | ||
|
||
def get_pointer(self, func): | ||
ol = func.__overload__(*(numba2cpp[x] for x in self.sig.args)) | ||
address = cppyy.addressof(ol) | ||
if not address: | ||
raise RuntimeError("unresolved address for %s" % str(ol)) | ||
return address | ||
|
||
@property | ||
def key(self): | ||
return self._func | ||
|
||
|
||
@nb_ext.register_model(CppyyFunctionNumbaType) | ||
class CppyyFunctionModel(nb_models.PrimitiveModel): | ||
def __init__(self, dmm, fe_type): | ||
# the function pointer of this overload can not be exactly typed, but | ||
# only the storage size is relevant, so simply use a void* | ||
be_type = ir.PointerType(dmm.lookup(nb_types.void).get_value_type()) | ||
super(CppyyFunctionModel, self).__init__(dmm, fe_type, be_type) | ||
|
||
@nb_iutils.lower_constant(CppyyFunctionNumbaType) | ||
def constant_function_pointer(context, builder, ty, pyval): | ||
# TODO: needs to exist for the proper flow, but why? The lowering of the | ||
# actual overloads is handled dynamically. | ||
return | ||
|
||
@nb_ext.typeof_impl.register(cppyy.types.Scope) | ||
def typeof_scope(val, c): | ||
if 'namespace' in repr(val): | ||
return numba.types.Module(val) | ||
return CppyyScopeNumbaType(val) | ||
|
||
@nb_ext.typeof_impl.register(cppyy.types.Function) | ||
def typeof_function(val, c): | ||
return CppyyFunctionNumbaType(val) | ||
|
||
@nb_ext.typeof_impl.register(cppyy.types.Template) | ||
def typeof_template(val, c): | ||
if hasattr(val, '__overload__'): | ||
return CppyyFunctionNumbaType(val) | ||
raise RuntimeError("only function templates supported") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import py, os, sys | ||
import math, time | ||
from pytest import mark, raises | ||
from .support import setup_make | ||
|
||
try: | ||
import numba | ||
has_numba = True | ||
except ImportError: | ||
has_numba = False | ||
|
||
|
||
@mark.skipif(has_numba == False, reason="numba not found") | ||
class TestNUMBA: | ||
def setup_class(cls): | ||
import cppyy | ||
import cppyy.numba_ext | ||
|
||
def compare(self, go_slow, go_fast, N, *args): | ||
t0 = time.time() | ||
for i in range(N): | ||
go_slow(*args) | ||
slow_time = time.time() - t0 | ||
|
||
t0 = time.time() | ||
for i in range(N): | ||
go_fast(*args) | ||
fast_time = time.time() - t0 | ||
|
||
return fast_time < slow_time | ||
|
||
def test01_compiled_free_func(self): | ||
"""Test Numba-JITing of a compiled free function""" | ||
|
||
import cppyy | ||
import numpy as np | ||
|
||
def go_slow(a): | ||
trace = 0.0 | ||
for i in range(a.shape[0]): | ||
trace += math.tanh(a[i, i]) | ||
return a + trace | ||
|
||
@numba.jit(nopython=True) | ||
def go_fast(a): | ||
trace = 0.0 | ||
for i in range(a.shape[0]): | ||
trace += cppyy.gbl.tanh(a[i, i]) | ||
return a + trace | ||
|
||
x = np.arange(100, dtype=np.float64).reshape(10, 10) | ||
|
||
assert (go_fast(x) == go_slow(x)).all() | ||
assert self.compare(go_slow, go_fast, 300000, x) | ||
|
||
def test02_JITed_template_free_func(self): | ||
"""Test Numba-JITing of Cling-JITed templated free function""" | ||
|
||
import cppyy | ||
import numpy as np | ||
|
||
cppyy.cppdef(r"""\ | ||
template<class T> | ||
T add42(T t) { | ||
return T(t+42); | ||
}""") | ||
|
||
|
||
def add42(t): | ||
return type(t)(t+42) | ||
|
||
def go_slow(a): | ||
trace = 0.0 | ||
for i in range(a.shape[0]): | ||
trace += add42(a[i, i]) + add42(int(a[i, i])) | ||
return a + trace | ||
|
||
@numba.jit(nopython=True) | ||
def go_fast(a): | ||
trace = 0.0 | ||
for i in range(a.shape[0]): | ||
trace += cppyy.gbl.add42(a[i, i]) + cppyy.gbl.add42(int(a[i, i])) | ||
return a + trace | ||
|
||
x = np.arange(100, dtype=np.float64).reshape(10, 10) | ||
|
||
assert (go_fast(x) == go_slow(x)).all() | ||
assert self.compare(go_slow, go_fast, 100000, x) |