Skip to content

Commit

Permalink
Merge pull request #2190 from stsievert/wavelet-denoise-v2
Browse files Browse the repository at this point in the history
ENH: Implements wavelet denoising (from #1833)
  • Loading branch information
JDWarner committed Aug 5, 2016
2 parents 0852e3b + 13490a6 commit 0ea2a34
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,6 @@

- Abdeali Kothari
Alpha blending to convert from rgba to rgb

- Scott Sievert
Wavelet denoising
28 changes: 17 additions & 11 deletions doc/examples/filters/plot_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
import matplotlib.pyplot as plt

from skimage import data, img_as_float
from skimage.restoration import denoise_tv_chambolle, denoise_bilateral
from skimage.restoration import denoise_tv_chambolle, denoise_bilateral, denoise_wavelet
from skimage.util import random_noise


astro = img_as_float(data.astronaut())
astro = astro[220:300, 220:320]

noisy = astro + 0.6 * astro.std() * np.random.random(astro.shape)
noisy = np.clip(noisy, 0, 1)
noisy = random_noise(astro, var=(0.6 * astro.std())**2)

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 5), sharex=True,
fig, ax = plt.subplots(nrows=2, ncols=4, figsize=(8, 5), sharex=True,
sharey=True, subplot_kw={'adjustable': 'box-forced'})

plt.gray()
Expand All @@ -52,16 +52,22 @@
ax[0, 2].imshow(denoise_bilateral(noisy, sigma_color=0.05, sigma_spatial=15))
ax[0, 2].axis('off')
ax[0, 2].set_title('Bilateral')
ax[0, 3].imshow(denoise_wavelet(noisy, sigma=0.4*astro.std()))
ax[0, 3].axis('off')
ax[0, 3].set_title('Wavelet')

ax[1, 0].imshow(denoise_tv_chambolle(noisy, weight=0.2, multichannel=True))
ax[1, 0].axis('off')
ax[1, 0].set_title('(more) TV')
ax[1, 1].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15))
ax[1, 1].imshow(denoise_tv_chambolle(noisy, weight=0.2, multichannel=True))
ax[1, 1].axis('off')
ax[1, 1].set_title('(more) Bilateral')
ax[1, 2].imshow(astro)
ax[1, 1].set_title('(more) TV')
ax[1, 2].imshow(denoise_bilateral(noisy, sigma_color=0.1, sigma_spatial=15))
ax[1, 2].axis('off')
ax[1, 2].set_title('original')
ax[1, 2].set_title('(more) Bilateral')
ax[1, 3].imshow(denoise_wavelet(noisy, sigma=0.6*astro.std()))
ax[1, 3].axis('off')
ax[1, 3].set_title('(more) Wavelet')
ax[1, 0].imshow(astro)
ax[1, 0].axis('off')
ax[1, 0].set_title('original')

fig.tight_layout()

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ six>=1.7.3
networkx>=1.8
pillow>=2.1.0
dask[array]>=0.5.0
PyWavelets>=0.4.0
3 changes: 2 additions & 1 deletion skimage/restoration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .deconvolution import wiener, unsupervised_wiener, richardson_lucy
from .unwrap import unwrap_phase
from ._denoise import denoise_tv_chambolle, denoise_tv_bregman, \
denoise_bilateral
denoise_bilateral, denoise_wavelet
from .non_local_means import denoise_nl_means
from .inpaint import inpaint_biharmonic
from .._shared.utils import copy_func, deprecated
Expand All @@ -37,6 +37,7 @@
'denoise_tv_bregman',
'denoise_tv_chambolle',
'denoise_bilateral',
'denoise_wavelet',
'denoise_nl_means',
'nl_means_denoising',
'inpaint_biharmonic']
Expand Down
138 changes: 138 additions & 0 deletions skimage/restoration/_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..restoration._denoise_cy import _denoise_bilateral, _denoise_tv_bregman
from .._shared.utils import skimage_deprecation, warn
import warnings
import pywt


