Skip to content

Commit

Permalink
add examples used in Numba support documentation to the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wlav committed Jul 1, 2022
1 parent f8dadc9 commit 784fe1d
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions test/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,92 @@ def go_fast(a):

assert((go_fast(x) == go_slow(x)).all())
assert self.compare(go_slow, go_fast, 100000, x)


@mark.skipif(has_numba == False, reason="numba not found")
class TestNUMBA_DOC:
def setup_class(cls):
import cppyy
import cppyy.numba_ext

def test01_templated_freefunction(self):
"""Numba support documentation example: free templated function"""

import cppyy
import numba
import numpy as np

cppyy.cppdef("""
namespace NumbaSupportExample {
template<typename T>
T square(T t) { return t*t; }
}""")

@numba.jit(nopython=True)
def tsa(a):
total = type(a[0])(0)
for i in range(len(a)):
total += cppyy.gbl.NumbaSupportExample.square(a[i])
return total

a = np.array(range(10), dtype=np.float32)
assert type(tsa(a)) == float
assert tsa(a) == 285.0

a = np.array(range(10), dtype=np.int64)
assert type(tsa(a)) == int
assert tsa(a) == 285

def test02_class_features(self):
"""Numba support documentation example: class features"""

import cppyy
import numba
import numpy as np

cppyy.cppdef("""\
namespace NumbaSupportExample {
class MyData {
public:
MyData(int i, int j) : fField1(i), fField2(j) {}
public:
int get_field1() { return fField1; }
int get_field2() { return fField2; }
MyData copy() { return *this; }
public:
int fField1;
int fField2;
}; }""")

@numba.jit(nopython=True)
def tsdf(a, d):
total = type(a[0])(0)
for i in range(len(a)):
total += a[i] + d.fField1 + d.fField2
return total

d = cppyy.gbl.NumbaSupportExample.MyData(5, 6)
a = np.array(range(10), dtype=np.int32)

assert tsdf(a, d) == 155

@numba.jit(nopython=True)
def tsdm(a, d):
total = type(a[0])(0)
for i in range(len(a)):
total += a[i] + d.get_field1() + d.get_field2()
return total

assert tsdm(a, d) == 155

@numba.jit(nopython=True)
def tsdcm(a, d):
total = type(a[0])(0)
for i in range(len(a)):
total += a[i] + d.copy().fField1 + d.get_field2()
return total

assert tsdcm(a, d) == 155

0 comments on commit 784fe1d

Please sign in to comment.