Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add n-dimensional support to denoise_wavelet #2242

Merged
merged 8 commits into from
Aug 13, 2016
6 changes: 4 additions & 2 deletions doc/examples/filters/plot_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
ax[0, 2].imshow(denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15))
ax[0, 2].axis('off')
ax[0, 2].set_title('Bilateral')
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std()))
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std(),
multichannel=True))
ax[0, 3].axis('off')
ax[0, 3].set_title('Wavelet')

Expand All @@ -62,7 +63,8 @@
ax[1, 2].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15))
ax[1, 2].axis('off')
ax[1, 2].set_title('(more) Bilateral')
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std()))
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std(),
multichannel=True))
ax[1, 3].axis('off')
ax[1, 3].set_title('(more) Wavelet')
ax[1, 0].imshow(astro)
Expand Down
22 changes: 12 additions & 10 deletions skimage/restoration/_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,13 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
return pywt.waverecn(denoised_coeffs, wavelet)


def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
multichannel=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the docstring, please, change expected data type for img to smth like (M[, N[, ..., P]][, C]), and add description for multichannel parameter.

"""Performs wavelet denoising on an image.

Parameters
----------
img : ndarray (2D/3D) of ints, uints or floats
img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, actually this should be (M, N[, ...P][, C]) if we declare 2+D support. Please, notice that [, ] mark optional dimensions.

Copy link
Contributor Author

@grlee77 grlee77 Aug 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 1D is already supported here, though?

edit: 1D through 4D operation now verified via new tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Let it be (M[, N[, ...P]][, C]) then.

Input data to be denoised. `img` can be of any numeric type,
but it is cast into an ndarray of floats for the computation
of the denoised image.
Expand All @@ -415,6 +416,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
multichannel : bool, optional
Apply wavelet denoising separately for each channel (where channels
correspond to the final axis of the array).

Returns
-------
Expand Down Expand Up @@ -457,16 +461,14 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):

img = img_as_float(img)

if img.ndim not in {2, 3}:
raise ValueError('denoise_wavelet only supports 2D and 3D images')

if img.ndim == 2:
if multichannel:
out = np.empty_like(img)
for c in range(img.shape[-1]):
out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
else:
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
sigma=sigma)
else:
out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
for c in range(img.ndim)])

clip_range = (-1, 1) if img.min() < 0 else (0, 1)
return np.clip(out, *clip_range)
50 changes: 41 additions & 9 deletions skimage/restoration/tests/test_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from skimage import restoration, data, color, img_as_float, measure
from skimage._shared._warnings import expected_warnings
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr

np.random.seed(1234)

Expand Down Expand Up @@ -310,16 +310,48 @@ def test_no_denoising_for_small_h():


def test_wavelet_denoising():
for img in [astro_gray, astro]:
noisy = img.copy() + 0.1 * np.random.randn(*(img.shape))
for img, multichannel in [(astro_gray, False), (astro, True)]:
sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)
# less energy in signal
denoised = restoration.denoise_wavelet(noisy, sigma=0.3)
assert denoised.sum()**2 <= img.sum()**2

# test changing noise_std (higher threshold, so less energy in signal)
assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <=
restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2)
# Verify that SNR is improved when true sigma is used
denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy

# Verify that SNR is improved with internally estimated sigma
denoised = restoration.denoise_wavelet(noisy,
multichannel=multichannel)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy

# Test changing noise_std (higher threshold, so less energy in signal)
res1 = restoration.denoise_wavelet(noisy, sigma=2*sigma,
multichannel=multichannel)
res2 = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
assert (res1.sum()**2 <= res2.sum()**2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this wasn't your approach, but this test seems quite weak to me! How about showing a reduction in the noise after denoise_wavelet, or increase in SNR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. it should be simple to verify at least a rough improvement via compare_psnr



def test_wavelet_denoising_nd():
for ndim in range(1, 5):
# Generate a very simple test image
img = 0.2*np.ones((16, )*ndim)
img[[slice(5, 13), ] * ndim] = 0.8

sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)

# Verify that SNR is improved with internally estimated sigma
denoised = restoration.denoise_wavelet(noisy)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy


if __name__ == "__main__":
Expand Down