Skip to content

Commit

Permalink
Handle cval as a np attr in stencil generation.
Browse files Browse the repository at this point in the history
As title.

Fixes numba#7286
  • Loading branch information
stuartarchibald committed Aug 6, 2021
1 parent 38ab89d commit 4398f0b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
10 changes: 8 additions & 2 deletions numba/stencils/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,11 @@ def _stencil_wrapper(self, result, sigret, return_type, typemap, calltypes, *arg
# If we have to allocate the output array (the out argument was not used)
# then us numpy.full if the user specified a cval stencil decorator option
# or np.zeros if they didn't to allocate the array.
def cval_as_str(cval):
if getattr(np, str(cval), False) is cval:
return f"np.{cval}"
else:
return str(cval)
if result is None:
return_type_name = numpy_support.as_dtype(
return_type.dtype).type.__name__
Expand All @@ -594,7 +599,8 @@ def _stencil_wrapper(self, result, sigret, return_type, typemap, calltypes, *arg
raise ValueError(
"cval type does not match stencil return type.")
out_init ="{} = np.full({}, {}, dtype=np.{})\n".format(
out_name, shape_name, cval, return_type_name)
out_name, shape_name, cval_as_str(cval),
return_type_name)
else:
out_init ="{} = np.zeros({}, dtype=np.{})\n".format(
out_name, shape_name, return_type_name)
Expand All @@ -606,7 +612,7 @@ def _stencil_wrapper(self, result, sigret, return_type, typemap, calltypes, *arg
if not self._typingctx.can_convert(cval_ty, return_type.dtype):
msg = "cval type does not match stencil return type."
raise ValueError(msg)
out_init = "{}[:] = {}\n".format(out_name, cval)
out_init = "{}[:] = {}\n".format(out_name, cval_as_str(cval))
func_text += " " + out_init

offset = 1
Expand Down
37 changes: 37 additions & 0 deletions numba/tests/test_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,34 @@ def wrapped():
else:
raise AssertionError("Expected error was not raised")

@skip_unsupported
def test_out_kwarg_w_cval_np_attr(self):
""" Tests issue #7286 where cval is an attr on the np module"""
def kernel(a):
return (a[0, 0] - a[1, 0])

stencil_fn = numba.stencil(kernel, cval=np.inf)

def wrapped():
A = np.arange(12.).reshape((3, 4))
ret = np.ones_like(A)
stencil_fn(A, out=ret)
return ret

# stencil function case
A = np.arange(12.).reshape((3, 4))
expected = np.full_like(A, -4)
expected[-1, :] = np.inf
ret = np.ones_like(A)
stencil_fn(A, out=ret)
np.testing.assert_almost_equal(ret, expected)

# wrapped function case, check njit, then njit(parallel=True)
impls = self.compile_all(wrapped,)
for impl in impls:
got = impl.entry_point()
np.testing.assert_almost_equal(got, expected)


class pyStencilGenerator:
"""
Expand Down Expand Up @@ -2795,5 +2823,14 @@ def kernel(a):
a = np.arange(20, dtype=np.uint32).reshape(4, 5)
self.check(kernel, a)

def test_basic98(self):
""" Test issue #7286 where the cval is a np attr"""
def kernel(a):
return a[0, 0]
a = np.arange(6.).reshape((2, 3))
self.check(kernel, a, options={'neighborhood': ((-1, 1), (-1, 1),),
'cval':np.nan})


if __name__ == "__main__":
unittest.main()

0 comments on commit 4398f0b

Please sign in to comment.