Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single precision support in skimage.restoration module #5219

Merged
merged 14 commits into from
May 21, 2021
3 changes: 3 additions & 0 deletions TODO.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ Other (2022)
------------
* Remove conditional import of ``scipy.fft`` in ``skimage._shared.fft`` once
the minimum supported version of ``scipy`` reaches 1.4.
* Remove unneeded `deconv_type` and astype call in `weiner` and
`unsupervised_wiener` in `skimage.restoration.deconvolution` once the minimum
supported version of ``scipy`` reaches 1.4.
* Remove pillow version related warning for CVE when pillow > 8.1.2 in
`skimage/io/_plugins/pil_plugin.py` and `skimage/io/collection.py`.

Expand Down
16 changes: 11 additions & 5 deletions skimage/restoration/_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import pywt

from .. import img_as_float
from ._denoise_cy import _denoise_bilateral, _denoise_tv_bregman
from .._shared import utils
from .._shared.utils import warn
import skimage.color as color
from skimage.color.colorconv import ycbcr_from_rgb
from .._shared.utils import _supported_float_type, warn
from ._denoise_cy import _denoise_bilateral, _denoise_tv_bregman
from .. import color
from ..color.colorconv import ycbcr_from_rgb


def _gaussian_weight(array, sigma_squared, *, dtype=float):
Expand Down Expand Up @@ -503,10 +503,13 @@ def denoise_tv_chambolle(image, weight=0.1, eps=2.e-4, n_iter_max=200,
if not im_type.kind == 'f':
image = img_as_float(image)

# enforce float16->float32 and float128->float64
float_dtype = _supported_float_type(image.dtype)
image = image.astype(float_dtype, copy=False)

if channel_axis is not None:
channel_axis = channel_axis % image.ndim
_at = functools.partial(utils.slice_at_axis, axis=channel_axis)

out = np.zeros_like(image)
for c in range(image.shape[channel_axis]):
out[_at(c)] = _denoise_tv_chambolle_nd(image[_at(c)], weight, eps,
Expand Down Expand Up @@ -682,6 +685,7 @@ def _scale_sigma_and_image_consistently(image, sigma, multichannel,
"""If the ``image`` is rescaled, also rescale ``sigma`` consistently.

Images that are not floating point will be rescaled via ``img_as_float``.
Half-precision images will be promoted to single precision.
"""
if multichannel:
if isinstance(sigma, numbers.Number) or sigma is None:
Expand All @@ -703,6 +707,8 @@ def _scale_sigma_and_image_consistently(image, sigma, multichannel,
for s in sigma]
elif sigma is not None:
sigma *= scale_factor
elif image.dtype == np.float16:
image = image.astype(np.float32)
return image, sigma


Expand Down
47 changes: 37 additions & 10 deletions skimage/restoration/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from scipy.signal import convolve

from . import uft
from .._shared.utils import _supported_float_type

__keywords__ = "restoration, image, deconvolution"


def wiener(image, psf, balance, reg=None, is_real=True, clip=True):
Expand Down Expand Up @@ -116,6 +116,10 @@ def wiener(image, psf, balance, reg=None, is_real=True, clip=True):
reg, _ = uft.laplacian(image.ndim, image.shape, is_real=is_real)
if not np.iscomplexobj(reg):
reg = uft.ir2tf(reg, image.shape, is_real=is_real)
float_type = _supported_float_type(image.dtype)
image = image.astype(float_type, copy=False)
psf = psf.real.astype(float_type, copy=False)
reg = reg.real.astype(float_type, copy=False)

if psf.shape != reg.shape:
trans_func = uft.ir2tf(psf, image.shape, is_real=is_real)
Expand All @@ -130,6 +134,13 @@ def wiener(image, psf, balance, reg=None, is_real=True, clip=True):
else:
deconv = uft.uifft2(wiener_filter * uft.ufft2(image))

# TODO: can remove astype call below once minimum SciPy >= 1.4
if deconv.dtype.kind == 'c':
deconv_type = np.promote_types(float_type, np.complex64)
else:
deconv_type = float_type
deconv = deconv.astype(deconv_type, copy=False)

if clip:
deconv[deconv > 1] = 1
deconv[deconv < -1] = -1
Expand Down Expand Up @@ -238,16 +249,20 @@ def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True,
reg, _ = uft.laplacian(image.ndim, image.shape, is_real=is_real)
if not np.iscomplexobj(reg):
reg = uft.ir2tf(reg, image.shape, is_real=is_real)
float_type = _supported_float_type(image.dtype)
image = image.astype(float_type, copy=False)
psf = psf.real.astype(float_type, copy=False)
reg = reg.real.astype(float_type, copy=False)

if psf.shape != reg.shape:
trans_fct = uft.ir2tf(psf, image.shape, is_real=is_real)
trans_fct = uft.ir2tf(psf, image.shape, is_real=is_real)
else:
trans_fct = psf

# The mean of the object
x_postmean = np.zeros(trans_fct.shape)
x_postmean = np.zeros(trans_fct.shape, dtype=float_type)
# The previous computed mean in the iterative loop
prev_x_postmean = np.zeros(trans_fct.shape)
prev_x_postmean = np.zeros(trans_fct.shape, dtype=float_type)

# Difference between two successive mean
delta = np.NAN
Expand All @@ -263,19 +278,22 @@ def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True,
# The Fourier transform may change the image.size attribute, so we
# store it.
if is_real:
data_spectrum = uft.urfft2(image.astype(float))
data_spectrum = uft.urfft2(image)
else:
data_spectrum = uft.ufft2(image.astype(float))
data_spectrum = uft.ufft2(image)

# Gibbs sampling
for iteration in range(params['max_iter']):
# Sample of Eq. 27 p(circX^k | gn^k-1, gx^k-1, y).

# weighting (correlation in direct space)
precision = gn_chain[-1] * atf2 + gx_chain[-1] * areg2 # Eq. 29
_rand1 = np.random.standard_normal(data_spectrum.shape)
_rand1 = _rand1.astype(float_type, copy=False)
_rand2 = np.random.standard_normal(data_spectrum.shape)
_rand2 = _rand2.astype(float_type, copy=False)
excursion = np.sqrt(0.5) / np.sqrt(precision) * (
np.random.standard_normal(data_spectrum.shape) +
1j * np.random.standard_normal(data_spectrum.shape))
_rand1 + 1j * _rand2)

# mean Eq. 30 (RLS for fixed gn, gamma0 and gamma1 ...)
wiener_filter = gn_chain[-1] * np.conj(trans_fct) / precision
Expand Down Expand Up @@ -319,6 +337,13 @@ def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True,
else:
x_postmean = uft.uifft2(x_postmean)

# TODO: remove astype call below once minimum SciPy >= 1.4
if x_postmean.dtype.kind == 'c':
deconv_type = np.promote_types(float_type, np.complex64)
else:
deconv_type = float_type
x_postmean = x_postmean.astype(deconv_type, copy=False)

if clip:
x_postmean[x_postmean > 1] = 1
x_postmean[x_postmean < -1] = -1
Expand Down Expand Up @@ -364,7 +389,7 @@ def richardson_lucy(image, psf, iterations=50, clip=True, filter_epsilon=None):
----------
.. [1] https://en.wikipedia.org/wiki/Richardson%E2%80%93Lucy_deconvolution
"""
float_type = np.promote_types(image.dtype, np.float32)
float_type = _supported_float_type(image.dtype)
image = image.astype(float_type, copy=False)
psf = psf.astype(float_type, copy=False)
im_deconv = np.full(image.shape, 0.5, dtype=float_type)
Expand All @@ -373,7 +398,9 @@ def richardson_lucy(image, psf, iterations=50, clip=True, filter_epsilon=None):
for _ in range(iterations):
conv = convolve(im_deconv, psf, mode='same')
if filter_epsilon:
relative_blur = np.where(conv < filter_epsilon, 0, image / conv)
with np.errstate(invalid='ignore'):
relative_blur = np.where(conv < filter_epsilon, 0,
image / conv)
else:
relative_blur = image / conv
im_deconv *= convolve(relative_blur, psf_mirror, mode='same')
Expand Down
6 changes: 5 additions & 1 deletion skimage/restoration/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,12 @@ def inpaint_biharmonic(image, mask, multichannel=False, *,
raise TypeError('Masked arrays are not supported')

image = skimage.img_as_float(image)
mask = mask.astype(bool, copy=False)

# float16->float32 and float128->float64
float_dtype = utils._supported_float_type(image.dtype)
image = image.astype(float_dtype, copy=False)

mask = mask.astype(bool, copy=False)
if not multichannel:
image = image[..., np.newaxis]
out = np.copy(image)
Expand Down
8 changes: 8 additions & 0 deletions skimage/restoration/j_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import numpy as np
from scipy import ndimage as ndi

from .._shared.utils import _supported_float_type
from ..metrics import mean_squared_error
from ..util import img_as_float



def _interpolate_image(image, *, multichannel=False):
"""Replacing each pixel in ``image`` with the average of its neighbors.

Expand Down Expand Up @@ -112,9 +114,15 @@ def _invariant_denoise(image, denoise_function, *, stride=4,
Denoised image, of same shape as `image`.
"""
image = img_as_float(image)

# promote float16->float32 if needed
float_dtype = _supported_float_type(image.dtype)
image = image.astype(float_dtype, copy=False)

if denoiser_kwargs is None:
denoiser_kwargs = {}


if 'multichannel' in denoiser_kwargs:
multichannel = denoiser_kwargs['multichannel']
else:
Expand Down
4 changes: 2 additions & 2 deletions skimage/restoration/non_local_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def denoise_nl_means(image, patch_size=7, patch_distance=11, h=0.1,
preserve_range = True

image = convert_to_float(image, preserve_range)

kwargs = dict(s=patch_size, d=patch_distance, h=h, var=sigma * sigma)
if not image.flags.c_contiguous:
image = np.ascontiguousarray(image)

kwargs = dict(s=patch_size, d=patch_distance, h=h, var=sigma * sigma)
if channel_axis is not None: # 2-D images
if fast_mode:
return _fast_nl_means_denoising_2d(image, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion skimage/restoration/rolling_ball.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from .._shared.utils import _supported_float_type
from ._rolling_ball_cy import apply_kernel, apply_kernel_nan


Expand Down Expand Up @@ -78,14 +79,16 @@ def rolling_ball(image, *, radius=100, kernel=None,
"""

image = np.asarray(image)
img = image.astype(float)
float_type = _supported_float_type(image.dtype)
img = image.astype(float_type, copy=False)

if num_threads is None:
num_threads = 0

if kernel is None:
kernel = ball_kernel(radius, image.ndim)

kernel = kernel.astype(float_type)
kernel_shape = np.asarray(kernel.shape)
kernel_center = (kernel_shape // 2)
center_intensity = kernel[tuple(kernel_center)]
Expand Down
57 changes: 40 additions & 17 deletions skimage/restoration/tests/test_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from skimage._shared import testing, utils
from skimage._shared.testing import (assert_equal, assert_almost_equal,
assert_warns, assert_)
from skimage._shared.utils import _supported_float_type
from skimage._shared._warnings import expected_warnings


Expand All @@ -35,23 +36,36 @@
astro_odd = astro[:, :-1]


def test_denoise_tv_chambolle_2d():
float_dtypes = [np.float16, np.float32, np.float64]
try:
float_dtypes += [np.float128]
except AttributeError:
pass


@testing.parametrize('dtype',float_dtypes)
def test_denoise_tv_chambolle_2d(dtype):
# astronaut image
img = astro_gray.copy()
img = astro_gray.astype(dtype, copy=True)
# add noise to astronaut
img += 0.5 * img.std() * np.random.rand(*img.shape)
# clip noise so that it does not exceed allowed range for float images.
img = np.clip(img, 0, 1)
# denoise
denoised_astro = restoration.denoise_tv_chambolle(img, weight=0.1)
# which dtype?
assert_(denoised_astro.dtype in [float, np.float32, np.float64])
assert denoised_astro.dtype == _supported_float_type(img)

from scipy import ndimage as ndi

# Convert to a floating point type supported by scipy.ndimage
float_dtype = _supported_float_type(img)
img = img.astype(float_dtype, copy=False)

grad = ndi.morphological_gradient(img, size=((3, 3)))
grad_denoised = ndi.morphological_gradient(denoised_astro, size=((3, 3)))
# test if the total variation has decreased
assert_(grad_denoised.dtype == float)
assert_(np.sqrt((grad_denoised**2).sum()) < np.sqrt((grad**2).sum()))
assert grad_denoised.dtype == float_dtype
assert np.sqrt((grad_denoised**2).sum()) < np.sqrt((grad**2).sum())


@testing.parametrize('channel_axis', [0, 1, 2, -1])
Expand Down Expand Up @@ -466,28 +480,36 @@ def test_denoise_nl_means_3d(fast_mode, dtype):


@pytest.mark.parametrize('fast_mode', [False, True])
@pytest.mark.parametrize('dtype', ['float64', 'float32'])
@pytest.mark.parametrize('dtype', ['float64', 'float32', 'float16'])
@pytest.mark.parametrize('channel_axis', [0, -1])
def test_denoise_nl_means_multichannel(fast_mode, dtype, channel_axis):
# for true 3D data, 3D denoising is better than denoising as 2D+channels
img = np.zeros((13, 10, 8), dtype=dtype)
img[6, 4:6, 2:-2] = 1.
sigma = 0.3
imgn = img + sigma * np.random.randn(*img.shape)
dtype = np.float64
rstate = np.random.RandomState(5)

# synthetic 3d volume
img = data.binary_blobs(length=32, n_dim=3, seed=5)
img = img[:, :24, :16].astype(dtype, copy=False)

sigma = 0.2
imgn = img + sigma * rstate.randn(*img.shape)
imgn = imgn.astype(dtype)

# test 3D denoising (channel_axis = None)
denoised_ok_multichannel = restoration.denoise_nl_means(
imgn, 3, 2, h=0.6 * sigma, sigma=sigma, fast_mode=fast_mode,
channel_axis=None)

# set a channel axis: one dimension is (incorrectly) considered "channels"
imgn = np.moveaxis(imgn, -1, channel_axis)
denoised_wrong_multichannel = restoration.denoise_nl_means(
imgn, 3, 4, 0.6 * sigma, fast_mode=fast_mode,
imgn, 3, 2, h=0.6 * sigma, sigma=sigma, fast_mode=fast_mode,
channel_axis=channel_axis
)
denoised_ok_multichannel = restoration.denoise_nl_means(
imgn, 3, 4, 0.6 * sigma, fast_mode=fast_mode, channel_axis=None)
denoised_wrong_multichannel = np.moveaxis(
denoised_wrong_multichannel, channel_axis, -1
)
denoised_ok_multichannel = np.moveaxis(
denoised_ok_multichannel, channel_axis, -1
)

psnr_wrong = peak_signal_noise_ratio(img, denoised_wrong_multichannel)
psnr_ok = peak_signal_noise_ratio(img, denoised_ok_multichannel)
assert_(psnr_ok > psnr_wrong)
Expand Down Expand Up @@ -699,6 +721,7 @@ def test_wavelet_denoising_scaling(case, dtype, convert2ycbcr,
channel_axis=channel_axis,
convert2ycbcr=convert2ycbcr,
rescale_sigma=True)
assert denoised.dtype == _supported_float_type(noisy)

data_range = x.max() - x.min()
psnr_noisy = peak_signal_noise_ratio(x, noisy, data_range=data_range)
Expand Down