Skip to content

Commit

Permalink
Merge pull request #2241 from grlee77/bayesshrink
Browse files Browse the repository at this point in the history
[MRG+1] more closesly match the BayesShrink paper in _wavelet_threshold
  • Loading branch information
jni committed Sep 8, 2016
2 parents 6f1ca7f + c4d8058 commit aa65ade
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 63 deletions.
171 changes: 111 additions & 60 deletions skimage/restoration/_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,53 @@ def denoise_tv_chambolle(im, weight=0.1, eps=2.e-4, n_iter_max=200,
return out


def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
"""Performs wavelet denoising.
def _bayes_thresh(details, var):
"""BayesShrink threshold for a zero-mean details coeff array."""
# Equivalent to: dvar = np.var(details) for 0-mean details array
dvar = np.mean(details*details)
eps = np.finfo(details.dtype).eps
thresh = var / np.sqrt(max(dvar - var, eps))
return thresh


def _sigma_est_dwt(detail_coeffs, distribution='Gaussian'):
"""Calculate the robust median estimator of the noise standard deviation.
Parameters
----------
detail_coeffs : ndarray
The detail coefficients corresponding to the discrete wavelet
transform of an image.
distribution : str
The underlying noise distribution.
Returns
-------
sigma : float
The estimated noise standard deviation (see section 4.2 of [1]_).
References
----------
.. [1] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
DOI:10.1093/biomet/81.3.425
"""
# Consider regions with detail coefficients exactly zero to be masked out
detail_coeffs = detail_coeffs[np.nonzero(detail_coeffs)]

if distribution.lower() == 'gaussian':
# 75th quantile of the underlying, symmetric noise distribution
denom = scipy.stats.norm.ppf(0.75)
sigma = np.median(np.abs(detail_coeffs)) / denom
else:
raise ValueError("Only Gaussian noise estimation is currently "
"supported")
return sigma


def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft',
wavelet_levels=None):
"""Perform wavelet denoising.
Parameters
----------
Expand All @@ -353,18 +398,33 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
is None (the default) by the method in [2]_.
threshold : float, optional
The thresholding value. All wavelet coefficients less than this value
are set to 0. The default value (None) uses the SureShrink method found
in [1]_ to remove noise.
are set to 0. The default value (None) uses the BayesShrink method
found in [1]_ to remove noise.
mode : {'soft', 'hard'}, optional
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
wavelet_levels : int or None, optional
The number of wavelet decomposition levels to use. The default is
three less than the maximum number of possible decomposition levels
(see Notes below).
Returns
-------
out : ndarray
Denoised image.
Notes
-----
Reference [1]_ used four levels of wavelet decomposition. To be more
flexible for a range of input sizes, the implementation here stops 3 levels
prior to the maximum level of decomposition for `img` (the exact # of
levels thus depends on `img.shape` and the chosen wavelet). BayesShrink
variance estimation doesn't work well on levels with extremely small
coefficient arrays. This is the rationale for skipping a few of the
coarsest levels. The user can override the automated setting by explicitly
specifying `wavelet_levels`.
References
----------
.. [1] Chang, S. Grace, Bin Yu, and Martin Vetterli. "Adaptive wavelet
Expand All @@ -376,27 +436,52 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'):
DOI: 10.1093/biomet/81.3.425
"""
coeffs = pywt.wavedecn(img, wavelet=wavelet)
detail_coeffs = coeffs[-1]['d' * img.ndim]
wavelet = pywt.Wavelet(wavelet)

# Determine the number of wavelet decomposition levels
if wavelet_levels is None:
# Determine the maximum number of possible levels for img
dlen = wavelet.dec_len
wavelet_levels = np.min(
[pywt.dwt_max_level(s, dlen) for s in img.shape])

# Skip coarsest wavelet scales (see Notes in docstring).
wavelet_levels = max(wavelet_levels - 3, 1)

coeffs = pywt.wavedecn(img, wavelet=wavelet, level=wavelet_levels)
# Detail coefficients at each decomposition level
dcoeffs = coeffs[1:]

if sigma is None:
# Estimates via the noise via method in [2]
sigma = np.median(np.abs(detail_coeffs)) / 0.67448975019608171
# Estimate the noise via the method in [2]_
detail_coeffs = dcoeffs[-1]['d' * img.ndim]
sigma = _sigma_est_dwt(detail_coeffs, distribution='Gaussian')

if threshold is None:
# The BayesShrink threshold from [1]_ in docstring
threshold = sigma**2 / np.sqrt(max(img.var() - sigma**2, 0))

denoised_detail = [{key: pywt.threshold(level[key], value=threshold,
mode=mode) for key in level} for level in coeffs[1:]]
denoised_root = pywt.threshold(coeffs[0], value=threshold, mode=mode)
denoised_coeffs = [denoised_root] + [d for d in denoised_detail]
# The BayesShrink thresholds from [1]_ in docstring
var = sigma**2
threshold = [{key: _bayes_thresh(level[key], var) for key in level}
for level in dcoeffs]

if np.isscalar(threshold):
# A single threshold for all coefficient arrays
denoised_detail = [{key: pywt.threshold(level[key],
value=threshold,
mode=mode) for key in level}
for level in dcoeffs]
else:
# Dict of unique threshold coefficients for each detail coeff. array
denoised_detail = [{key: pywt.threshold(level[key],
value=thresh[key],
mode=mode) for key in level}
for thresh, level in zip(threshold, dcoeffs)]
denoised_coeffs = [coeffs[0]] + denoised_detail
return pywt.waverecn(denoised_coeffs, wavelet)


def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
multichannel=False):
"""Performs wavelet denoising on an image.
wavelet_levels=None, multichannel=False):
"""Perform wavelet denoising on an image.
Parameters
----------
Expand All @@ -416,6 +501,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
An optional argument to choose the type of denoising performed. It
noted that choosing soft thresholding given additive noise finds the
best approximation of the original image.
wavelet_levels : int or None, optional
The number of wavelet decomposition levels to use. The default is
three less than the maximum number of possible decomposition levels.
multichannel : bool, optional
Apply wavelet denoising separately for each channel (where channels
correspond to the final axis of the array).
Expand Down Expand Up @@ -458,72 +546,35 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft',
>>> denoised_img = denoise_wavelet(img, sigma=0.1)
"""

