Skip to content

Commit

Permalink
Merge pull request numba#3199 from ehsantn/stencil_index_const
Browse files Browse the repository at this point in the history
Support inferring stencil index as constant in simple unary expressions
  • Loading branch information
stuartarchibald committed Aug 22, 2018
2 parents 0fb2c6b + 1268eaf commit 252957c
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 21 deletions.
40 changes: 25 additions & 15 deletions docs/source/user/stencil.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Using the ``@stencil`` decorator
================================

Stencils are a common computational pattern in which array elements
Stencils are a common computational pattern in which array elements
are updated according to some fixed pattern called the stencil kernel.
Numba provides the ``@stencil`` decorator so that users may
easily specify a stencil kernel and Numba then generates the looping
Expand All @@ -21,7 +21,7 @@ Basic usage
===========

An example use of the ``@stencil`` decorator::

from numba import stencil

@stencil
Expand All @@ -38,8 +38,8 @@ Conceptually, the stencil kernel is run once for each element in the
output array. The return value from the stencil kernel is the value
written into the output array for that particular element.

The parameter ``a`` represents the input array over which the
kernel is applied.
The parameter ``a`` represents the input array over which the
kernel is applied.
Indexing into this array takes place with respect to the current element
of the output array being processed. For example, if element ``(x, y)``
is being processed then ``a[0, 0]`` in the stencil kernel corresponds to
Expand All @@ -48,9 +48,9 @@ kernel corresponds to ``a[x - 1, y + 1]`` in the input array.

Depending on the specified kernel, the kernel may not be applicable to the
borders of the output array as this may cause the input array to be
accessed out-of-bounds. The way in which the stencil decorator handles
this situation is dependent upon which :ref:`stencil-mode` is selected.
The default mode is for the stencil decorator to set the border elements
accessed out-of-bounds. The way in which the stencil decorator handles
this situation is dependent upon which :ref:`stencil-mode` is selected.
The default mode is for the stencil decorator to set the border elements
of the output array to zero.

To invoke a stencil on an input array, call the stencil as if it were
Expand Down Expand Up @@ -105,13 +105,13 @@ all such input array arguments.
Kernel shape inference and border handling
==========================================

In the above example and in most cases, the array indexing in the
In the above example and in most cases, the array indexing in the
stencil kernel will exclusively use ``Integer`` literals.
In such cases, the stencil decorator is able to analyze the stencil
kernel to determine its size. In the above example, the stencil
decorator determines that the kernel is ``3 x 3`` in shape since indices
``-1`` to ``1`` are used for both the first and second dimensions. Note that
the stencil decorator also correctly handles non-symmetric and
the stencil decorator also correctly handles non-symmetric and
non-square stencil kernels.

Based on the size of the stencil kernel, the stencil decorator is
Expand All @@ -122,11 +122,21 @@ of the output array. In the above example, points ``-1`` and ``+1`` are
accessed in each dimension and thus the output array has a border
of size one in all dimensions.

The parallel mode is able to infer kernel indices as constants from
simple expressions if possible. For example::

@njit(parallel=True)
def stencil_test(A):
c = 2
B = stencil(
lambda a, c: 0.3 * (a[-c+1] + a[0] + a[c-1]))(A, c)
return B


Stencil decorator options
=========================

While the stencil decorator may be augmented in the future to
While the stencil decorator may be augmented in the future to
provide additional mechanisms for border handling, at the moment
the stencil decorator currently supports only one option.

Expand All @@ -138,7 +148,7 @@ the stencil decorator currently supports only one option.
Sometimes it may be inconvenient to write the stencil kernel
exclusively with ``Integer`` literals. For example, let us say we
would like to compute the trailing 30-day moving average of a
time series of data. One could write
time series of data. One could write
``(a[-29] + a[-28] + ... + a[-1] + a[0]) / 30`` but the stencil
decorator offers a more concise form using the ``neighborhood``
option::
Expand Down Expand Up @@ -176,7 +186,7 @@ to a constant value, as specified by the ``cval`` parameter.

The optional cval parameter defaults to zero but can be set to any
desired value, which is then used for the border of the output array
if the mode parameter is set to ``constant``. The cval parameter is
if the mode parameter is set to ``constant``. The cval parameter is
ignored in all other modes. The type of the cval parameter must match
the return type of the stencil kernel. If the user wishes the output
array to be constructed from a particular type then they should ensure
Expand Down Expand Up @@ -206,7 +216,7 @@ The stencil decorator returns a callable object of type ``StencilFunc``.
``StencilFunc`` objects contains a number of attributes but the only one of
potential interest to users is the ``neighborhood`` attribute.
If the ``neighborhood`` option was passed to the stencil decorator then
the provided neighborhood is stored in this attribute. Else, upon
the provided neighborhood is stored in this attribute. Else, upon
first execution or compilation, the system calculates the neighborhood
as described above and then stores the computed neighborhood into this
attribute. A user may then inspect the attribute if they wish to verify
Expand All @@ -226,8 +236,8 @@ also include the following optional parameter.
-------

