Skip to content

Commit

Permalink
explicit declaration of ref type
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartarchibald committed Mar 8, 2019
1 parent 2f40dd9 commit 83ae051
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 22 deletions.
13 changes: 7 additions & 6 deletions numba/targets/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
lower_setattr_generic,
lower_cast, lower_constant,
iternext_impl, impl_ret_borrowed,
impl_ret_new_ref, impl_ret_untracked)
impl_ret_new_ref, impl_ret_untracked,
RefType)
from numba.typing import signature
from numba.extending import register_jitable, overload, overload_method
from . import quicksort, mergesort, slicing
Expand Down Expand Up @@ -277,7 +278,7 @@ def _getitem_array1d(context, builder, arrayty, array, idx, wraparound):
return load_item(context, builder, arrayty, ptr)

@lower_builtin('iternext', types.ArrayIterator)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_array(context, builder, sig, args, result):
[iterty] = sig.args
[iter] = args
Expand Down Expand Up @@ -2976,7 +2977,7 @@ def make_array_flatiter(context, builder, arrty, arr):


@lower_builtin('iternext', types.NumpyFlatType)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_numpy_flatiter(context, builder, sig, args, result):
[flatiterty] = sig.args
[flatiter] = args
Expand Down Expand Up @@ -3054,7 +3055,7 @@ def make_array_ndenumerate(context, builder, sig, args):


@lower_builtin('iternext', types.NumpyNdEnumerateType)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_numpy_nditer(context, builder, sig, args, result):
[nditerty] = sig.args
[nditer] = args
Expand Down Expand Up @@ -3106,7 +3107,7 @@ def make_array_ndindex(context, builder, sig, args):
return impl_ret_borrowed(context, builder, sig.return_type, res)

@lower_builtin('iternext', types.NumpyNdIndexType)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_numpy_ndindex(context, builder, sig, args, result):
[nditerty] = sig.args
[nditer] = args
Expand Down Expand Up @@ -3137,7 +3138,7 @@ def make_array_nditer(context, builder, sig, args):
return impl_ret_borrowed(context, builder, nditerty, res)

@lower_builtin('iternext', types.NumpyNdIterType)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_numpy_ndindex(context, builder, sig, args, result):
[nditerty] = sig.args
[nditer] = args
Expand Down
32 changes: 28 additions & 4 deletions numba/targets/imputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import contextlib
import inspect
import functools
from enum import Enum

from .. import typing, cgutils, types, utils
from .. typing.templates import BaseRegistryLoader
Expand Down Expand Up @@ -234,7 +235,7 @@ def wrapper(cls):
# These are unbound methods
iternext = cls.iternext

