Skip to content

Commit

Permalink
Merge pull request #2242 from grlee77/wavelet_nd
Browse files Browse the repository at this point in the history
add n-dimensional support to denoise_wavelet
  • Loading branch information
jni committed Aug 13, 2016
2 parents e691467 + 721dd37 commit 99136ab
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 21 deletions.
6 changes: 4 additions & 2 deletions doc/examples/filters/plot_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
ax[0, 2].imshow(denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15))
ax[0, 2].axis('off')
ax[0, 2].set_title('Bilateral')
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std()))
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std(),
multichannel=True))
ax[0, 3].axis('off')
ax[0, 3].set_title('Wavelet')

Expand All @@ -62,7 +63,8 @@
ax[1, 2].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15))
ax[1, 2].axis('off')
ax[1, 2].set_title('(more) Bilateral')
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std()))
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std(),
multichannel=True))
ax[1, 3].axis('off')
ax[1, 3].set_title('(more) Wavelet')
ax[1, 0].imshow(astro)
Expand Down
22 changes: 12 additions & 10 deletions skimage/restoration/_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,13 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
return pywt.waverecn(denoised_coeffs, wavelet)


def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
multichannel=False):
"""Performs wavelet denoising on an image.
Parameters
----------
img : ndarray (2D/3D) of ints, uints or floats
img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats
Input data to be denoised. `img` can be of any numeric type,
but it is cast into an ndarray of floats for the computation
of the denoised image.
Expand All @@ -415,6 +416,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
multichannel : bool, optional
Apply wavelet denoising separately for each channel (where channels
correspond to the final axis of the array).
Returns
-------
Expand Down Expand Up @@ -457,16 +461,14 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):

img = img_as_float(img)

if img.ndim not in {2, 3}:
raise ValueError('denoise_wavelet only supports 2D and 3D images')

if img.ndim == 2:
if multichannel:
out = np.empty_like(img)
for c in range(img.shape[-1]):
out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
else:
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
sigma=sigma)
else:
out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
for c in range(img.ndim)])

clip_range = (-1, 1) if img.min() < 0 else (0, 1)
return np.clip(out, *clip_range)
50 changes: 41 additions & 9 deletions skimage/restoration/tests/test_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from skimage import restoration, data, color, img_as_float, measure
from skimage._shared._warnings import expected_warnings
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr

np.random.seed(1234)

Expand Down Expand Up @@ -310,16 +310,48 @@ def test_no_denoising_for_small_h():


def test_wavelet_denoising():
for img in [astro_gray, astro]:
noisy = img.copy() + 0.1 * np.random.randn(*(img.shape))
for img, multichannel in [(astro_gray, False), (astro, True)]:
sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)
# less energy in signal
denoised = restoration.denoise_wavelet(noisy, sigma=0.3)
assert denoised.sum()**2 <= img.sum()**2

# test changing noise_std (higher threshold, so less energy in signal)
assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <=
restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2)
# Verify that SNR is improved when true sigma is used
denoised = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy

# Verify that SNR is improved with internally estimated sigma
denoised = restoration.denoise_wavelet(noisy,
multichannel=multichannel)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy

# Test changing noise_std (higher threshold, so less energy in signal)
res1 = restoration.denoise_wavelet(noisy, sigma=2*sigma,
multichannel=multichannel)
res2 = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
assert (res1.sum()**2 <= res2.sum()**2)


def test_wavelet_denoising_nd():
for ndim in range(1, 5):
# Generate a very simple test image
img = 0.2*np.ones((16, )*ndim)
img[[slice(5, 13), ] * ndim] = 0.8

sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)

# Verify that SNR is improved with internally estimated sigma
denoised = restoration.denoise_wavelet(noisy)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy


if __name__ == "__main__":
Expand Down

0 comments on commit 99136ab

Please sign in to comment.