-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add n-dimensional support to denoise_wavelet #2242
Changes from all commits
c83fe7d
d4168b6
efb829f
6f3520b
72577e9
87936c0
48bfe20
721dd37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,12 +394,13 @@ def _wavelet_threshold(img, wavelet, threshold=None, sigma=None, mode='soft'): | |
return pywt.waverecn(denoised_coeffs, wavelet) | ||
|
||
|
||
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): | ||
def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft', | ||
multichannel=False): | ||
"""Performs wavelet denoising on an image. | ||
|
||
Parameters | ||
---------- | ||
img : ndarray (2D/3D) of ints, uints or floats | ||
img : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, actually this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 1D is already supported here, though? edit: 1D through 4D operation now verified via new tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. Let it be |
||
Input data to be denoised. `img` can be of any numeric type, | ||
but it is cast into an ndarray of floats for the computation | ||
of the denoised image. | ||
|
@@ -415,6 +416,9 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): | |
An optional argument to choose the type of denoising performed. It | ||
noted that choosing soft thresholding given additive noise finds the | ||
best approximation of the original image. | ||
multichannel : bool, optional | ||
Apply wavelet denoising separately for each channel (where channels | ||
correspond to the final axis of the array). | ||
|
||
Returns | ||
------- | ||
|
@@ -457,16 +461,14 @@ def denoise_wavelet(img, sigma=None, wavelet='db1', mode='soft'): | |
|
||
img = img_as_float(img) | ||
|
||
if img.ndim not in {2, 3}: | ||
raise ValueError('denoise_wavelet only supports 2D and 3D images') | ||
|
||
if img.ndim == 2: | ||
if multichannel: | ||
out = np.empty_like(img) | ||
for c in range(img.shape[-1]): | ||
out[..., c] = _wavelet_threshold(img[..., c], wavelet=wavelet, | ||
mode=mode, sigma=sigma) | ||
else: | ||
out = _wavelet_threshold(img, wavelet=wavelet, mode=mode, | ||
sigma=sigma) | ||
else: | ||
out = np.dstack([_wavelet_threshold(img[..., c], wavelet=wavelet, | ||
mode=mode, sigma=sigma) | ||
for c in range(img.ndim)]) | ||
|
||
clip_range = (-1, 1) if img.min() < 0 else (0, 1) | ||
return np.clip(out, *clip_range) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
|
||
from skimage import restoration, data, color, img_as_float, measure | ||
from skimage._shared._warnings import expected_warnings | ||
from skimage.measure import compare_ssim | ||
from skimage.measure import compare_psnr | ||
|
||
np.random.seed(1234) | ||
|
||
|
@@ -310,16 +310,48 @@ def test_no_denoising_for_small_h(): | |
|
||
|
||
def test_wavelet_denoising(): | ||
for img in [astro_gray, astro]: | ||
noisy = img.copy() + 0.1 * np.random.randn(*(img.shape)) | ||
for img, multichannel in [(astro_gray, False), (astro, True)]: | ||
sigma = 0.1 | ||
noisy = img + sigma * np.random.randn(*(img.shape)) | ||
noisy = np.clip(noisy, 0, 1) | ||
# less energy in signal | ||
denoised = restoration.denoise_wavelet(noisy, sigma=0.3) | ||
assert denoised.sum()**2 <= img.sum()**2 | ||
|
||
# test changing noise_std (higher threshold, so less energy in signal) | ||
assert (restoration.denoise_wavelet(noisy, sigma=0.2).sum()**2 <= | ||
restoration.denoise_wavelet(noisy, sigma=0.1).sum()**2) | ||
# Verify that SNR is improved when true sigma is used | ||
denoised = restoration.denoise_wavelet(noisy, sigma=sigma, | ||
multichannel=multichannel) | ||
psnr_noisy = compare_psnr(img, noisy) | ||
psnr_denoised = compare_psnr(img, denoised) | ||
assert psnr_denoised > psnr_noisy | ||
|
||
# Verify that SNR is improved with internally estimated sigma | ||
denoised = restoration.denoise_wavelet(noisy, | ||
multichannel=multichannel) | ||
psnr_noisy = compare_psnr(img, noisy) | ||
psnr_denoised = compare_psnr(img, denoised) | ||
assert psnr_denoised > psnr_noisy | ||
|
||
# Test changing noise_std (higher threshold, so less energy in signal) | ||
res1 = restoration.denoise_wavelet(noisy, sigma=2*sigma, | ||
multichannel=multichannel) | ||
res2 = restoration.denoise_wavelet(noisy, sigma=sigma, | ||
multichannel=multichannel) | ||
assert (res1.sum()**2 <= res2.sum()**2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this wasn't your approach, but this test seems quite weak to me! How about showing a reduction in the noise after denoise_wavelet, or increase in SNR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. it should be simple to verify at least a rough improvement via |
||
|
||
|
||
def test_wavelet_denoising_nd(): | ||
for ndim in range(1, 5): | ||
# Generate a very simple test image | ||
img = 0.2*np.ones((16, )*ndim) | ||
img[[slice(5, 13), ] * ndim] = 0.8 | ||
|
||
sigma = 0.1 | ||
noisy = img + sigma * np.random.randn(*(img.shape)) | ||
noisy = np.clip(noisy, 0, 1) | ||
|
||
# Verify that SNR is improved with internally estimated sigma | ||
denoised = restoration.denoise_wavelet(noisy) | ||
psnr_noisy = compare_psnr(img, noisy) | ||
psnr_denoised = compare_psnr(img, denoised) | ||
assert psnr_denoised > psnr_noisy | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the docstring, please, change expected data type for
img
to smth like(M[, N[, ..., P]][, C])
, and add description formultichannel
parameter.