Skip to content

Commit

Permalink
TST: fix test that fails weirdly when executing the whole test file a…
Browse files Browse the repository at this point in the history
…nd not just the test
  • Loading branch information
William de Vazelhes committed Mar 20, 2019
1 parent bfb0f8f commit 6f5666b
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,26 +285,6 @@ def test_deprecation_num_labeled(self):
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)

def test_sdml_raises_warning_non_psd(self):
"""Tests that SDML raises a warning on a toy example where we know the
pseudo-covariance matrix is not PSD"""
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
y = [1, -1]
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
msg = ("Warning, the input matrix of graphical lasso is not "
"positive semi-definite (PSD). The algorithm may diverge, "
"and lead to degenerate solutions. "
"To prevent that, try to decrease the balance parameter "
"`balance_param` and/or to set use_covariance=False.")
with pytest.warns(ConvergenceWarning) as raised_warning:
try:
sdml.fit(pairs, y)
except Exception:
pass
# we assert that this warning is in one of the warning raised by the
# estimator
assert msg in list(map(lambda w: str(w.message), raised_warning))

def test_sdml_converges_if_psd(self):
"""Tests that sdml converges on a simple problem where we know the
pseudo-covariance matrix is PSD"""
Expand Down Expand Up @@ -385,6 +365,27 @@ def test_verbose_has_not_installed_skggm_sdml_supervised(capsys):
assert "SDML will use scikit-learn's graphical lasso solver." in out


def test_sdml_raises_warning_non_psd():
"""Tests that SDML raises a warning on a toy example where we know the
pseudo-covariance matrix is not PSD"""
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
y = [1, -1]
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
msg = ("Warning, the input matrix of graphical lasso is not "
"positive semi-definite (PSD). The algorithm may diverge, "
"and lead to degenerate solutions. "
"To prevent that, try to decrease the balance parameter "
"`balance_param` and/or to set use_covariance=False.")
with pytest.warns(ConvergenceWarning) as raised_warning:
try:
sdml.fit(pairs, y)
except Exception:
pass
# we assert that this warning is in one of the warning raised by the
# estimator
assert msg in list(map(lambda w: str(w.message), raised_warning))


class TestNCA(MetricTestCase):
def test_iris(self):
n = self.iris_points.shape[0]
Expand Down

0 comments on commit 6f5666b

Please sign in to comment.