Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

Gaussian filter truncation #239

Closed
wants to merge 2 commits into
from
Jump to file or symbol
Failed to load files and symbols.
+33 −14
Split
View
@@ -168,7 +168,7 @@ def convolve1d(input, weights, axis = -1, output = None, mode = "reflect",
@docfiller
def gaussian_filter1d(input, sigma, axis = -1, order = 0, output = None,
- mode = "reflect", cval = 0.0):
+ mode = "reflect", cval = 0.0, truncate = 4.0):
"""One-dimensional Gaussian filter.
Parameters
@@ -185,13 +185,15 @@ def gaussian_filter1d(input, sigma, axis = -1, order = 0, output = None,
%(output)s
%(mode)s
%(cval)s
+ truncate : float
+ Truncate the filter at this many standard deviations.
+ Default is 4.0.
"""
if order not in range(4):
raise ValueError('Order outside 0..3 not implemented')
sd = float(sigma)
- # make the length of the filter equal to 4 times the standard
- # deviations:
- lw = int(4.0 * sd + 0.5)
+ # make the radius of the filter equal to truncate standard deviations
+ lw = int(truncate * sd + 0.5)
weights = [0.0] * (2 * lw + 1)
weights[lw] = 1.0
sum = 1.0
@@ -232,7 +234,7 @@ def gaussian_filter1d(input, sigma, axis = -1, order = 0, output = None,
@docfiller
def gaussian_filter(input, sigma, order = 0, output = None,
- mode = "reflect", cval = 0.0):
+ mode = "reflect", cval = 0.0, truncate = 4.0):
"""Multi-dimensional Gaussian filter.
Parameters
@@ -253,6 +255,9 @@ def gaussian_filter(input, sigma, order = 0, output = None,
%(output)s
%(mode)s
%(cval)s
+ truncate : float
+ Truncate the filter at this many standard deviations.
+ Default is 4.0.
Notes
-----
@@ -275,7 +280,7 @@ def gaussian_filter(input, sigma, order = 0, output = None,
if len(axes) > 0:
for axis, sigma, order in axes:
gaussian_filter1d(input, sigma, axis, order, output,
- mode, cval)
+ mode, cval, truncate)
input = output
else:
output[...] = input[...]
@@ -386,7 +391,7 @@ def derivative2(input, axis, output, mode, cval):
@docfiller
def gaussian_laplace(input, sigma, output = None, mode = "reflect",
- cval = 0.0):
+ cval = 0.0, **kwargs):
"""Calculate a multidimensional laplace filter using gaussian
second derivatives.
@@ -400,14 +405,17 @@ def gaussian_laplace(input, sigma, output = None, mode = "reflect",
%(output)s
%(mode)s
%(cval)s
+ Extra keyword arguments will be passed to gaussian_filter().
"""
input = numpy.asarray(input)
- def derivative2(input, axis, output, mode, cval, sigma):
+ def derivative2(input, axis, output, mode, cval, sigma, **kwargs):
order = [0] * input.ndim
order[axis] = 2
- return gaussian_filter(input, sigma, order, output, mode, cval)
+ return gaussian_filter(input, sigma, order, output, mode, cval,
+ **kwargs)
return generic_laplace(input, derivative2, output, mode, cval,
- extra_arguments = (sigma,))
+ extra_arguments = (sigma,),
+ extra_keywords = kwargs)
@docfiller
@@ -462,7 +470,7 @@ def generic_gradient_magnitude(input, derivative, output = None,
@docfiller
def gaussian_gradient_magnitude(input, sigma, output = None,
- mode = "reflect", cval = 0.0):
+ mode = "reflect", cval = 0.0, **kwargs):
"""Calculate a multidimensional gradient magnitude using gaussian
derivatives.
@@ -476,14 +484,17 @@ def gaussian_gradient_magnitude(input, sigma, output = None,
%(output)s
%(mode)s
%(cval)s
+ Extra keyword arguments will be passed to gaussian_filter().
"""
input = numpy.asarray(input)
- def derivative(input, axis, output, mode, cval, sigma):
+ def derivative(input, axis, output, mode, cval, sigma, **kwargs):
order = [0] * input.ndim
order[axis] = 1
- return gaussian_filter(input, sigma, order, output, mode, cval)
+ return gaussian_filter(input, sigma, order, output, mode,
+ cval, **kwargs)
return generic_gradient_magnitude(input, derivative, output, mode,
- cval, extra_arguments = (sigma,))
+ cval, extra_arguments = (sigma,),
+ extra_keywords = kwargs)
def _correlate_or_convolve(input, weights, output, mode, cval, origin,
@@ -52,3 +52,11 @@ def test_valid_origins():
# Just check this raises an error instead of silently accepting or
# segfaulting.
assert_raises(ValueError, filter, data, 3, origin=2)
+
+def test_gaussian_truncate():
+ """Test that Gaussian filters can be truncated at different widths."""
+ arr = np.zeros((100, 100), np.float)
+ arr[50, 50] = 1
+ num_zeros_2 = (sndi.gaussian_filter(arr, 5, truncate=2) > 0).sum()
+ num_zeros_5 = (sndi.gaussian_filter(arr, 5, truncate=5) > 0).sum()
+ assert (num_zeros_5 > num_zeros_2 * 2)