Skip to content

Commit

Permalink
Make literal_unroll function work as a freevar.
Browse files Browse the repository at this point in the history
As title. For context see comment on numba#5626.
  • Loading branch information
stuartarchibald committed Apr 28, 2020
1 parent 82c7d37 commit 36cc840
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
5 changes: 3 additions & 2 deletions numba/core/untyped_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def run_pass(self, state):
calls = [_ for _ in blk.find_exprs('call')]
for call in calls:
glbl = guard(get_definition, func_ir, call.func)
if glbl and isinstance(glbl, ir.Global):
if glbl and isinstance(glbl, (ir.Global, ir.FreeVar)):
# find a literal_unroll
if glbl.value is literal_unroll:
if len(call.args) > 1:
Expand Down Expand Up @@ -1285,7 +1285,8 @@ def assess_loop(self, loop, func_ir, partial_typemap=None):
return False
func_var = guard(get_definition, func_ir, call.func)
func = guard(get_definition, func_ir, func_var)
if func is None or not isinstance(func, ir.Global):
if func is None or not isinstance(func,
(ir.Global, ir.FreeVar)):
return False
if (func.value is None or
func.value not in self._accepted_calls):
Expand Down
13 changes: 13 additions & 0 deletions numba/tests/test_mixed_tuple_unroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,19 @@ def foo(cont):
['a', '25', '0.23', 'None'],
)

def test_unroll_freevar_tuple(self):
mixed = (np.ones((1,)), np.ones((1, 1)), np.ones((1, 1, 1)))
from numba import literal_unroll as freevar_unroll

@njit
def foo():
out = 0
for i in freevar_unroll(mixed):
out += i.ndim
return out

self.assertEqual(foo(), foo.py_func())


def capture(real_pass):
""" Returns a compiler pass that captures the mutation state reported
Expand Down

0 comments on commit 36cc840

Please sign in to comment.