Skip to content

Commit

Permalink
Fix decorators warnings stacklevel (#6183)
Browse files Browse the repository at this point in the history
* Add _DecoratorBaseClass.update_stacklevel

* Add test_decorator_warnings_stacklevel
  • Loading branch information
rfezzani committed Feb 8, 2022
1 parent c295090 commit 6877386
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
25 changes: 25 additions & 0 deletions skimage/_shared/tests/test_utils.py
@@ -1,5 +1,6 @@
import sys
import warnings
import inspect

import numpy as np
import pytest
Expand All @@ -10,6 +11,8 @@
change_default_value, remove_arg,
_supported_float_type,
channel_as_last_axis)
from skimage.feature import hog
from skimage.transform import pyramid_gaussian

complex_dtypes = [np.complex64, np.complex128]
if hasattr(np, 'complex256'):
Expand Down Expand Up @@ -275,3 +278,25 @@ def test_decorated_channel_axis_shape(channel_axis):
assert size is None
else:
assert size == x.shape[channel_axis]


def test_decorator_warnings():
"""Assert that warning message issued by decorator points to
expected file and line number.
"""

with pytest.warns(FutureWarning) as record:
pyramid_gaussian(None, multichannel=True)
expected_lineno = inspect.currentframe().f_lineno - 1

assert record[0].lineno == expected_lineno
assert record[0].filename == __file__

img = np.random.rand(100, 100, 3)

with pytest.warns(FutureWarning) as record:
hog(img, multichannel=True)
expected_lineno = inspect.currentframe().f_lineno - 1

assert record[0].lineno == expected_lineno
assert record[0].filename == __file__
78 changes: 67 additions & 11 deletions skimage/_shared/utils.py
Expand Up @@ -24,7 +24,43 @@ class skimage_deprecation(Warning):
pass


class change_default_value:
def _get_stack_rank(func):
"""Return function rank in the call stack."""
if _is_wrapped(func):
return 1 + _get_stack_rank(func.__wrapped__)
else:
return 0


def _is_wrapped(func):
return "__wrapped__" in dir(func)


def _get_stack_length(func):
"""Return function call stack length."""
return _get_stack_rank(func.__globals__.get(func.__name__, func))


class _DecoratorBaseClass:
"""Used to manage decorators' warnings stacklevel.
The `_stack_length` class variable is used to store the number of
times a function is wrapped by a decorator.
Let `stack_length` be the total number of times a decorated
function is wrapped, and `stack_rank` be the rank of the decorator
in the decorators stack. The stacklevel of a warning is then
`stacklevel = 1 + stack_length - stack_rank`.
"""

_stack_length = {}

def get_stack_length(self, func):
return self._stack_length.get(func.__name__,
_get_stack_length(func))


class change_default_value(_DecoratorBaseClass):
"""Decorator for changing the default value of an argument.
Parameters
Expand Down Expand Up @@ -53,6 +89,8 @@ def __call__(self, func):
arg_idx = list(parameters.keys()).index(self.arg_name)
old_value = parameters[self.arg_name].default

stack_rank = _get_stack_rank(func)

if self.warning_msg is None:
self.warning_msg = (
f'The new recommended value for {self.arg_name} is '
Expand All @@ -64,15 +102,17 @@ def __call__(self, func):

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
# warn that arg_name default value changed:
warnings.warn(self.warning_msg, FutureWarning, stacklevel=2)
warnings.warn(self.warning_msg, FutureWarning,
stacklevel=stacklevel)
return func(*args, **kwargs)

return fixed_func


class remove_arg:
class remove_arg(_DecoratorBaseClass):
"""Decorator to remove an argument from function's signature.
Parameters
Expand All @@ -93,6 +133,7 @@ def __init__(self, arg_name, *, changed_version, help_msg=None):
self.changed_version = changed_version

def __call__(self, func):

parameters = inspect.signature(func).parameters
arg_idx = list(parameters.keys()).index(self.arg_name)
warning_msg = (
Expand All @@ -104,11 +145,15 @@ def __call__(self, func):
if self.help_msg is not None:
warning_msg += f' {self.help_msg}'

stack_rank = _get_stack_rank(func)

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) > arg_idx or self.arg_name in kwargs.keys():
# warn that arg_name is deprecated
warnings.warn(warning_msg, FutureWarning, stacklevel=2)
warnings.warn(warning_msg, FutureWarning,
stacklevel=stacklevel)
return func(*args, **kwargs)

return fixed_func
Expand Down Expand Up @@ -180,7 +225,7 @@ def docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
return final_docstring


class deprecate_kwarg:
class deprecate_kwarg(_DecoratorBaseClass):
"""Decorator ensuring backward compatibility when argument names are
modified in a function definition.
Expand Down Expand Up @@ -208,7 +253,7 @@ def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None,
"for `{func_name}`. ")
if removed_version is not None:
self.warning_msg += (f'It will be removed in '
f'version {removed_version}.')
f'version {removed_version}. ')
self.warning_msg += "Please use `{new_arg}` instead."
else:
self.warning_msg = warning_msg
Expand All @@ -217,14 +262,19 @@ def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None,

def __call__(self, func):

stack_rank = _get_stack_rank(func)

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank

for old_arg, new_arg in self.kwarg_mapping.items():
if old_arg in kwargs:
# warn that the function interface has changed:
warnings.warn(self.warning_msg.format(
old_arg=old_arg, func_name=func.__name__,
new_arg=new_arg), FutureWarning, stacklevel=2)
new_arg=new_arg), FutureWarning,
stacklevel=stacklevel)
# Substitute new_arg to old_arg
kwargs[new_arg] = kwargs.pop(old_arg)

Expand Down Expand Up @@ -258,8 +308,12 @@ def __init__(self, removed_version='1.0', multichannel_position=None):
self.position = multichannel_position

def __call__(self, func):

stack_rank = _get_stack_rank(func)

@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank

if self.position is not None and len(args) > self.position:
warning_msg = (
Expand All @@ -269,7 +323,7 @@ def fixed_func(*args, **kwargs):
)
warnings.warn(warning_msg.format(func_name=func.__name__),
FutureWarning,
stacklevel=2)
stacklevel=stacklevel)
if 'channel_axis' in kwargs:
raise ValueError(
"Cannot provide both a `channel_axis` kwarg and a "
Expand All @@ -283,7 +337,8 @@ def fixed_func(*args, **kwargs):
# warn that the function interface has changed:
warnings.warn(self.warning_msg.format(
old_arg='multichannel', func_name=func.__name__,
new_arg='channel_axis'), FutureWarning, stacklevel=2)
new_arg='channel_axis'), FutureWarning,
stacklevel=stacklevel)

# multichannel = True -> last axis corresponds to channels
convert = {True: -1, False: None}
Expand All @@ -299,7 +354,7 @@ def fixed_func(*args, **kwargs):
return fixed_func


class channel_as_last_axis():
class channel_as_last_axis:
"""Decorator for automatically making channels axis last for all arrays.
This decorator reorders axes for compatibility with functions that only
Expand Down Expand Up @@ -329,6 +384,7 @@ def __init__(self, channel_arg_positions=(0,), channel_kwarg_names=(),
self.multichannel_output = multichannel_output

def __call__(self, func):

@functools.wraps(func)
def fixed_func(*args, **kwargs):

Expand Down Expand Up @@ -376,7 +432,7 @@ def fixed_func(*args, **kwargs):
return fixed_func


class deprecated(object):
class deprecated:
"""Decorator to mark deprecated functions with warning.
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
Expand Down

0 comments on commit 6877386

Please sign in to comment.