Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixed allclose call in normalize() of filter_design.py and added a te…

…st so check the output against the same filter coefficients from MATLAB
  • Loading branch information...
commit 8bd46f4e7a4f78c69a6d476be52276c231797f1e 1 parent c818914
@wa03 wa03 authored rgommers committed
View
4 scipy/signal/filter_design.py
@@ -283,10 +283,10 @@ def normalize(b, a):
a = a[1:]
outb = b * (1.0) / a[0]
outa = a * (1.0) / a[0]
- if allclose(outb[:, 0], 0, rtol=1e-14):
+ if allclose(0, outb[:, 0], atol=1e-14):
warnings.warn("Badly conditioned filter coefficients (numerator): the "
"results may be meaningless", BadCoefficients)
- while allclose(outb[:, 0], 0, rtol=1e-14) and (outb.shape[-1] > 1):
+ while allclose(0, outb[:, 0], atol=1e-14) and (outb.shape[-1] > 1):
outb = outb[:, 1:]
if outb.shape[0] == 1:
outb = outb[0]
View
48 scipy/signal/tests/test_filter_design.py
@@ -5,7 +5,7 @@
assert_array_equal, assert_raises, assert_equal, assert_, \
run_module_suite
-from scipy.signal import tf2zpk, zpk2tf, BadCoefficients, freqz
+from scipy.signal import tf2zpk, zpk2tf, BadCoefficients, freqz, normalize
class TestTf2zpk(TestCase):
@@ -86,6 +86,52 @@ def plot(w, h):
freqz, [1.0], worN=8, plot=lambda w, h: 1 / 0)
freqz([1.0], worN=8, plot=plot)
+class TestNormalize(TestCase):
+
+ def test_allclose(self):
+ """Test for false positive on allclose in normalize() in
+ filter_design.py"""
+ # Test to make sure the allclose call within signal.normalize does not
+ # choose false positives. Then check against a known output from MATLAB
+ # to make sure the fix doesn't break anything.
+
+ # These are the coefficients returned from
+ # `[b,a] = cheby1(8, 0.5, 0.048)'
+ # in MATLAB. There are at least 15 significant figures in each
+ # coefficient, so it makes sense to test for errors on the order of
+ # 1e-13 (this can always be relaxed if different platforms have
+ # different rounding errors)
+ b_matlab = np.array([2.150733144728282e-11, 1.720586515782626e-10,
+ 6.022052805239190e-10, 1.204410561047838e-09,
+ 1.505513201309798e-09, 1.204410561047838e-09,
+ 6.022052805239190e-10, 1.720586515782626e-10,
+ 2.150733144728282e-11])
+ a_matlab = np.array([1.000000000000000e+00, -7.782402035027959e+00,
+ 2.654354569747454e+01, -5.182182531666387e+01,
+ 6.334127355102684e+01, -4.963358186631157e+01,
+ 2.434862182949389e+01, -6.836925348604676e+00,
+ 8.412934944449140e-01])
+
+ # This is the input to signal.normalize after passing through the
+ # equivalent steps in signal.iirfilter as was done for MATLAB
+ b_norm_in = np.array([1.5543135865293012e-06, 1.2434508692234413e-05,
+ 4.3520780422820447e-05, 8.7041560845640893e-05,
+ 1.0880195105705122e-04, 8.7041560845640975e-05,
+ 4.3520780422820447e-05, 1.2434508692234413e-05,
+ 1.5543135865293012e-06])
+ a_norm_in = np.array([7.2269025909127173e+04, -5.6242661430467968e+05,
+ 1.9182761917308895e+06, -3.7451128364682454e+06,
+ 4.5776121393762771e+06, -3.5869706138592605e+06,
+ 1.7596511818472347e+06, -4.9409793515707983e+05,
+ 6.0799461347219651e+04])
+
+ b_output, a_output = normalize(b_norm_in, a_norm_in)
+
+ # The test on b works for decimal=14 but the one for a does not. For
+ # the sake of consistency, both of these are decimal=13. If something
+ # breaks on another platform, it is probably fine to relax this lower.
+ assert_array_almost_equal(b_matlab, b_output, decimal=13)
+ assert_array_almost_equal(a_matlab, a_output, decimal=13)
if __name__ == "__main__":
run_module_suite()
Please sign in to comment.
Something went wrong with that request. Please try again.