Skip to content

Commit

Permalink
Merge pull request #1792 from scottsievert/fftconvolve
Browse files Browse the repository at this point in the history
Uses fftconvolve instead of convolve2d for speedups
  • Loading branch information
jni committed Dec 14, 2015
2 parents 477f396 + 2892e90 commit 975d1a4
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions skimage/restoration/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import numpy.random as npr
from scipy.signal import convolve2d
from scipy.signal import fftconvolve, convolve

from . import uft

Expand Down Expand Up @@ -336,7 +336,7 @@ def richardson_lucy(image, psf, iterations=50, clip=True):
Parameters
----------
image : ndarray
Input degraded image.
Input degraded image (can be N dimensional).
psf : ndarray
The point spread function.
iterations : int
Expand Down Expand Up @@ -365,13 +365,32 @@ def richardson_lucy(image, psf, iterations=50, clip=True):
----------
.. [1] http://en.wikipedia.org/wiki/Richardson%E2%80%93Lucy_deconvolution
"""
# compute the times for direct convolution and the fft method. The fft is of
# complexity O(N log(N)) for each dimension and the direct method does
# straight arithmetic (and is O(n*k) to add n elements k times)
def direct_time(img_shape, kernel_shape):
return np.prod(img_shape + kernel_shape)
def fft_time(img_shape, kernel_shape):
return np.sum([n*np.log(n) for n in img_shape+kernel_shape])

# see whether the fourier transform convolution method or the direct
# convolution method is faster (discussed in scikit-image PR #1792)
time_ratio = 40.032 * fft_time(image.shape, psf.shape)
time_ratio /= direct_time(image.shape, psf.shape)

if time_ratio <= 1 or len(image.shape) > 2:
convolve_method = fftconvolve
else:
convolve_method = convolve

image = image.astype(np.float)
psf = psf.astype(np.float)
im_deconv = 0.5 * np.ones(image.shape)
psf_mirror = psf[::-1, ::-1]

for _ in range(iterations):
relative_blur = image / convolve2d(im_deconv, psf, 'same')
im_deconv *= convolve2d(relative_blur, psf_mirror, 'same')
relative_blur = image / convolve_method(im_deconv, psf, 'same')
im_deconv *= convolve_method(relative_blur, psf_mirror, 'same')

if clip:
im_deconv[im_deconv > 1] = 1
Expand Down

0 comments on commit 975d1a4

Please sign in to comment.