@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_wrapper(context, builder, sig, args, result):
(value,) = args
iterobj = cls(context, builder, value)
Expand Down Expand Up @@ -293,8 +294,24 @@ def yielded_value(self):
"""
return self._pairobj.first

class RefType(Enum):
"""
Enumerate the reference type
"""
"""
A new reference
"""
NEW = 1
"""
A borrowed reference
"""
BORROWED = 2
"""
An untracked reference
"""
UNTRACKED = 3

def iternext_impl(new_ref=False):
def iternext_impl(ref_type=None):
"""
Wrap the given iternext() implementation so that it gets passed
an _IternextResult() object easing the returning of the iternext()
Expand All @@ -305,16 +322,23 @@ def iternext_impl(new_ref=False):
The wrapped function will be called with the following signature:
(context, builder, sig, args, iternext_result)
"""
if ref_type is None:
raise ValueError("ref_type must be an enum member of imputils.RefType")

def outer(func):
def wrapper(context, builder, sig, args):
pair_type = sig.return_type
pairobj = context.make_helper(builder, pair_type)
func(context, builder, sig, args,
_IternextResult(context, builder, pairobj))
if new_ref:
if ref_type == RefType.NEW:
impl_ret = impl_ret_new_ref
else:
elif ref_type == RefType.BORROWED:
impl_ret = impl_ret_borrowed
elif ref_type == RefType.UNTRACKED:
impl_ret = impl_ret_untracked
else:
raise ValueError("Unknown ref_type encountered")
return impl_ret(context, builder,
pair_type, pairobj._getvalue())
return wrapper
Expand Down
8 changes: 4 additions & 4 deletions numba/targets/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numba import types, cgutils
from numba.targets.imputils import (
lower_builtin, iternext_impl, call_iternext, call_getiter,
impl_ret_borrowed, impl_ret_new_ref)
impl_ret_borrowed, impl_ret_new_ref, RefType)



Expand Down Expand Up @@ -44,7 +44,7 @@ def make_enumerate_object(context, builder, sig, args):
return impl_ret_new_ref(context, builder, sig.return_type, res)

@lower_builtin('iternext', types.EnumerateType)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_enumerate(context, builder, sig, args, result):
[enumty] = sig.args
[enum] = args
Expand Down Expand Up @@ -87,7 +87,7 @@ def make_zip_object(context, builder, sig, args):
return impl_ret_new_ref(context, builder, sig.return_type, res)

@lower_builtin('iternext', types.ZipType)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_zip(context, builder, sig, args, result):
[zip_type] = sig.args
[zipobj] = args
Expand Down Expand Up @@ -125,7 +125,7 @@ def iternext_zip(context, builder, sig, args, result):
# generator implementation

@lower_builtin('iternext', types.Generator)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_zip(context, builder, sig, args, result):
genty, = sig.args
gen, = args
Expand Down
5 changes: 3 additions & 2 deletions numba/targets/listobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from numba import types, cgutils, typing
from numba.targets.imputils import (lower_builtin, lower_cast,
iternext_impl, impl_ret_borrowed,
impl_ret_new_ref, impl_ret_untracked)
impl_ret_new_ref, impl_ret_untracked,
RefType)
from numba.utils import cached_property
from . import quicksort, slicing

Expand Down Expand Up @@ -487,7 +488,7 @@ def getiter_list(context, builder, sig, args):
return impl_ret_borrowed(context, builder, sig.return_type, inst.value)

@lower_builtin('iternext', types.ListIter)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_listiter(context, builder, sig, args, result):
inst = ListIterInstance(context, builder, sig.args[0], args[0])

Expand Down
4 changes: 2 additions & 2 deletions numba/targets/setobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numba.targets.imputils import (lower_builtin, lower_cast,
iternext_impl, impl_ret_borrowed,
impl_ret_new_ref, impl_ret_untracked,
for_iter, call_len)
for_iter, call_len, RefType)
from numba.utils import cached_property
from . import quicksort, slicing

Expand Down Expand Up @@ -1215,7 +1215,7 @@ def getiter_set(context, builder, sig, args):
return impl_ret_borrowed(context, builder, sig.return_type, inst.value)

@lower_builtin('iternext', types.SetIter)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_listiter(context, builder, sig, args, result):
inst = SetIterInstance(context, builder, sig.args[0], args[0])
inst.iternext(result)
Expand Down
5 changes: 3 additions & 2 deletions numba/targets/tupleobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from .imputils import (lower_builtin, lower_getattr_generic, lower_cast,
lower_constant,
iternext_impl, impl_ret_borrowed, impl_ret_untracked)
iternext_impl, impl_ret_borrowed, impl_ret_untracked,
RefType)
from .. import typing, types, cgutils
from ..extending import overload_method

Expand Down Expand Up @@ -148,7 +149,7 @@ def getiter_unituple(context, builder, sig, args):


@lower_builtin('iternext', types.UniTupleIter)
@iternext_impl()
@iternext_impl(RefType.BORROWED)
def iternext_unituple(context, builder, sig, args, result):
[tupiterty] = sig.args
[tupiter] = args
Expand Down
4 changes: 2 additions & 2 deletions numba/unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
intrinsic,
)
from numba.targets.imputils import (lower_constant, lower_cast, lower_builtin,
iternext_impl, impl_ret_new_ref)
iternext_impl, impl_ret_new_ref, RefType)
from numba.datamodel import register_default, StructModel
from numba import cgutils
from numba import types
Expand Down Expand Up @@ -771,7 +771,7 @@ def getiter_unicode(context, builder, sig, args):

@lower_builtin('iternext', types.UnicodeIteratorType)
# a new ref counted object is put into result._yield so set the new_ref to True!
@iternext_impl(new_ref=True)
@iternext_impl(RefType.NEW)
def iternext_unicode(context, builder, sig, args, result):
[iterty] = sig.args
[iter] = args
Expand Down

0 comments on commit 83ae051

Please sign in to comment.