-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
utils.py
375 lines (300 loc) · 12.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import inspect
import warnings
import functools
import sys
import numpy as np
import numbers
from ..util import img_as_float
from ._warnings import all_warnings, warn
__all__ = ['deprecated', 'get_bound_method_class', 'all_warnings',
'safe_as_int', 'check_nD', 'check_shape_equality', 'warn']
class skimage_deprecation(Warning):
"""Create our own deprecation class, since Python >= 2.7
silences deprecations by default.
"""
pass
class change_default_value:
"""Decorator for changing the default value of an argument.
Parameters
----------
arg_name: str
The name of the argument to be updated.
new_value: any
The argument new value.
changed_version : str
The package version in which the change will be introduced.
warning_msg: str
Optional warning message. If None, a generic warning message
is used.
"""
def __init__(self, arg_name, *, new_value, changed_version,
warning_msg=None):
self.arg_name = arg_name
self.new_value = new_value
self.warning_msg = warning_msg
self.changed_version = changed_version
def __call__(self, func):
parameters = inspect.signature(func).parameters
arg_idx = list(parameters.keys()).index(self.arg_name)
old_value = parameters[self.arg_name].default
if self.warning_msg is None:
self.warning_msg = (
f"The new recommended value for {self.arg_name} is "
f"{self.new_value}. Until version {self.changed_version}, "
f"the default {self.arg_name} value is {old_value}. "
f"From version {self.changed_version}, the {self.arg_name} "
f"default value will be {self.new_value}. To avoid "
f"this warning, please explicitly set {self.arg_name} value.")
@functools.wraps(func)
def fixed_func(*args, **kwargs):
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
# warn that arg_name default value changed:
warnings.warn(self.warning_msg, FutureWarning, stacklevel=2)
return func(*args, **kwargs)
return fixed_func
class deprecate_kwarg:
"""Decorator ensuring backward compatibility when argument names are
modified in a function definition.
Parameters
----------
arg_mapping: dict
Mapping between the function's old argument names and the new
ones.
warning_msg: str
Optional warning message. If None, a generic warning message
is used.
removed_version : str
The package version in which the deprecated argument will be
removed.
"""
def __init__(self, kwarg_mapping, warning_msg=None, removed_version=None):
self.kwarg_mapping = kwarg_mapping
if warning_msg is None:
self.warning_msg = ("'{old_arg}' is a deprecated argument name "
"for `{func_name}`. ")
if removed_version is not None:
self.warning_msg += ("It will be removed in version {}. "
.format(removed_version))
self.warning_msg += "Please use '{new_arg}' instead."
else:
self.warning_msg = warning_msg
def __call__(self, func):
@functools.wraps(func)
def fixed_func(*args, **kwargs):
for old_arg, new_arg in self.kwarg_mapping.items():
if old_arg in kwargs:
# warn that the function interface has changed:
warnings.warn(self.warning_msg.format(
old_arg=old_arg, func_name=func.__name__,
new_arg=new_arg), FutureWarning, stacklevel=2)
# Substitute new_arg to old_arg
kwargs[new_arg] = kwargs.pop(old_arg)
# Call the function with the fixed arguments
return func(*args, **kwargs)
return fixed_func
class deprecated(object):
"""Decorator to mark deprecated functions with warning.
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
Parameters
----------
alt_func : str
If given, tell user what function to use instead.
behavior : {'warn', 'raise'}
Behavior during call to deprecated function: 'warn' = warn user that
function is deprecated; 'raise' = raise error.
removed_version : str
The package version in which the deprecated function will be removed.
"""
def __init__(self, alt_func=None, behavior='warn', removed_version=None):
self.alt_func = alt_func
self.behavior = behavior
self.removed_version = removed_version
def __call__(self, func):
alt_msg = ''
if self.alt_func is not None:
alt_msg = ' Use ``%s`` instead.' % self.alt_func
rmv_msg = ''
if self.removed_version is not None:
rmv_msg = (' and will be removed in version %s' %
self.removed_version)
msg = ('Function ``%s`` is deprecated' % func.__name__ +
rmv_msg + '.' + alt_msg)
@functools.wraps(func)
def wrapped(*args, **kwargs):
if self.behavior == 'warn':
func_code = func.__code__
warnings.simplefilter('always', skimage_deprecation)
warnings.warn_explicit(msg,
category=skimage_deprecation,
filename=func_code.co_filename,
lineno=func_code.co_firstlineno + 1)
elif self.behavior == 'raise':
raise skimage_deprecation(msg)
return func(*args, **kwargs)
# modify doc string to display deprecation warning
doc = '**Deprecated function**.' + alt_msg
if wrapped.__doc__ is None:
wrapped.__doc__ = doc
else:
wrapped.__doc__ = doc + '\n\n ' + wrapped.__doc__
return wrapped
def get_bound_method_class(m):
"""Return the class for a bound method.
"""
return m.im_class if sys.version < '3' else m.__self__.__class__
def safe_as_int(val, atol=1e-3):
"""
Attempt to safely cast values to integer format.
Parameters
----------
val : scalar or iterable of scalars
Number or container of numbers which are intended to be interpreted as
integers, e.g., for indexing purposes, but which may not carry integer
type.
atol : float
Absolute tolerance away from nearest integer to consider values in
``val`` functionally integers.
Returns
-------
val_int : NumPy scalar or ndarray of dtype `np.int64`
Returns the input value(s) coerced to dtype `np.int64` assuming all
were within ``atol`` of the nearest integer.
Notes
-----
This operation calculates ``val`` modulo 1, which returns the mantissa of
all values. Then all mantissas greater than 0.5 are subtracted from one.
Finally, the absolute tolerance from zero is calculated. If it is less
than ``atol`` for all value(s) in ``val``, they are rounded and returned
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
returned.
If any value(s) are outside the specified tolerance, an informative error
is raised.
Examples
--------
>>> safe_as_int(7.0)
7
>>> safe_as_int([9, 4, 2.9999999999])
array([9, 4, 3])
>>> safe_as_int(53.1)
Traceback (most recent call last):
...
ValueError: Integer argument required but received 53.1, check inputs.
>>> safe_as_int(53.01, atol=0.01)
53
"""
mod = np.asarray(val) % 1 # Extract mantissa
# Check for and subtract any mod values > 0.5 from 1
if mod.ndim == 0: # Scalar input, cannot be indexed
if mod > 0.5:
mod = 1 - mod
else: # Iterable input, now ndarray
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
try:
np.testing.assert_allclose(mod, 0, atol=atol)
except AssertionError:
raise ValueError("Integer argument required but received "
"{0}, check inputs.".format(val))
return np.round(val).astype(np.int64)
def check_shape_equality(im1, im2):
"""Raise an error if the shape do not match."""
if not im1.shape == im2.shape:
raise ValueError('Input images must have the same dimensions.')
return
def check_nD(array, ndim, arg_name='image'):
"""
Verify an array meets the desired ndims and array isn't empty.
Parameters
----------
array : array-like
Input array to be validated
ndim : int or iterable of ints
Allowable ndim or ndims for the array.
arg_name : str, optional
The name of the array in the original function.
"""
array = np.asanyarray(array)
msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
msg_empty_array = "The parameter `%s` cannot be an empty array"
if isinstance(ndim, int):
ndim = [ndim]
if array.size == 0:
raise ValueError(msg_empty_array % (arg_name))
if not array.ndim in ndim:
raise ValueError(msg_incorrect_dim % (arg_name, '-or-'.join([str(n) for n in ndim])))
def check_random_state(seed):
"""Turn seed into a `np.random.RandomState` instance.
Parameters
----------
seed : None, int or np.random.RandomState
If `seed` is None, return the RandomState singleton used by `np.random`.
If `seed` is an int, return a new RandomState instance seeded with `seed`.
If `seed` is already a RandomState instance, return it.
Raises
------
ValueError
If `seed` is of the wrong type.
"""
# Function originally from scikit-learn's module sklearn.utils.validation
if seed is None or seed is np.random:
return np.random.mtrand._rand
if isinstance(seed, (numbers.Integral, np.integer)):
return np.random.RandomState(seed)
if isinstance(seed, np.random.RandomState):
return seed
raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
' instance' % seed)
def convert_to_float(image, preserve_range):
"""Convert input image to float image with the appropriate range.
Parameters
----------
image : ndarray
Input image.
preserve_range : bool
Determines if the range of the image should be kept or transformed
using img_as_float. Also see
https://scikit-image.org/docs/dev/user_guide/data_types.html
Notes:
------
* Input images with `float32` data type are not upcast.
Returns
-------
image : ndarray
Transformed version of the input.
"""
if preserve_range:
# Convert image to double only if it is not single or double
# precision float
if image.dtype.char not in 'df':
image = image.astype(float)
else:
image = img_as_float(image)
return image
def _validate_interpolation_order(image_dtype, order):
"""Validate and return spline interpolation's order.
Parameters
----------
image_dtype : dtype
Image dtype.
order : int, optional
The order of the spline interpolation. The order has to be in
the range 0-5. See `skimage.transform.warp` for detail.
Returns
-------
order : int
if input order is None, returns 0 if image_dtype is bool and 1
otherwise. Otherwise, image_dtype is checked and input order
is validated accordingly (order > 0 is not supported for bool
image dtype)
"""
if order is None:
return 0 if image_dtype == bool else 1
if order < 0 or order > 5:
raise ValueError("Spline interpolation order has to be in the "
"range 0-5.")
if image_dtype == bool and order != 0:
warn("Input image dtype is bool. Interpolation is not defined "
"with bool data type. Please set order to 0 or explicitely "
"cast input image to another data type. Starting from version "
"0.19 a ValueError will be raised instead of this warning.",
FutureWarning, stacklevel=2)
return order