Skip to content

Commit

Permalink
support for free (templated) functions in Numba
Browse files Browse the repository at this point in the history
  • Loading branch information
wlav committed Apr 29, 2022
1 parent 2365d39 commit dd7dbbb
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ PyPy support lags CPython support.
master: 2.4.0
-------------

* Support for free (templated) functions in Numba
* Support for globally overloaded ordering operators
* Special cases for __repr__/__str__ returning C++ stringy types
* Fix lookup of templates of function with template args
Expand Down
113 changes: 113 additions & 0 deletions python/cppyy/numba_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
""" cppyy extensions for numba
"""

import cppyy
import cppyy.types

import numba
import numba.extending as nb_ext
import numba.core.datamodel.models as nb_models
import numba.core.imputils as nb_iutils
import numba.core.types as nb_types
import numba.core.registry as nb_reg

from llvmlite import ir


numba2cpp = {
nb_types.long_ : 'long',
nb_types.int64 : 'int64_t',
nb_types.float64 : 'double',
}

cpp2numba = {
'long' : nb_types.long_,
'int64_t' : nb_types.int64,
'double' : nb_types.float64,
}


class CppyyFunctionNumbaType(nb_types.Callable):
targetdescr = nb_reg.cpu_target
requires_gil = False

def __init__(self, func):
super(CppyyFunctionNumbaType, self).__init__(str(func))

self.sig = None
self._func = func

self._signatures = list()
self._impl_keys = dict()

def is_precise(self):
return True # by definition

def get_call_type(self, context, args, kwds):
ol = CppyyFunctionNumbaType(self._func.__overload__(*(numba2cpp[x] for x in args)))

ol.sig = numba.core.typing.Signature(
return_type=cpp2numba[ol._func.__cpp_rettype__],
args=args,
recvr=None) # this pointer

self._impl_keys[args] = ol

@nb_iutils.lower_builtin(ol, *args)
def lower_external_call(context, builder, sig, args,
ty=nb_types.ExternalFunctionPointer(ol.sig, ol.get_pointer), pyval=self._func):
ptrty = context.get_function_pointer_type(ty)
ptrval = context.add_dynamic_addr(
builder, ty.get_pointer(pyval), info=str(pyval))
fptr = builder.bitcast(ptrval, ptrty)
return context.call_function_pointer(builder, fptr, args)

return ol.sig

def get_call_signatures(self):
return list(self._signatures), False

def get_impl_key(self, sig):
return self._impl_keys[sig.args]

def get_pointer(self, func):
ol = func.__overload__(*(numba2cpp[x] for x in self.sig.args))
address = cppyy.addressof(ol)
if not address:
raise RuntimeError("unresolved address for %s" % str(ol))
return address

@property
def key(self):
return self._func


@nb_ext.register_model(CppyyFunctionNumbaType)
class CppyyFunctionModel(nb_models.PrimitiveModel):
def __init__(self, dmm, fe_type):
# the function pointer of this overload can not be exactly typed, but
# only the storage size is relevant, so simply use a void*
be_type = ir.PointerType(dmm.lookup(nb_types.void).get_value_type())
super(CppyyFunctionModel, self).__init__(dmm, fe_type, be_type)

@nb_iutils.lower_constant(CppyyFunctionNumbaType)
def constant_function_pointer(context, builder, ty, pyval):
# TODO: needs to exist for the proper flow, but why? The lowering of the
# actual overloads is handled dynamically.
return

@nb_ext.typeof_impl.register(cppyy.types.Scope)
def typeof_scope(val, c):
if 'namespace' in repr(val):
return numba.types.Module(val)
return CppyyScopeNumbaType(val)

@nb_ext.typeof_impl.register(cppyy.types.Function)
def typeof_function(val, c):
return CppyyFunctionNumbaType(val)

@nb_ext.typeof_impl.register(cppyy.types.Template)
def typeof_template(val, c):
if hasattr(val, '__overload__'):
return CppyyFunctionNumbaType(val)
raise RuntimeError("only function templates supported")
88 changes: 88 additions & 0 deletions test/test_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import py, os, sys
import math, time
from pytest import mark, raises
from .support import setup_make

try:
import numba
has_numba = True
except ImportError:
has_numba = False


@mark.skipif(has_numba == False, reason="numba not found")
class TestNUMBA:
def setup_class(cls):
import cppyy
import cppyy.numba_ext

def compare(self, go_slow, go_fast, N, *args):
t0 = time.time()
for i in range(N):
go_slow(*args)
slow_time = time.time() - t0

t0 = time.time()
for i in range(N):
go_fast(*args)
fast_time = time.time() - t0

return fast_time < slow_time

def test01_compiled_free_func(self):
"""Test Numba-JITing of a compiled free function"""

import cppyy
import numpy as np

def go_slow(a):
trace = 0.0
for i in range(a.shape[0]):
trace += math.tanh(a[i, i])
return a + trace

@numba.jit(nopython=True)
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += cppyy.gbl.tanh(a[i, i])
return a + trace

x = np.arange(100, dtype=np.float64).reshape(10, 10)

assert (go_fast(x) == go_slow(x)).all()
assert self.compare(go_slow, go_fast, 300000, x)

def test02_JITed_template_free_func(self):
"""Test Numba-JITing of Cling-JITed templated free function"""

import cppyy
import numpy as np

cppyy.cppdef(r"""\
template<class T>
T add42(T t) {
return T(t+42);
}""")


def add42(t):
return type(t)(t+42)

def go_slow(a):
trace = 0.0
for i in range(a.shape[0]):
trace += add42(a[i, i]) + add42(int(a[i, i]))
return a + trace

@numba.jit(nopython=True)
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += cppyy.gbl.add42(a[i, i]) + cppyy.gbl.add42(int(a[i, i]))
return a + trace

x = np.arange(100, dtype=np.float64).reshape(10, 10)

assert (go_fast(x) == go_slow(x)).all()
assert self.compare(go_slow, go_fast, 100000, x)

0 comments on commit dd7dbbb

Please sign in to comment.