The optional ``out`` parameter is added to every stencil function
generated by Numba. If specified, the ``out`` parameter tells
Numba that the user is providing their own pre-allocated array
generated by Numba. If specified, the ``out`` parameter tells
Numba that the user is providing their own pre-allocated array
to be used for the output of the stencil. In this case, the
stencil function will not allocate its own output array.
Users should assure that the return type of the stencil kernel can
Expand Down
7 changes: 7 additions & 0 deletions numba/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,13 @@ def __init__(self, f_ir):
self.calltypes = None
rewrites.rewrite_registry.apply('before-inference',
DummyPipeline(ir), ir)
# call inline pass to handle cases like stencils and comprehensions
inline_pass = numba.inline_closurecall.InlineClosureCallPass(
ir, numba.targets.cpu.ParallelOptions(False))
inline_pass.run()
from numba import postproc
post_proc = postproc.PostProcessor(ir)
post_proc.run()
return ir

def replace_arg_nodes(block, args):
Expand Down
7 changes: 4 additions & 3 deletions numba/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def add_indices_to_kernel(self, kernel, index_names, ndim,
elif stmt_index_var.name in const_dict:
kernel_consts += [const_dict[stmt_index_var.name]]
else:
raise ValueError("Non-constant specified for "
"stencil kernel index.")
raise ValueError("stencil kernel index is not "
"constant, 'neighborhood' option required")

if ndim == 1:
# Single dimension always has index variable 'index0'.
Expand Down Expand Up @@ -261,7 +261,8 @@ def add_indices_to_kernel(self, kernel, index_names, ndim,
neighborhood[i][1] = max(neighborhood[i][1], te)
else:
raise ValueError(
"Non-constant used as stencil index.")
"stencil kernel index is not constant,"
"'neighborhood' option required")
index_len = len(index)
elif isinstance(index, int):
neighborhood[0][0] = min(neighborhood[0][0], index)
Expand Down
80 changes: 77 additions & 3 deletions numba/stencilparfor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from numba import ir_utils, ir, utils, config, typing
from numba.ir_utils import (get_call_table, mk_unique_var,
compile_to_numba_ir, replace_arg_nodes, guard,
find_callname)
find_callname, require, find_const, GuardException)
from numba.six import exec_


