Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix decorators warnings stacklevel #6183

Merged
Merged
25 changes: 25 additions & 0 deletions skimage/_shared/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import warnings
import inspect

import numpy as np
import pytest
Expand All @@ -10,6 +11,8 @@
change_default_value, remove_arg,
_supported_float_type,
channel_as_last_axis)
from skimage.feature import hog
from skimage.transform import pyramid_gaussian

complex_dtypes = [np.complex64, np.complex128]
if hasattr(np, 'complex256'):
Expand Down Expand Up @@ -275,3 +278,25 @@ def test_decorated_channel_axis_shape(channel_axis):
assert size is None
else:
assert size == x.shape[channel_axis]


def test_decorator_warnings():
"""Assert that warning message issued by decorator points to
expected file and line number.
"""

with pytest.warns(FutureWarning) as record:
pyramid_gaussian(None, multichannel=True)
expected_lineno = inspect.currentframe().f_lineno - 1

assert record[0].lineno == expected_lineno
assert record[0].filename == __file__

img = np.random.rand(100, 100, 3)

with pytest.warns(FutureWarning) as record:
hog(img, multichannel=True)
expected_lineno = inspect.currentframe().f_lineno - 1

assert record[0].lineno == expected_lineno
assert record[0].filename == __file__
78 changes: 67 additions & 11 deletions skimage/_shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,43 @@ class skimage_deprecation(Warning):
pass


class change_default_value:
def _get_stack_rank(func):
"""Return function rank in the call stack."""
if _is_wrapped(func):
return 1 + _get_stack_rank(func.__wrapped__)
else:
return 0


def _is_wrapped(func):
return "__wrapped__" in dir(func)


rfezzani marked this conversation as resolved.
Show resolved Hide resolved
def _get_stack_length(func):
"""Return function call stack length."""
return _get_stack_rank(func.__globals__.get(func.__name__, func))


class _DecoratorBaseClass:
"""Used to manage decorators' warnings stacklevel.
The `_stack_length` class variable is used to store the number of
times a function is wrapped by a decorator.
Let `stack_length` be the total number of times a decorated
function is wrapped, and `stack_rank` be the rank of the decorator
in the decorators stack. The stacklevel of a warning is then
`stacklevel = 1 + stack_length - stack_rank`.
"""

_stack_length = {}

def get_stack_length(self, func):
return self._stack_length.get(func.__name__,
_get_stack_length(func))


class change_default_value(_DecoratorBaseClass):
"""Decorator for changing the default value of an argument.
Parameters
Expand Down Expand Up @@ -53,6 +89,8 @@ def __call__(self, func):
arg_idx = list(parameters.keys()).index(self.arg_name)
old_value = parameters[self.arg_name].default

stack_rank = _get_stack_rank(func)

if self.warning_msg is None:
self.warning_msg = (
f'The new recommended value for {self.arg_name} is '
Expand All @@ -64,15 +102,17 @@ def __call__(self, func):

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
# warn that arg_name default value changed:
warnings.warn(self.warning_msg, FutureWarning, stacklevel=2)
warnings.warn(self.warning_msg, FutureWarning,
stacklevel=stacklevel)
return func(*args, **kwargs)

return fixed_func


class remove_arg:
class remove_arg(_DecoratorBaseClass):
"""Decorator to remove an argument from function's signature.
Parameters
Expand All @@ -93,6 +133,7 @@ def __init__(self, arg_name, *, changed_version, help_msg=None):
self.changed_version = changed_version

def __call__(self, func):

parameters = inspect.signature(func).parameters
arg_idx = list(parameters.keys()).index(self.arg_name)
warning_msg = (
Expand All @@ -104,11 +145,15 @@ def __call__(self, func):
if self.help_msg is not None:
warning_msg += f' {self.help_msg}'

stack_rank = _get_stack_rank(func)

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) > arg_idx or self.arg_name in kwargs.keys():
# warn that arg_name is deprecated
warnings.warn(warning_msg, FutureWarning, stacklevel=2)
warnings.warn(warning_msg, FutureWarning,
stacklevel=stacklevel)
return func(*args, **kwargs)

return fixed_func
Expand Down Expand Up @@ -180,7 +225,7 @@ def docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
return final_docstring


class deprecate_kwarg:
class deprecate_kwarg(_DecoratorBaseClass):
"""Decorator ensuring backward compatibility when argument names are
modified in a function definition.
Expand Down Expand Up @@ -208,7 +253,7 @@ def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None,
"for `{func_name}`. ")
if removed_version is not None:
self.warning_msg += (f'It will be removed in '
f'version {removed_version}.')
f'version {removed_version}. ')
self.warning_msg += "Please use `{new_arg}` instead."
else:
self.warning_msg = warning_msg
Expand All @@ -217,14 +262,19 @@ def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None,

def __call__(self, func):

stack_rank = _get_stack_rank(func)

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank

for old_arg, new_arg in self.kwarg_mapping.items():
if old_arg in kwargs:
# warn that the function interface has changed:
warnings.warn(self.warning_msg.format(
old_arg=old_arg, func_name=func.__name__,
new_arg=new_arg), FutureWarning, stacklevel=2)
new_arg=new_arg), FutureWarning,
stacklevel=stacklevel)
# Substitute new_arg to old_arg
kwargs[new_arg] = kwargs.pop(old_arg)

Expand Down Expand Up @@ -258,8 +308,12 @@ def __init__(self, removed_version='1.0', multichannel_position=None):
self.position = multichannel_position

def __call__(self, func):

stack_rank = _get_stack_rank(func)

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank

if self.position is not None and len(args) > self.position:
warning_msg = (
Expand All @@ -269,7 +323,7 @@ def fixed_func(*args, **kwargs):
)
warnings.warn(warning_msg.format(func_name=func.__name__),
FutureWarning,
stacklevel=2)
stacklevel=stacklevel)
if 'channel_axis' in kwargs:
raise ValueError(
"Cannot provide both a `channel_axis` kwarg and a "
Expand All @@ -283,7 +337,8 @@ def fixed_func(*args, **kwargs):
# warn that the function interface has changed:
warnings.warn(self.warning_msg.format(
old_arg='multichannel', func_name=func.__name__,
new_arg='channel_axis'), FutureWarning, stacklevel=2)
new_arg='channel_axis'), FutureWarning,
stacklevel=stacklevel)

# multichannel = True -> last axis corresponds to channels
convert = {True: -1, False: None}
Expand All @@ -299,7 +354,7 @@ def fixed_func(*args, **kwargs):
return fixed_func


class channel_as_last_axis():
class channel_as_last_axis:
"""Decorator for automatically making channels axis last for all arrays.
This decorator reorders axes for compatibility with functions that only
Expand Down Expand Up @@ -329,6 +384,7 @@ def __init__(self, channel_arg_positions=(0,), channel_kwarg_names=(),
self.multichannel_output = multichannel_output

def __call__(self, func):

@functools.wraps(func)
def fixed_func(*args, **kwargs):

Expand Down Expand Up @@ -376,7 +432,7 @@ def fixed_func(*args, **kwargs):
return fixed_func


class deprecated(object):
class deprecated:
"""Decorator to mark deprecated functions with warning.
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
Expand Down