Skip to content
Browse files

ENH: signal: enhanced scipy.signal.filtfilt

* Rewrote lfilter_zi, and added a docstring that explains what it does.
* Enhanced filtfilt so that it can filter a given axis in an n-dimensional array,
  and also allow more flexibility in how the data is extended before filtering.
* Added the module _arraytools.py, which provides the convenient functions
  axis_slice and axis_reverse, and the axis extension functions odd_ext,
  even_ext and const_ext.
* Changes to filtfilt also fix bug #1410.
  • Loading branch information...
1 parent 2448598 commit 8d4665f77ffab58d40ef712a36030450007b48e0 Warren Weckesser committed Jun 3, 2011
View
7 doc/release/0.10.0-notes.rst
@@ -38,6 +38,13 @@ about our function's call signatures.
New features
============
+Enhanced filtfilt function in ``scipy.signal``
+----------------------------------------------
+
+The forward-backward filter function `scipy.signal.filtfilt` can now
+filter the data in a given axis of an n-dimensional numpy array.
+(Previously it only handled a 1-dimensional array.) Options have been
+added to allow more control over how the data is extended before filtering.
Deprecated features
View
161 scipy/signal/_arraytools.py
@@ -0,0 +1,161 @@
+"""
+Functions for acting on a axis of an array.
+"""
+
+import numpy as np
+
+
+def axis_slice(a, start=None, stop=None, step=None, axis=-1):
+ """Take a slice along axis 'axis' from 'a'.
+
+ Parameters
+ ----------
+ a : numpy.ndarray
+ The array to be sliced.
+ start, stop, step : int or None
+ The slice parameters.
+ axis : int
+ The axis of `a` to be sliced.
+
+ Examples
+ --------
+ >>> a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
+ >>> axis_slice(a, start=0, stop=1, axis=1)
+ array([[1],
+ [4],
+ [7]])
+ >>> axis_slice(a, start=1, axis=0)
+ array([[4, 5, 6],
+ [7, 8, 9]])
+
+ Notes
+ -----
+ The keyword arguments start, stop and step are used by calling
+ slice(start, stop, step). This implies axis_slice() does not
+ handle its arguments the exacty the same as indexing. To select
+ a single index k, for example, use
+ axis_slice(a, start=k, stop=k+1)
+ In this case, the length of the axis 'axis' in the result will
+ be 1; the trivial dimension is not removed. (Use numpy.squeeze()
+ to remove trivial axes.)
+ """
+ a_slice = [slice(None)] * a.ndim
+ a_slice[axis] = slice(start, stop, step)
+ b = a[a_slice]
+ return b
+
+
+def axis_reverse(a, axis=-1):
+ """Reverse the 1-d slices of `a` along axis `axis`.
+
+ Returns axis_slice(a, step=-1, axis=axis).
+ """
+ return axis_slice(a, step=-1, axis=axis)
+
+
+def odd_ext(x, n, axis=-1):
+ """Generate a new ndarray by making an odd extension of x along an axis.
+
+ Parameters
+ ----------
+ x : ndarray
+ The array to be extended.
+ n : int
+ The number of elements by which to extend x at each end of the axis.
+ axis : int
+ The axis along which to extend x. Default is -1.
+
+ Examples
+ --------
+ >>> a = array([[1.0,2.0,3.0,4.0,5.0], [0.0, 1.0, 4.0, 9.0, 16.0]])
+ >>> _odd_ext(a, 2)
+ array([[-1., 0., 1., 2., 3., 4., 5., 6., 7.],
+ [-4., -1, 0., 1., 4., 9., 16., 23., 28.]])
+ """
+ if n < 1:
+ return x
+ if n > x.shape[axis] - 1:
+ raise ValueError(("The extension length n (%d) is too big. " +
+ "It must not exceed x.shape[axis]-1, which is %d.")
+ % (n, x.shape[axis] - 1))
+ left_end = axis_slice(x, start=0, stop=1, axis=axis)
+ left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
+ right_end = axis_slice(x, start=-1, axis=axis)
+ right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
+ ext = np.concatenate((2 * left_end - left_ext,
+ x,
+ 2 * right_end - right_ext),
+ axis=axis)
+ return ext
+
+
+def even_ext(x, n, axis=-1):
+ """Create an ndarray that is an even extension of x along an axis.
+
+ Parameters
+ ----------
+ x : ndarray
+ The array to be extended.
+ n : int
+ The number of elements by which to extend x at each end of the axis.
+ axis : int
+ The axis along which to extend x. Default is -1.
+
+ Examples
+ --------
+ >>> a = array([[1.0,2.0,3.0,4.0,5.0], [0.0, 1.0, 4.0, 9.0, 16.0]])
+ >>> _even_ext(a, 2)
+ array([[ 3., 2., 1., 2., 3., 4., 5., 4., 3.],
+ [ 4., 1., 0., 1., 4., 9., 16., 9., 4.]])
+ """
+ if n < 1:
+ return x
+ if n > x.shape[axis] - 1:
+ raise ValueError(("The extension length n (%d) is too big. " +
+ "It must not exceed x.shape[axis]-1, which is %d.")
+ % (n, x.shape[axis] - 1))
+ left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
+ right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
+ ext = np.concatenate((left_ext,
+ x,
+ right_ext),
+ axis=axis)
+ return ext
+
+
+def const_ext(x, n, axis=-1):
+ """Create an ndarray that is a constant extension of x along an axis.
+
+ The extension repeats the values at the first and last element of
+ the axis.
+
+ Parameters
+ ----------
+ x : ndarray
+ The array to be extended.
+ n : int
+ The number of elements by which to extend x at each end of the axis.
+ axis : int
+ The axis along which to extend x. Default is -1.
+
+ Examples
+ --------
+ >>> a = array([[1.0,2.0,3.0,4.0,5.0], [0.0, 1.0, 4.0, 9.0, 16.0]])
+ >>> _const_ext(a, 2)
+ array([[ 1., 1., 1., 2., 3., 4., 5., 5., 5.],
+ [ 0., 0., 0., 1., 4., 9., 16., 16., 16.]])
+ """
+ if n < 1:
+ return x
+ left_end = axis_slice(x, start=0, stop=1, axis=axis)
+ ones_shape = [1] * x.ndim
+ ones_shape[axis] = n
+ ones = np.ones(ones_shape, dtype=x.dtype)
+ left_ext = ones * left_end
+ right_end = axis_slice(x, start=-1, axis=axis)
+ right_ext = ones * right_end
+ ext = np.concatenate((left_ext,
+ x,
+ right_ext),
+ axis=axis)
+ return ext
View
289 scipy/signal/signaltools.py
@@ -11,10 +11,11 @@
ones, real_if_close, zeros, array, arange, where, rank, \
newaxis, product, ravel, sum, r_, iscomplexobj, take, \
argsort, allclose, expand_dims, unique, prod, sort, reshape, \
- transpose, dot, mean, flipud, ndarray, atleast_2d
+ transpose, dot, mean, ndarray, atleast_2d
import numpy as np
from scipy.misc import factorial
from windows import get_window
+from _arraytools import axis_slice, axis_reverse, odd_ext, even_ext, const_ext
__all__ = ['correlate', 'fftconvolve', 'convolve', 'convolve2d', 'correlate2d',
'order_filter', 'medfilt', 'medfilt2d', 'wiener', 'lfilter',
@@ -1276,64 +1277,274 @@ def detrend(data, axis=-1, type='linear', bp=0):
def lfilter_zi(b, a):
- #compute the zi state from the filter parameters. see [Gust96].
+ """
+ Compute an initial state `zi` for the lfilter function that corresponds
+ to the steady state of the step response.
+
+ A typical use of this function is to set the initial state so that the
+ output of the filter starts at the same value as the first element of
+ the signal to be filtered.
+
+ Parameters
+ ----------
+ b, a : array_like (1-D)
+ The IIR filter coefficients. See `scipy.signal.lfilter` for more
+ information.
+
+ Returns
+ -------
+ zi : 1-D ndarray
+ The initial state for the filter.
+
+ Notes
+ -----
+ A linear filter with order m has a state space representation (A, B, C, D),
+ for which the output y of the filter can be expressed as::
+
+ z(n+1) = A*z(n) + B*x(n)
+ y(n) = C*z(n) + D*x(n)
- #Based on:
- # [Gust96] Fredrik Gustafsson, Determining the initial states in
- # forward-backward filtering, IEEE Transactions on
- # Signal Processing, pp. 988--992, April 1996,
- # Volume 44, Issue 4
+ where z(n) is a vector of length m, A has shape (m, m), B has shape
+ (m, 1), C has shape (1, m) and D has shape (1, 1) (assuming x(n) is
+ a scalar). lfilter_zi solves::
+
+ zi = A*zi + B
+
+ In other words, it finds the initial condition for which the response
+ to an input of all ones is a constant.
+
+ Given the filter coefficients `a` and `b`, the state space matrices
+ for the transposed direct form II implementation of the linear filter,
+ which is the implementation used by scipy.signal.lfilter, are::
+
+ A = scipy.linalg.companion(a).T
+ B = b[1:] - a[1:]*b[0]
+
+ assuming `a[0]` is 1.0; if `a[0]` is not 1, `a` and `b` are first
+ divided by a[0].
+
+ Examples
+ --------
+ The following code creates a lowpass Butterworth filter. Then it
+ applies that filter to an array whose values are all 1.0; the
+ output is also all 1.0, as expected for a lowpass filter. If the
+ `zi` argument of `lfilter` had not been given, the output would have
+ shown the transient signal.
+
+ >>> from numpy import array, ones
+ >>> from scipy.signal import lfilter, lfilter_zi, butter
+ >>> b, a = butter(5, 0.25)
+ >>> zi = lfilter_zi(b, a)
+ >>> y, zo = lfilter(b, a, ones(10), zi=zi)
+ >>> y
+ array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
+
+ Another example:
+
+ >>> x = array([0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0])
+ >>> y, zf = lfilter(b, a, x, zi=zi*x[0])
+ >>> y
+ array([ 0.5 , 0.5 , 0.5 , 0.49836039, 0.48610528,
+ 0.44399389, 0.35505241])
+
+ Note that the `zi` argument to `lfilter` was computed using
+ `lfilter_zi` and scaled by `x[0]`. Then the output `y` has no
+ transient until the input drops from 0.5 to 0.0.
+
+ """
+
+ # FIXME: Can this function be replaced with an appropriate
+ # use of lfiltic? For example, when b,a = butter(N,Wn),
+ # lfiltic(b, a, y=numpy.ones_like(a), x=numpy.ones_like(b)).
+ #
+
+ # We could use scipy.signal.normalize, but it uses warnings in
+ # cases where a ValueError is more appropriate, and it allows
+ # b to be 2D.
+ b = np.atleast_1d(b)
+ if b.ndim != 1:
+ raise ValueError("Numerator b must be rank 1.")
+ a = np.atleast_1d(a)
+ if a.ndim != 1:
+ raise ValueError("Denominator a must be rank 1.")
+
+ while len(a) > 1 and a[0] == 0.0:
+ a = a[1:]
+ if a.size < 1:
+ raise ValueError("There must be at least one nonzero `a` coefficient.")
+
+ if a[0] != 1.0:
+ # Normalize the coefficients so a[0] == 1.
+ a = a / a[0]
+ b = b / a[0]
@mbauman
mbauman added a note Jul 3, 2014

Sorry to bump such an old commit, but is this doing what you intend? If I'm reading this correctly, you always divide b by 1 since a is becomes normalized on the previous line. I think you want to reverse the order here:

    b = b / a[0]
    a = a / a[0]

No need to apologize for pointing out a bug! You are correct, those lines need to be reversed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
n = max(len(a), len(b))
- zin = (np.eye(n - 1) - np.hstack((-a[1:n, newaxis],
- np.vstack((np.eye(n - 2), zeros(n - 2))))))
+ # Pad a or b with zeros so they are the same length.
+ if len(a) < n:
+ a = np.r_[a, np.zeros(n - len(a))]
+ elif len(b) < n:
+ b = np.r_[b, np.zeros(n - len(b))]
+
+ IminusA = np.eye(n - 1) - linalg.companion(a).T
+ B = b[1:] - a[1:] * b[0]
+ # Solve zi = A*zi + B
+ zi = np.linalg.solve(IminusA, B)
+
+ # For future reference: we could also use the following
+ # explicit formulas to solve the linear system:
+ #
+ # zi = np.zeros(n - 1)
+ # zi[0] = B.sum() / IminusA[:,0].sum()
+ # asum = 1.0
+ # csum = 0.0
+ # for k in range(1,n-1):
+ # asum += a[k]
+ # csum += b[k] - a[k]*b[0]
+ # zi[k] = asum*zi[0] - csum
- zid = b[1:n] - a[1:n] * b[0]
+ return zi
- zi_matrix = linalg.inv(zin) * (np.matrix(zid).transpose())
- zi_return = []
- #convert the result into a regular array (not a matrix)
- for i in range(len(zi_matrix)):
- zi_return.append(float(zi_matrix[i][0]))
+def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None):
+ """A forward-backward filter.
- return array(zi_return)
+ This function applies a linear filter twice, once forward
+ and once backwards. The combined filter has linear phase.
+ Before applying the filter, the function can pad the data along the
+ given axis in one of three ways: odd, even or constant. The odd
+ and even extensions have the corresponding symmetry about the end point
+ of the data. The constant extension extends the data with the values
+ at end points. On both the forward and backwards passes, the
+ initial condition of the filter is found by using lfilter_zi and
+ scaling it by the end point of the extended data.
-def filtfilt(b, a, x):
- b, a, x = map(asarray, [b, a, x])
- # FIXME: For now only accepting 1d arrays
- ntaps = max(len(a), len(b))
- edge = ntaps * 3
+ Parameters
+ ----------
+ b : array_like, 1-D
+ The numerator coefficient vector of the filter.
+ a : array_like, 1-D
+ The denominator coefficient vector of the filter. If a[0]
+ is not 1, then both a and b are normalized by a[0].
+ x : array_like
+ The array of data to be filtered.
+ axis : int, optional
+ The axis of `x` to which the filter is applied.
+ Default is -1.
+ padtype : str or None, optional
+ Must be 'odd', 'even', 'constant', or None. This determines the
+ type of extension to use for the padded signal to which the filter
+ is applied. If `padtype` is None, no padding is used. The default
+ is 'odd'.
+ padlen : int or None, optional
+ The number of elements by which to extend `x` at both ends of
+ `axis` before applying the filter. This value must be less than
+ `x.shape[axis]-1`. `padlen=0` implies no padding.
+ The default value is 3*max(len(a),len(b)).
- if x.ndim != 1:
- raise ValueError("filtfilt only accepts 1-d arrays.")
+ Returns
+ -------
+ y : ndarray
+ The filtered output, an array of type numpy.float64 with the same
+ shape as `x`.
- #x must be bigger than edge
- if x.size < edge:
- raise ValueError("Input vector needs to be bigger than "
- "3 * max(len(a),len(b).")
+ See Also
+ --------
+ lfilter_zi
+ lfilter
- if len(a) < ntaps:
- a = r_[a, zeros(len(b) - len(a))]
+ Examples
+ --------
+ First we create a one second signal that is the sum of two pure sine
+ waves, with frequencies 5 Hz and 250 Hz, sampled at 2000 Hz.
+
+ >>> t = np.linspace(0, 1.0, 2001)
+ >>> xlow = np.sin(2 * np.pi * 5 * t)
+ >>> xhigh = np.sin(2 * np.pi * 250 * t)
+ >>> x = xlow + xhigh
+
+ Now create a lowpass Butterworth filter with a cutoff of 0.125 times
+ the Nyquist rate, or 125 Hz, and apply it to x with filtfilt. The
+ result should be approximately xlow, with no phase shift.
+
+ >>> from scipy.signal import butter
+ >>> b, a = butter(8, 0.125)
+ >>> y = filtfilt(b, a, x, padlen=150)
+ >>> np.abs(y - xlow).max()
+ 9.1086182074789912e-06
+
+ We get a fairly clean result for this artificial example because
+ the odd extension is exact, and with the moderately long padding,
+ the filter's transients have dissipated by the time the actual data
+ is reached. In general, transient effects at the edges are
+ unavoidable.
+ """
+
+ if padtype not in ['even', 'odd', 'constant', None]:
+ raise ValueError(("Unknown value '%s' given to padtype. padtype must "
+ "be 'even', 'odd', 'constant', or None.") %
+ padtype)
+
+ b = np.asarray(b)
+ a = np.asarray(a)
+ x = np.asarray(x)
- if len(b) < ntaps:
- b = r_[b, zeros(len(a) - len(b))]
+ ntaps = max(len(a), len(b))
+
+ if padtype is None:
+ padlen = 0
+ if padlen is None:
+ # Original padding; preserved for backwards compatibility.
+ edge = ntaps * 3
+ else:
+ edge = padlen
+
+ # x's 'axis' dimension must be bigger than edge.
+ if x.shape[axis] <= edge:
+ raise ValueError("The length of the input vector x must be at least "
+ "padlen, which is %d." % edge)
+
+ if padtype is not None and edge > 0:
+ # Make an extension of length `edge` at each
+ # end of the input array.
+ if padtype == 'even':
+ ext = even_ext(x, edge, axis=axis)
+ elif padtype == 'odd':
+ ext = odd_ext(x, edge, axis=axis)
+ else:
+ ext = const_ext(x, edge, axis=axis)
+ else:
+ ext = x
+
+ # Get the steady state of the filter's step response.
zi = lfilter_zi(b, a)
- #Grow the signal to have edges for stabilizing
- #the filter with inverted replicas of the signal
- s = r_[2 * x[0] - x[edge:1:-1], x, 2 * x[-1] - x[-1:-edge:-1]]
- #in the case of one go we only need one of the extrems
- # both are needed for filtfilt
+ # Reshape zi and create x0 so that zi*x0 broadcasts
+ # to the correct value for the 'zi' keyword argument
+ # to lfilter.
+ zi_shape = [1] * x.ndim
+ zi_shape[axis] = zi.size
+ zi = np.reshape(zi, zi_shape)
+ x0 = axis_slice(ext, stop=1, axis=axis)
+
+ # Forward filter.
+ (y, zf) = lfilter(b, a, ext, zi=zi * x0)
+
+ # Backward filter.
+ # Create y0 so zi*y0 broadcasts appropriately.
+ y0 = axis_slice(y, start=-1, axis=axis)
+ (y, zf) = lfilter(b, a, axis_reverse(y, axis=axis), zi=zi * y0)
- (y, zf) = lfilter(b, a, s, -1, zi * s[0])
+ # Reverse y.
+ y = axis_reverse(y, axis=axis)
- (y, zf) = lfilter(b, a, flipud(y), -1, zi * y[-1])
+ if edge > 0:
+ # Slice the actual signal from the extended signal.
+ y = axis_slice(y, start=edge, stop=-edge, axis=axis)
- return flipud(y[edge - 1:-edge + 1])
+ return y
from scipy.signal.filter_design import cheby1
View
99 scipy/signal/tests/test_array_tools.py
@@ -0,0 +1,99 @@
+
+import numpy as np
+
+from numpy.testing import TestCase, run_module_suite, \
+ assert_array_equal, assert_raises
+
+from scipy.signal.array_tools import axis_slice, axis_reverse, \
+ odd_ext, even_ext, const_ext
+
+
+class TestArrayTools(TestCase):
+
+ def test_axis_slice(self):
+ a = np.arange(12).reshape(3, 4)
+
+ s = axis_slice(a, start=0, stop=1, axis=0)
+ assert_array_equal(s, a[0:1, :])
+
+ s = axis_slice(a, start=-1, axis=0)
+ assert_array_equal(s, a[-1:, :])
+
+ s = axis_slice(a, start=0, stop=1, axis=1)
+ assert_array_equal(s, a[:, 0:1])
+
+ s = axis_slice(a, start=-1, axis=1)
+ assert_array_equal(s, a[:, -1:])
+
+ s = axis_slice(a, start=0, step=2, axis=0)
+ assert_array_equal(s, a[::2, :])
+
+ s = axis_slice(a, start=0, step=2, axis=1)
+ assert_array_equal(s, a[:, ::2])
+
+ def test_axis_reverse(self):
+ a = np.arange(12).reshape(3, 4)
+
+ r = axis_reverse(a, axis=0)
+ assert_array_equal(r, a[::-1, :])
+
+ r = axis_reverse(a, axis=1)
+ assert_array_equal(r, a[:, ::-1])
+
+ def test_odd_ext(self):
+ a = np.array([[1, 2, 3, 4, 5],
+ [9, 8, 7, 6, 5]])
+
+ odd = odd_ext(a, 2, axis=1)
+ expected = np.array([[-1, 0, 1, 2, 3, 4, 5, 6, 7],
+ [11, 10, 9, 8, 7, 6, 5, 4, 3]])
+ assert_array_equal(odd, expected)
+
+ odd = odd_ext(a, 1, axis=0)
+ expected = np.array([[-7, -4, -1, 2, 5],
+ [ 1, 2, 3, 4, 5],
+ [ 9, 8, 7, 6, 5],
+ [17, 14, 11, 8, 5]])
+ assert_array_equal(odd, expected)
+
+ assert_raises(ValueError, odd_ext, a, 2, axis=0)
+ assert_raises(ValueError, odd_ext, a, 5, axis=1)
+
+ def test_even_ext(self):
+ a = np.array([[1, 2, 3, 4, 5],
+ [9, 8, 7, 6, 5]])
+
+ even = even_ext(a, 2, axis=1)
+ expected = np.array([[3, 2, 1, 2, 3, 4, 5, 4, 3],
+ [7, 8, 9, 8, 7, 6, 5, 6, 7]])
+ assert_array_equal(even, expected)
+
+ even = even_ext(a, 1, axis=0)
+ expected = np.array([[ 9, 8, 7, 6, 5],
+ [ 1, 2, 3, 4, 5],
+ [ 9, 8, 7, 6, 5],
+ [ 1, 2, 3, 4, 5]])
+ assert_array_equal(even, expected)
+
+ assert_raises(ValueError, even_ext, a, 2, axis=0)
+ assert_raises(ValueError, even_ext, a, 5, axis=1)
+
+ def test_const_ext(self):
+ a = np.array([[1, 2, 3, 4, 5],
+ [9, 8, 7, 6, 5]])
+
+ const = const_ext(a, 2, axis=1)
+ expected = np.array([[1, 1, 1, 2, 3, 4, 5, 5, 5],
+ [9, 9, 9, 8, 7, 6, 5, 5, 5]])
+ assert_array_equal(const, expected)
+
+ const = const_ext(a, 1, axis=0)
+ expected = np.array([[ 1, 2, 3, 4, 5],
+ [ 1, 2, 3, 4, 5],
+ [ 9, 8, 7, 6, 5],
+ [ 9, 8, 7, 6, 5]])
+ assert_array_equal(const, expected)
+
+
+if __name__ == "__main__":
+ run_module_suite()
View
49 scipy/signal/tests/test_signaltools.py
@@ -6,8 +6,8 @@
assert_raises, assert_, dec
import scipy.signal as signal
-from scipy.signal import lfilter, correlate, convolve, convolve2d, hilbert, \
- hilbert2
+from scipy.signal import correlate, convolve, convolve2d, \
+ hilbert, hilbert2, lfilter, lfilter_zi, filtfilt, butter, tf2zpk
from numpy import array, arange
@@ -534,11 +534,52 @@ def test_rank3(self):
globals()[cls.__name__] = cls
-class TestFiltFilt:
+class TestLFilterZI(TestCase):
+
+ def test_basic(self):
+ a = np.array([1.0, -1.0, 0.5])
+ b = np.array([1.0, 0.0, 2.0])
+ zi_expected = np.array([5.0, -1.0])
+ zi = lfilter_zi(b, a)
+ assert_array_almost_equal(zi, zi_expected)
+
+
+class TestFiltFilt(TestCase):
+
def test_basic(self):
- out = signal.filtfilt([1,2,3], [1,2,3], np.arange(12))
+ out = signal.filtfilt([1, 2, 3], [1, 2, 3], np.arange(12))
assert_equal(out, arange(12))
+ def test_sine(self):
+ rate = 2000
+ t = np.linspace(0, 1.0, rate + 1)
+ # A signal with low frequency and a high frequency.
+ xlow = np.sin(5 * 2 * np.pi * t)
+ xhigh = np.sin(250 * 2 * np.pi * t)
+ x = xlow + xhigh
+
+ b, a = butter(8, 0.125)
+ z, p, k = tf2zpk(b, a)
+ # r is the magnitude of the largest pole.
+ r = np.abs(p).max()
+ eps = 1e-5
+ # n estimates the number of steps for the
+ # transient to decay by a factor of eps.
+ n = int(np.ceil(np.log(eps) / np.log(r)))
+
+ # High order lowpass filter...
+ y = filtfilt(b, a, x, padlen=n)
+ # Result should be just xlow.
+ err = np.abs(y - xlow).max()
+ assert_(err < 1e-4)
+
+ # A 2D case.
+ x2d = np.vstack([xlow, xlow + xhigh])
+ y2d = filtfilt(b, a, x2d, padlen=n, axis=1)
+ assert_equal(y2d.shape, x2d.shape)
+ err = np.abs(y2d - xlow).max()
+ assert_(err < 1e-4)
+
class TestDecimate:
def test_basic(self):

0 comments on commit 8d4665f

Please sign in to comment.
Something went wrong with that request. Please try again.