Skip to content

Commit

Permalink
Improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
sklam committed Aug 17, 2018
1 parent 67fb239 commit 7f9a097
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions numba/tests/test_looplifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ def foo(x):


class TestLoopLiftingInAction(MemoryLeakMixin, TestCase):
def assert_has_lifted(self, jitted, loopcount):
lifted = jitted.overloads[jitted.signatures[0]].lifted
self.assertEqual(len(lifted), loopcount)

def test_issue_734(self):
from numba import jit, void, int32, double

Expand Down Expand Up @@ -369,7 +373,8 @@ def lift_issue2368(a, b):
self.assertTrue(loopcres.fndesc.native)

def test_no_iteration_w_redef(self):
# redefinition of res in the loop with no use of res prevents lifting
# redefinition of res in the loop with no use of res should not
# prevent lifting
from numba import jit

@jit(forceobj=True)
Expand All @@ -379,11 +384,12 @@ def test(n):
res = i
return res

# loop count = 0, loop won't lift or execute
# loop count = 1, loop lift but loop body not execute
self.assertEqual(test.py_func(-1), test(-1))

# loop count = 0, loop won't lift but will execute
self.assert_has_lifted(test, loopcount=1)
# loop count = 1, loop won't lift and will execute
self.assertEqual(test.py_func(1), test(1))
self.assert_has_lifted(test, loopcount=1)

def test_no_iteration(self):
from numba import jit
Expand All @@ -395,11 +401,12 @@ def test(n):
res += i
return res

# loop count = 0
# loop count = 1
self.assertEqual(test.py_func(-1), test(-1))

self.assert_has_lifted(test, loopcount=1)
# loop count = 1
self.assertEqual(test.py_func(1), test(1))
self.assert_has_lifted(test, loopcount=1)

def test_define_in_loop_body(self):
# tests a definition in a loop that leaves the loop is liftable
Expand All @@ -413,7 +420,7 @@ def test(n):

# loop count = 1
self.assertEqual(test.py_func(1), test(1))

self.assert_has_lifted(test, loopcount=1)

def test_invalid_argument(self):
"""Test a problem caused by invalid discovery of loop argument
Expand Down Expand Up @@ -443,7 +450,6 @@ def test(arg):
arg = np.arange(10)
self.assertEqual(test.py_func(arg), test(arg))


def test_conditionally_defined_in_loop(self):
from numba import jit
@jit(forceobj=True)
Expand All @@ -454,7 +460,10 @@ def test():
if i > 0:
x = 6
y += x
return y, x

self.assertEqual(test.py_func(), test())
self.assert_has_lifted(test, loopcount=1)

def test_stack_offset_error_when_has_no_return(self):
from numba import jit
Expand Down Expand Up @@ -524,6 +533,5 @@ def foo(x, y):
self.assertEqual(len(lifted.signatures), 2)



if __name__ == '__main__':
unittest.main()

0 comments on commit 7f9a097

Please sign in to comment.