Expand Down Expand Up @@ -166,7 +166,7 @@ def _mk_stencil_parfor(self, label, in_args, out_arr, stencil_ir,
parfor_vars.append(parfor_var)

start_lengths, end_lengths = self._replace_stencil_accesses(
stencil_blocks, parfor_vars, in_args, index_offsets, stencil_func,
stencil_ir, parfor_vars, in_args, index_offsets, stencil_func,
arg_to_arr_dict)

if config.DEBUG_ARRAY_OPT == 1:
Expand Down Expand Up @@ -350,12 +350,13 @@ def get_start_ind(s_length):
ret_var = block.body[-2].value.value
return ret_var

def _replace_stencil_accesses(self, stencil_blocks, parfor_vars, in_args,
def _replace_stencil_accesses(self, stencil_ir, parfor_vars, in_args,
index_offsets, stencil_func, arg_to_arr_dict):
""" Convert relative indexing in the stencil kernel to standard indexing
by adding the loop index variables to the corresponding dimensions
of the array index tuples.
"""
stencil_blocks = stencil_ir.blocks
in_arr = in_args[0]
in_arg_names = [x.name for x in in_args]

Expand Down Expand Up @@ -426,6 +427,12 @@ def _replace_stencil_accesses(self, stencil_blocks, parfor_vars, in_args,
else:
if hasattr(index_list, 'name') and index_list.name in tuple_table:
index_list = tuple_table[index_list.name]
# indices can be inferred as constant in simple expressions
# like -c where c is constant
# handled here since this is a common stencil index pattern
stencil_ir._definitions = ir_utils.build_definitions(stencil_blocks)
index_list = [_get_const_index_expr(
stencil_ir, self.func_ir, v) for v in index_list]
if index_offsets:
index_list = self._add_index_offsets(index_list,
list(index_offsets), new_body, scope, loc)
Expand Down Expand Up @@ -674,3 +681,70 @@ def __init__(self, typingctx, targetctx, args, f_ir):
self.typemap = None
self.return_type = None
self.calltypes = None


def _get_const_index_expr(stencil_ir, func_ir, index_var):
"""
infer index_var as constant if it is of a expression form like c-1 where c
is a constant in the outer function.
index_var is assumed to be inside stencil kernel
"""
const_val = guard(
_get_const_index_expr_inner, stencil_ir, func_ir, index_var)
if const_val is not None:
return const_val
return index_var

def _get_const_index_expr_inner(stencil_ir, func_ir, index_var):
"""inner constant inference function that calls constant, unary and binary
cases.
"""
require(isinstance(index_var, ir.Var))
# case where the index is a const itself in outer function
var_const = guard(_get_const_two_irs, stencil_ir, func_ir, index_var)
if var_const is not None:
return var_const
# get index definition
index_def = ir_utils.get_definition(stencil_ir, index_var)
# match inner_var = unary(index_var)
var_const = guard(
_get_const_unary_expr, stencil_ir, func_ir, index_def)
if var_const is not None:
return var_const
# match inner_var = arg1 + arg2
var_const = guard(
_get_const_binary_expr, stencil_ir, func_ir, index_def)
if var_const is not None:
return var_const
raise GuardException

def _get_const_two_irs(ir1, ir2, var):
"""get constant in either of two IRs if available
otherwise, throw GuardException
"""
var_const = guard(find_const, ir1, var)
if var_const is not None:
return var_const
var_const = guard(find_const, ir2, var)
if var_const is not None:
return var_const
raise GuardException

def _get_const_unary_expr(stencil_ir, func_ir, index_def):
"""evaluate constant unary expr if possible
otherwise, raise GuardException
"""
require(isinstance(index_def, ir.Expr) and index_def.op == 'unary')
inner_var = index_def.value
# return -c as constant
const_val = _get_const_index_expr_inner(stencil_ir, func_ir, inner_var)
return eval("{}{}".format(index_def.fn, const_val))

def _get_const_binary_expr(stencil_ir, func_ir, index_def):
"""evaluate constant binary expr if possible
otherwise, raise GuardException
"""
require(isinstance(index_def, ir.Expr) and index_def.op == 'binop')
arg1 = _get_const_index_expr_inner(stencil_ir, func_ir, index_def.lhs)
arg2 = _get_const_index_expr_inner(stencil_ir, func_ir, index_def.rhs)
return eval("{}{}{}".format(arg1, index_def.fn, arg2))
81 changes: 81 additions & 0 deletions numba/tests/test_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,87 @@ def test_impl_seq(n):
n = 100
self.check(test_impl_seq, test_impl, n)

@skip_unsupported
@tag('important')
def test_stencil_call_const(self):
"""Tests numba.stencil call that has an index that can be inferred as
constant from a unary expr. Otherwise, this would raise an error since
neighborhood length is not specified.
"""
def test_impl1(n):
A = np.arange(n)
B = np.zeros(n)
c = 1
numba.stencil(lambda a,c : 0.3 * (a[-c] + a[0] + a[c]))(
A, c, out=B)
return B

def test_impl2(n):
A = np.arange(n)
B = np.zeros(n)
c = 2
numba.stencil(lambda a,c : 0.3 * (a[1-c] + a[0] + a[c-1]))(
A, c, out=B)
return B

# recursive expr case
def test_impl3(n):
A = np.arange(n)
B = np.zeros(n)
c = 2
numba.stencil(lambda a,c : 0.3 * (a[-c+1] + a[0] + a[c-1]))(
A, c, out=B)
return B

# multi-constant case
def test_impl4(n):
A = np.arange(n)
B = np.zeros(n)
d = 1
c = 2
numba.stencil(lambda a,c,d : 0.3 * (a[-c+d] + a[0] + a[c-d]))(
A, c, d, out=B)
return B

def test_impl_seq(n):
A = np.arange(n)
B = np.zeros(n)
c = 1
for i in range(1, n - 1):
B[i] = 0.3 * (A[i - c] + A[i] + A[i + c])
return B

n = 100
# constant inference is only possible in parallel path
cpfunc1 = self.compile_parallel(test_impl1, (types.intp,))
cpfunc2 = self.compile_parallel(test_impl2, (types.intp,))
cpfunc3 = self.compile_parallel(test_impl3, (types.intp,))
cpfunc4 = self.compile_parallel(test_impl4, (types.intp,))
expected = test_impl_seq(n)
# parfor result
parfor_output1 = cpfunc1.entry_point(n)
parfor_output2 = cpfunc2.entry_point(n)
parfor_output3 = cpfunc3.entry_point(n)
parfor_output4 = cpfunc4.entry_point(n)
np.testing.assert_almost_equal(parfor_output1, expected, decimal=3)
np.testing.assert_almost_equal(parfor_output2, expected, decimal=3)
np.testing.assert_almost_equal(parfor_output3, expected, decimal=3)
np.testing.assert_almost_equal(parfor_output4, expected, decimal=3)

# check error in regular Python path
with self.assertRaises(ValueError) as e:
test_impl4(4)

self.assertIn("stencil kernel index is not constant, "
"'neighborhood' option required", str(e.exception))
# check error in njit path
# TODO: ValueError should be thrown instead of LoweringError
with self.assertRaises(LoweringError) as e:
njit(test_impl4)(4)

self.assertIn("stencil kernel index is not constant, "
"'neighborhood' option required", str(e.exception))

@skip_unsupported
@tag('important')
def test_stencil_parallel_off(self):
Expand Down

0 comments on commit 252957c

Please sign in to comment.