img = img_as_float(img)

if multichannel:
out = np.empty_like(img)
for c in range(img.shape[-1]):
out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet,
mode=mode, sigma=sigma)
mode=mode, sigma=sigma,
wavelet_levels=wavelet_levels)
else:
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode,
sigma=sigma)
sigma=sigma, wavelet_levels=wavelet_levels)

clip_range = (-1, 1) if img.min() < 0 else (0, 1)
return np.clip(out, *clip_range)


def _sigma_est_dwt(detail_coeffs, distribution='Gaussian'):
"""
Calculation of the robust median estimator of the noise standard
deviation.
Parameters
----------
detail_coeffs : ndarray
The detail coefficients corresponding to the discrete wavelet
transform of an image.
distribution : str
The underlying noise distribution.
Returns
-------
sigma : float
The estimated noise standard deviation (see section 4.2 of [1]_).
References
----------
.. [1] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation
by wavelet shrinkage." Biometrika 81.3 (1994): 425-455.
DOI:10.1093/biomet/81.3.425
"""
# consider regions with detail coefficients exactly zero to be masked out
detail_coeffs = detail_coeffs[np.nonzero(detail_coeffs)]

if distribution.lower() == 'gaussian':
# 75th quantile of the underlying, symmetric noise distribution:
denom = scipy.stats.norm.ppf(0.75)
sigma = np.median(np.abs(detail_coeffs)) / denom
else:
raise ValueError("Only Gaussian noise estimation is currently "
"supported")
return sigma


