Skip to content

Commit

Permalink
Fixes literally not forcing re-dispatch for inline='always'
Browse files Browse the repository at this point in the history
As title.

Fixes numba#5887
  • Loading branch information
stuartarchibald committed Jun 19, 2020
1 parent 92df8df commit 5583a3e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
17 changes: 15 additions & 2 deletions numba/core/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from types import MethodType, FunctionType

import numba
from numba.core import types, utils
from numba.core import types, utils, errors
from numba.core.errors import TypingError, InternalError
from numba.core.cpu_options import InlineOptions

Expand Down Expand Up @@ -517,7 +517,20 @@ def generic(self, args, kws):
inline_worker = InlineWorker(tyctx, tgctx, fcomp.locals,
compiler_inst, flags, None,)

ir = inline_worker.run_untyped_passes(disp_type.dispatcher.py_func)
# If the inlinee contains something to trigger literal arg dispatch
# then the pipeline will unconditionally fail, but the
# ForceLiteralArg instance that causes the failure will not have
# been through the path that adds `fold_arguments` to the exception,
# which means there's nothing to call if this exception propagates
# up to the dispatcher. As a result, it's best to pretend it didn't
# happen and continue to `resolve` as that will trigger a
# ForceLiteralArg exception instance that is appropriately formed
# for consumption in the dispatcher.
try:
ir = inline_worker.run_untyped_passes(
disp_type.dispatcher.py_func)
except errors.ForceLiteralArg:
pass
resolve = disp_type.dispatcher.get_call_template
template, pysig, folded_args, kws = resolve(new_args, kws)

Expand Down
44 changes: 43 additions & 1 deletion numba/tests/test_ir_inlining.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from itertools import product
import numpy as np

from numba import njit, typeof
from numba import njit, typeof, literally
from numba.core import types, ir, ir_utils, cgutils
from numba.core.extending import (
overload,
Expand Down Expand Up @@ -816,6 +816,48 @@ def impl():
for val in consts:
self.assertNotEqual(val.value, 1234)

def test_overload_inline_always_with_literally_in_inlinee(self):
# See issue #5887

def foo_ovld(dtype):

if not isinstance(dtype, types.StringLiteral):
def foo_noop(dtype):
return literally(dtype)
return foo_noop

if dtype.literal_value == 'str':
def foo_as_str_impl(dtype):
return 10
return foo_as_str_impl

if dtype.literal_value in ('int64', 'float64'):
def foo_as_num_impl(dtype):
return 20
return foo_as_num_impl

# define foo for literal str 'str'
def foo(dtype):
return 10

overload(foo, inline='always')(foo_ovld)

def test_impl(dtype):
return foo(dtype)

# check literal dispatch on 'str'
dtype = 'str'
self.check(test_impl, dtype, inline_expect={'foo': True})

# redefine foo to be correct for literal str 'int64'
def foo(dtype):
return 20
overload(foo, inline='always')(foo_ovld)

# check literal dispatch on 'int64'
dtype = 'int64'
self.check(test_impl, dtype, inline_expect={'foo': True})


class TestOverloadMethsAttrsInlining(InliningBase):
def setUp(self):
Expand Down

0 comments on commit 5583a3e

Please sign in to comment.