Skip to content

Commit

Permalink
ENH: Deal better with errors and skggm/scikit-learn
Browse files Browse the repository at this point in the history
  • Loading branch information
William de Vazelhes committed Mar 18, 2019
1 parent 60866cb commit eb95719
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 36 deletions.
53 changes: 37 additions & 16 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,11 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,

def _fit(self, pairs, y):
if not HAS_SKGGM:
msg = ("Warning, skggm is not installed, so SDML will use "
"scikit-learn's graphical_lasso method. It can fail to converge"
"on some non SPD matrices where skggm would converge. If so, "
"try to install skggm. (see the README.md for the right "
"version.)")
warnings.warn(msg)
if self.verbose:
print("SDML will use scikit-learn's graphical lasso solver.")
else:
print("SDML will use skggm's solver.")
if self.verbose:
print("SDML will use skggm's graphical lasso solver.")
pairs, y = self._prepare_inputs(pairs, y,
type_of_inputs='tuples')

Expand All @@ -93,15 +90,39 @@ def _fit(self, pairs, y):
"`balance_param` and/or to set use_covariance=False.",
ConvergenceWarning)
sigma0 = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T)
if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
try:
if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
raised_error = None
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
not_finite = not np.isfinite(M).all()
except Exception as e:
raised_error = e
not_spd = False # not_spd not applicable here so we set to False
not_finite = False # not_finite not applicable here so we set to False
if raised_error is not None or not_spd or not_finite:
msg = ("There was a problem in SDML when using {}'s graphical "
"lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn")
if not HAS_SKGGM:
skggm_advice = (" skggm's graphical lasso can sometimes converge "
"on non SPD cases where scikit-learn's graphical "
"lasso fails to converge. Try to install skggm and "
"rerun the algorithm (see the README.md for the "
"right version of skggm).")
msg += skggm_advice
if raised_error is not None:
msg += " The following error message was thrown: {}.".format(
raised_error)
raise RuntimeError(msg)

self.transformer_ = transformer_from_metric(np.atleast_2d(M))
return self

Expand Down
130 changes: 110 additions & 20 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,28 +155,89 @@ def test_no_twice_same_objective(capsys):
class TestSDML(MetricTestCase):