def denoise_bilateral(image, win_size=None, sigma_color=None, sigma_spatial=1,
Expand Down Expand Up @@ -332,3 +333,140 @@ def denoise_tv_chambolle(im, weight=0.1, eps=2.e-4, n_iter_max=200,
else:
out = _denoise_tv_chambolle_nd(im, weight, eps, n_iter_max)
return out


def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
"""Performs wavelet denoising.
Parameters
----------
img : ndarray (2d or 3d) of ints, uints or floats
Input data to be denoised. `img` can be of any numeric type,
but it is cast into an ndarray of floats for the computation
of the denoised image.
wavelet : string
The type of wavelet to perform. Can be any of the options
pywt.wavelist outputs. For example, this may be any of ``{db1, db2,
db3, db4, haar}``.
sigma : float, optional
The standard deviation of the noise. The noise is estimated when sigma
is None (the default) by the method in [2]_.
threshold : float, optional
The thresholding value. All wavelet coefficients less than this value
are set to 0. The default value (None) uses the SureShrink method found
in [1]_ to remove noise.
mode : {'soft', 'hard'}, optional
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
Returns
-------
out : ndarray
Denoised image.
References
----------
.. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
thresholding for image denoising and compression." Image Processing,
IEEE Transactions on 9.9 (2000): 1532-1546.
DOI: 10.1109/83.862633
.. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
DOI: 10.1093/biomet/81.3.425
"""
coeffs = pywt.wavedecn(img, wavelet=wavelet)
detail_coeffs = coeffs[-1]['d' * img.ndim]

if sigma is None:
# Estimates via the noise via method in [2]
sigma = np.median(np.abs(detail_coeffs)) / 0.67448975019608171

if threshold is None:
# The BayesShrink threshold from [1]_ in docstring
threshold = sigma**2 / np.sqrt(max(img.var() - sigma**2, 0))

denoised_detail = [{key: pywt.threshold(level[key], value=threshold,
mode=mode) for key in level} for level in coeffs[1:]]
denoised_root = pywt.threshold(coeffs[0], value=threshold, mode=mode)
denoised_coeffs = [denoised_root] + [d for d in denoised_detail]
return pywt.waverecn(denoised_coeffs, wavelet)


def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'):
"""Performs wavelet denoising on an image.
Parameters
----------
img : ndarray (2D/3D) of ints, uints or floats
Input data to be denoised. `img` can be of any numeric type,
but it is cast into an ndarray of floats for the computation
of the denoised image.
sigma : float, optional
The noise standard deviation used when computing the threshold
adaptively as described in [1]_. When None (default), the noise
standard deviation is estimated via the method in [2]_.
wavelet : string, optional
The type of wavelet to perform and can be any of the options
``pywt.wavelist`` outputs. The default is `'db1'`. For example,
``wavelet`` can be any of ``{'db2', 'haar', 'sym9'}`` and many more.
mode : {'soft', 'hard'}, optional
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
Returns
-------
out : ndarray
Denoised image.
Notes
-----
The wavelet domain is a sparse representation of the image, and can be
thought of similarly to the frequency domain of the Fourier transform.
Sparse representations have most values zero or near-zero and truly random
noise is (usually) represented by many small values in the wavelet domain.
Setting all values below some threshold to 0 reduces the noise in the
image, but larger thresholds also decrease the detail present in the image.
If the input is 3D, this function performs wavelet denoising on each color
plane separately. The output image is clipped between either [-1, 1] and
[0, 1] depending on the input image range.
References
----------
.. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
thresholding for image denoising and compression." Image Processing,
IEEE Transactions on 9.9 (2000): 1532-1546.
DOI: 10.1109/83.862633
.. [2] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
DOI: 10.1093/biomet/81.3.425
Examples
--------
>>> from skimage import color, data
>>> img = img_as_float(data.astronaut())
>>> img = color.rgb2gray(img)
>>> img += 0.1 * np.random.randn(*img.shape)
>>> img = np.clip(img, 0, 1)
>>> denoised_img = denoise_wavelet(img, sigma=0.1)
"""

img = img_as_float(img)

if img.ndim not in {2, 3}:
raise ValueError('denoise_wavelet only supports 2D and 3D images')

if img.ndim == 2:
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
sigma=sigma)
else:
out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
for c in range(img.ndim)])

clip_range = (-1, 1) if img.min() < 0 else (0, 1)
return np.clip(out, *clip_range)
14 changes: 14 additions & 0 deletions skimage/restoration/tests/test_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from skimage import restoration, data, color, img_as_float, measure
from skimage._shared._warnings import expected_warnings
from skimage.measure import compare_ssim

np.random.seed(1234)

Expand Down Expand Up @@ -308,5 +309,18 @@ def test_no_denoising_for_small_h():
assert np.allclose(denoised, img)


def test_wavelet_denoising():
for img in [astro_gray, astro]:
noisy = img.copy() + 0.1 * np.random.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)
# less energy in signal
denoised = restoration.denoise_wavelet(noisy, sigma=0.3)
assert denoised.sum()**2 <= img.sum()**2

# test changing noise_std (higher threshold, so less energy in signal)
assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <=
restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2)


if __name__ == "__main__":
run_module_suite()

0 comments on commit 0ea2a34

Please sign in to comment.