Skip to content
This repository
Browse code

Merge pull request #3407 from argriffing/use-assert-warns

MAINT: use assert_warns instead of a more complicated mechanism
commit 13cbe0698c36ebc0c7a5d41ed9ed7008733871b8 2 parents a491a37 + 31491ee
authored April 20, 2014
47  scipy/lib/_numpy_compat.py
... ...
@@ -0,0 +1,47 @@
  1
+"""Functions copypasted from newer versions of numpy.
  2
+
  3
+"""
  4
+from __future__ import division, print_function, absolute_import
  5
+
  6
+import warnings
  7
+
  8
+import numpy as np
  9
+
  10
+from scipy.lib._version import NumpyVersion
  11
+
  12
+if NumpyVersion(np.__version__) > '1.7.0.dev':
  13
+    _assert_warns = np.testing.assert_warns
  14
+else:
  15
+    def _assert_warns(warning_class, func, *args, **kw):
  16
+        r"""
  17
+        Fail unless the given callable throws the specified warning.
  18
+
  19
+        This definition is copypasted from numpy 1.9.0.dev.
  20
+        The version in earlier numpy returns None.
  21
+
  22
+        Parameters
  23
+        ----------
  24
+        warning_class : class
  25
+            The class defining the warning that `func` is expected to throw.
  26
+        func : callable
  27
+            The callable to test.
  28
+        *args : Arguments
  29
+            Arguments passed to `func`.
  30
+        **kwargs : Kwargs
  31
+            Keyword arguments passed to `func`.
  32
+
  33
+        Returns
  34
+        -------
  35
+        The value returned by `func`.
  36
+
  37
+        """
  38
+        with warnings.catch_warnings(record=True) as l:
  39
+            warnings.simplefilter('always')
  40
+            result = func(*args, **kw)
  41
+            if not len(l) > 0:
  42
+                raise AssertionError("No warning raised when calling %s"
  43
+                        % func.__name__)
  44
+            if not l[0].category is warning_class:
  45
+                raise AssertionError("First warning for %s is not a "
  46
+                        "%s( is %s)" % (func.__name__, warning_class, l[0]))
  47
+        return result
29  scipy/linalg/tests/test_matfuncs.py
@@ -5,7 +5,6 @@
5 5
 """ Test functions for linalg.matfuncs module
6 6
 
7 7
 """
8  
-
9 8
 from __future__ import division, print_function, absolute_import
10 9
 
11 10
 import random
@@ -17,13 +16,13 @@
17 16
 from numpy.testing import (TestCase, run_module_suite,
18 17
         assert_array_equal, assert_array_less, assert_equal,
19 18
         assert_array_almost_equal, assert_array_almost_equal_nulp,
20  
-        assert_allclose, assert_, assert_raises, decorators,
21  
-        assert_raises)
  19
+        assert_allclose, assert_, decorators)
  20
+
  21
+from scipy.lib._numpy_compat import _assert_warns
22 22
 
23 23
 import scipy.linalg
24  
-from scipy.linalg import norm
25 24
 from scipy.linalg import (funm, signm, logm, sqrtm, fractional_matrix_power,
26  
-        expm, expm_frechet, expm_cond)
  25
+        expm, expm_frechet, expm_cond, norm)
27 26
 from scipy.linalg.matfuncs import expm2, expm3
28 27
 from scipy.linalg import _matfuncs_inv_ssq
29 28
 import scipy.linalg._expm_frechet
@@ -212,24 +211,16 @@ def test_logm_exactly_singular(self):
212 211
         B = np.asarray([[1, 1], [0, 0]])
213 212
         for M in A, A.T, B, B.T:
214 213
             expected_warning = _matfuncs_inv_ssq.LogmExactlySingularWarning
215  
-            with warnings.catch_warnings(record=True) as w:
216  
-                warnings.simplefilter('always')
217  
-                L, info = logm(M, disp=False)
218  
-                assert_equal(len(w), 1)
219  
-                assert_(issubclass(w[-1].category, expected_warning))
220  
-                E = expm(L)
221  
-                assert_allclose(E, M, atol=1e-14)
  214
+            L, info = _assert_warns(expected_warning, logm, M, disp=False)
  215
+            E = expm(L)
  216
+            assert_allclose(E, M, atol=1e-14)
222 217
 
223 218
     def test_logm_nearly_singular(self):
224 219
         M = np.array([[1e-100]])
225 220
         expected_warning = _matfuncs_inv_ssq.LogmNearlySingularWarning
226  
-        with warnings.catch_warnings(record=True) as w:
227  
-            warnings.simplefilter('always')
228  
-            L, info = logm(M, disp=False)
229  
-            assert_equal(len(w), 1)
230  
-            assert_(issubclass(w[-1].category, expected_warning))
231  
-            E = expm(L)
232  
-            assert_allclose(E, M, atol=1e-14)
  221
+        L, info = _assert_warns(expected_warning, logm, M, disp=False)
  222
+        E = expm(L)
  223
+        assert_allclose(E, M, atol=1e-14)
233 224
 
234 225
 
235 226
 class TestSqrtM(TestCase):

0 notes on commit 13cbe06

Please sign in to comment.
Something went wrong with that request. Please try again.