@pytest.mark.skipif(HAS_SKGGM,
reason="The warning will be thrown only if skggm is "
reason="The warning can be thrown only if skggm is "
"not installed.")
def test_raises_warning_msg_not_installed_skggm(self):
def test_sdml_supervised_raises_warning_msg_not_installed_skggm(self):
"""Tests that the right warning message is raised if someone tries to
use SDML but has not installed skggm"""
use SDML_Supervised but has not installed skggm, and that the algorithm
fails to converge"""
# TODO: remove if we don't need skggm anymore
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
# load_iris: dataset where we know scikit-learn's graphical lasso fails
# with a Floating Point error
X, y = load_iris(return_X_y=True)
sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=True,
sparsity_param=0.01)
msg = ("There was a problem in SDML when using scikit-learn's graphical "
"lasso solver. skggm's graphical lasso can sometimes converge on "
"non SPD cases where scikit-learn's graphical lasso fails to "
"converge. Try to install skggm and rerun the algorithm (see "
"the README.md for the right version of skggm). The following "
"error message was thrown:")
with pytest.raises(RuntimeError) as raised_error:
sdml_supervised.fit(X, y)
assert str(raised_error.value).startswith(msg)

@pytest.mark.skipif(HAS_SKGGM,
reason="The warning can be thrown only if skggm is "
"not installed.")
def test_sdml_raises_warning_msg_not_installed_skggm(self):
"""Tests that the right warning message is raised if someone tries to
use SDML but has not installed skggm, and that the algorithm fails to
converge"""
# TODO: remove if we don't need skggm anymore
# case on which we know that scikit-learn's graphical lasso fails
# because it will return a non SPD matrix
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
y_pairs = [1, -1]
X, y = make_classification(random_state=42)
sdml = SDML()
sdml_supervised = SDML_Supervised(use_cov=False, balance_param=1e-5)
msg = ("Warning, skggm is not installed, so SDML will use "
"scikit-learn's graphical_lasso method. It can fail to converge"
"on some non SPD matrices where skggm would converge. If so, "
"try to install skggm. (see the README.md for the right "
"version.)")
with pytest.warns(None) as record:
sdml = SDML(use_cov=False, balance_param=100, verbose=True)

msg = ("There was a problem in SDML when using scikit-learn's graphical "
"lasso solver. skggm's graphical lasso can sometimes converge on "
"non SPD cases where scikit-learn's graphical lasso fails to "
"converge. Try to install skggm and rerun the algorithm (see "
"the README.md for the right version of skggm).")
with pytest.raises(RuntimeError) as raised_error:
sdml.fit(pairs, y_pairs)
assert str(record[0].message) == msg
with pytest.warns(None) as record:
assert msg == str(raised_error.value)

@pytest.mark.skipif(not HAS_SKGGM,
reason="The warning can be thrown only if skggm is "
"installed.")
def test_sdml_raises_warning_msg_installed_skggm(self):
"""Tests that the right warning message is raised if someone tries to
use SDML but has not installed skggm, and that the algorithm fails to
converge"""
# TODO: remove if we don't need skggm anymore
# case on which we know that skggm's graphical lasso fails
# because it will return non finite values
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
y_pairs = [1, -1]
sdml = SDML(use_cov=False, balance_param=100, verbose=True)

msg = ("There was a problem in SDML when using skggm's graphical "
"lasso solver.")
with pytest.raises(RuntimeError) as raised_error:
sdml.fit(pairs, y_pairs)
assert msg == str(raised_error.value)

@pytest.mark.skipif(not HAS_SKGGM,
reason="The warning can be thrown only if skggm is "
"installed.")
def test_sdml_supervised_raises_warning_msg_installed_skggm(self):
"""Tests that the right warning message is raised if someone tries to
use SDML_Supervised but has not installed skggm, and that the algorithm
fails to converge"""
# TODO: remove if we don't need skggm anymore
# case on which we know that skggm's graphical lasso fails
# because it will return non finite values
X, y = load_iris(return_X_y=True)
sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=True,
sparsity_param=0.01)
msg = ("There was a problem in SDML when using skggm's graphical "
"lasso solver.")
with pytest.raises(RuntimeError) as raised_error:
sdml_supervised.fit(X, y)
assert str(record[0].message) == msg
assert msg == str(raised_error.value)

@pytest.mark.skipif(not HAS_SKGGM,
reason="It's only in the case where skggm is installed"
Expand Down Expand Up @@ -271,10 +332,10 @@ def test_verbose_has_installed_skggm_sdml(capsys):
# TODO: remove if we don't need skggm anymore
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
sdml = SDML()
sdml = SDML(verbose=True)
sdml.fit(pairs, y_pairs)
out, _ = capsys.readouterr()
assert "SDML will use skggm's solver." in out
assert "SDML will use skggm's graphical lasso solver." in out


@pytest.mark.skipif(not HAS_SKGGM,
Expand All @@ -285,10 +346,39 @@ def test_verbose_has_installed_skggm_sdml_supervised(capsys):
# skggm's solver is used (when they use SDML_Supervised)
# TODO: remove if we don't need skggm anymore
X, y = make_classification(random_state=42)
sdml = SDML_Supervised()
sdml = SDML_Supervised(verbose=True)
sdml.fit(X, y)
out, _ = capsys.readouterr()
assert "SDML will use skggm's graphical lasso solver." in out


@pytest.mark.skipif(HAS_SKGGM,
reason='The message should be printed only if skggm is '
'not installed.')
def test_verbose_has_not_installed_skggm_sdml(capsys):
# Test that if users have installed skggm, a message is printed telling them
# skggm's solver is used (when they use SDML)
# TODO: remove if we don't need skggm anymore
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
sdml = SDML(verbose=True)
sdml.fit(pairs, y_pairs)
out, _ = capsys.readouterr()
assert "SDML will use scikit-learn's graphical lasso solver." in out


@pytest.mark.skipif(HAS_SKGGM,
reason='The message should be printed only if skggm is '
'not installed.')
def test_verbose_has_not_installed_skggm_sdml_supervised(capsys):
# Test that if users have installed skggm, a message is printed telling them
# skggm's solver is used (when they use SDML_Supervised)
# TODO: remove if we don't need skggm anymore
X, y = make_classification(random_state=42)
sdml = SDML_Supervised(verbose=True, balance_param=1e-5, use_cov=False)
sdml.fit(X, y)
out, _ = capsys.readouterr()
assert "SDML will use skggm's solver." in out
assert "SDML will use scikit-learn's graphical lasso solver." in out


class TestNCA(MetricTestCase):
Expand Down

0 comments on commit eb95719

Please sign in to comment.