def estimate_sigma(im, multichannel=False, average_sigmas=False):
def estimate_sigma(im, average_sigmas=False, multichannel=False):
"""
Robust wavelet-based estimator of the (Gaussian) noise standard deviation.
Parameters
----------
im : ndarray
Image for which to estimate the noise standard deviation.
multichannel : bool
Estimate sigma separately for each channel.
average_sigmas : bool, optional
If true, average the channel estimates of `sigma`. Otherwise return
a list of sigmas corresponding to each channel.
multichannel : bool
Estimate sigma separately for each channel.
Returns
-------
Expand Down
58 changes: 55 additions & 3 deletions skimage/restoration/tests/test_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from skimage import restoration, data, color, img_as_float, measure
from skimage._shared._warnings import expected_warnings
from skimage.measure import compare_psnr
from skimage.restoration._denoise import _wavelet_threshold

import pywt

np.random.seed(1234)

Expand Down Expand Up @@ -311,9 +314,10 @@ def test_no_denoising_for_small_h():


def test_wavelet_denoising():
rstate = np.random.RandomState(1234)
for img, multichannel in [(astro_gray, False), (astro, True)]:
sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = img + sigma * rstate.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)

# Verify that SNR is improved when true sigma is used
Expand All @@ -335,17 +339,33 @@ def test_wavelet_denoising():
multichannel=multichannel)
res2 = restoration.denoise_wavelet(noisy, sigma=sigma,
multichannel=multichannel)
assert (res1.sum()**2 <= res2.sum()**2)
assert np.sum(res1**2) <= np.sum(res2**2)


def test_wavelet_threshold():
rstate = np.random.RandomState(1234)

img = astro_gray
sigma = 0.1
noisy = img + sigma * rstate.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)

# employ a single, uniform threshold instead of BayesShrink sigmas
denoised = _wavelet_threshold(noisy, wavelet='db1', threshold=sigma)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
assert psnr_denoised > psnr_noisy


def test_wavelet_denoising_nd():
rstate = np.random.RandomState(1234)
for ndim in range(1, 5):
# Generate a very simple test image
img = 0.2*np.ones((16, )*ndim)
img[[slice(5, 13), ] * ndim] = 0.8

sigma = 0.1
noisy = img + sigma * np.random.randn(*(img.shape))
noisy = img + sigma * rstate.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)

# Verify that SNR is improved with internally estimated sigma
Expand All @@ -355,6 +375,38 @@ def test_wavelet_denoising_nd():
assert psnr_denoised > psnr_noisy


def test_wavelet_denoising_levels():
rstate = np.random.RandomState(1234)
ndim = 2
N = 256
wavelet = 'db1'
# Generate a very simple test image
img = 0.2*np.ones((N, )*ndim)
img[[slice(5, 13), ] * ndim] = 0.8

sigma = 0.1
noisy = img + sigma * rstate.randn(*(img.shape))
noisy = np.clip(noisy, 0, 1)

denoised = restoration.denoise_wavelet(noisy, wavelet=wavelet)
denoised_1 = restoration.denoise_wavelet(noisy, wavelet=wavelet,
wavelet_levels=1)
psnr_noisy = compare_psnr(img, noisy)
psnr_denoised = compare_psnr(img, denoised)
psnr_denoised_1 = compare_psnr(img, denoised_1)

# multi-level case should outperform single level case
assert psnr_denoised > psnr_denoised_1 > psnr_noisy

# invalid number of wavelet levels results in a ValueError
max_level = pywt.dwt_max_level(np.min(img.shape),
pywt.Wavelet(wavelet).dec_len)
assert_raises(ValueError, restoration.denoise_wavelet, noisy,
wavelet=wavelet, wavelet_levels=max_level+1)
assert_raises(ValueError, restoration.denoise_wavelet, noisy,
wavelet=wavelet, wavelet_levels=-1)


def test_estimate_sigma_gray():
rstate = np.random.RandomState(1234)
# astronaut image
Expand Down

0 comments on commit aa65ade

Please sign in to comment.