Skip to content

Commit

Permalink
test object return within a trace
Browse files Browse the repository at this point in the history
  • Loading branch information
wlav committed Jun 27, 2022
1 parent 16565fd commit b5b6c80
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions test/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def test03_proxy_argument_for_field(self):
import numpy as np

cppyy.cppdef(r"""\
struct MyNumbaData1 {
MyNumbaData1(int64_t i1, int64_t i2) : fField1(i1), fField2(i2) {}
struct MyNumbaData03 {
MyNumbaData03(int64_t i1, int64_t i2) : fField1(i1), fField2(i2) {}
int64_t fField1;
int64_t fField2;
};""")
Expand All @@ -178,7 +178,7 @@ def go_fast(a, d):

# note: need a sizable array to outperform given the unboxing overhead
x = np.arange(10000, dtype=np.float64).reshape(100, 100)
d = cppyy.gbl.MyNumbaData1(42, 27)
d = cppyy.gbl.MyNumbaData03(42, 27)

assert((go_fast(x, d) == go_slow(x, d)).all())
assert self.compare(go_slow, go_fast, 10000, x, d)
Expand All @@ -190,8 +190,8 @@ def test04_proxy_argument_for_method(self):
import numpy as np

cppyy.cppdef(r"""\
struct MyNumbaData2 {
MyNumbaData2(int64_t i) : fField(i) {}
struct MyNumbaData04 {
MyNumbaData04(int64_t i) : fField(i) {}
int64_t get_field() { return fField; }
int64_t fField;
};""")
Expand All @@ -211,7 +211,7 @@ def go_fast(a, d):

# note: need a sizable array to outperform given the unboxing overhead
x = np.arange(10000, dtype=np.float64).reshape(100, 100)
d = cppyy.gbl.MyNumbaData2(42)
d = cppyy.gbl.MyNumbaData04(42)

assert((go_fast(x, d) == go_slow(x, d)).all())
assert self.compare(go_slow, go_fast, 10000, x, d)
Expand Down Expand Up @@ -248,3 +248,36 @@ def access_field(d):
for m in ('min', 'max'):
val = getattr(nl[ntype], m)()
assert access_field(getattr(ns, 'M%d'%i)(val)) == val

def test06_object_returns(self):
"""Numba-JITing of a function that returns an object"""

import cppyy
import numpy as np

cppyy.cppdef(r"""\
struct MyNumbaData06 {
MyNumbaData06(int64_t i1) : fField(i1) {}
int64_t fField;
};
MyNumbaData06 get_numba_data_06() { return MyNumbaData06(42); }
""")

def go_slow(a):
trace = 0.0
for i in range(a.shape[0]):
trace += cppyy.gbl.get_numba_data_06().fField
return a + trace

@numba.jit(nopython=True)
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += cppyy.gbl.get_numba_data_06().fField
return a + trace

x = np.arange(100, dtype=np.float64).reshape(10, 10)

assert((go_fast(x) == go_slow(x)).all())
assert self.compare(go_slow, go_fast, 100000, x)

0 comments on commit b5b6c80

Please sign in to comment.