Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FIX: make proposal for sdml formulation #162

Merged
merged 78 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
02cc937
FIX: make proposal for sdml formulation
Jan 24, 2019
aebd47f
MAINT clearer formulation to make the prior appear
Jan 29, 2019
40b2c88
MAINT call the prior prior
Jan 29, 2019
fb04cc9
Use skggm instead of graphical lasso
Feb 1, 2019
518d6e8
Be more severe for the class separation
Feb 1, 2019
c912e93
Merge branch 'master' into fix/proposal_for_sdml
Feb 1, 2019
8f0b113
Put back verbose param
Feb 1, 2019
c57a35a
MAINT: make more explicit the fact that to use identity (i.e. an SPD …
Feb 1, 2019
f0eb938
Add skggm as a requirement for SDML
Feb 1, 2019
821db0b
Add skggm to required packages for travis
Feb 1, 2019
bd2862d
Also add cython as a dependency
Feb 1, 2019
c6a2daa
FIX: install all except skggm and then skggm
Feb 1, 2019
93d790e
Remove cython dependency
Feb 1, 2019
cae6c28
Install skggm only if we have at least python 3.6
Feb 15, 2019
5d673ba
Should work if we want other versions superior to 3.6
Feb 15, 2019
e8a28d5
Fix bash >= which should be written -ge
Feb 15, 2019
e740702
Deal with tests when skggm is not installed and fix some PEP8 warnings
Feb 15, 2019
333675b
replace manual calls of algorithms with tuples_learners
Feb 15, 2019
1a6e97b
Remove another call of SDML if skggm is not installed
Feb 15, 2019
7cecf27
FIX fix the test_error_message_tuple_size
Feb 15, 2019
5303e1a
FIX fix test_sdml_supervised
Feb 15, 2019
377760a
FIX: fix another sdml test
Feb 15, 2019
0a46ad5
FIX quic call for python 2.7
Feb 15, 2019
391d773
Fix quic import
Feb 15, 2019
6654769
Add Sigma0 initalization (both sigma zero and theta zero should be sp…
Feb 15, 2019
ac4e18a
Deal with SDML making some tests fail
Feb 15, 2019
458d646
Remove epsilon that was unnecessary
Feb 15, 2019
fd7c9fb
FIX: use latest commit of skggm that fixes the non deterministic problem
Feb 19, 2019
e118cd8
MAINT: add message for SDML when not SPD
Mar 5, 2019
b0c4753
MAINT: add test for error message if skggm not installed
Mar 5, 2019
13146d8
Try other syntax for installing the right commit of skggm
Mar 5, 2019
db4a799
MAINT: make sklearn compat sdml test be run only if skggm is installed
Mar 5, 2019
1011391
Try another syntax for running travis
Mar 5, 2019
5ea7ba0
Better bash syntax
Mar 5, 2019
45d3b7b
Fix tests by removing duplicates
Mar 6, 2019
dbf5257
FIX: fix for sdml by reducing balance parameter
Mar 6, 2019
4b0bae9
FIX: update code to work with old version of numpy that does not have…
Mar 6, 2019
f3c690e
Remove the need for skggm
Mar 7, 2019
57b0567
Update travis not to use skggm
Mar 7, 2019
04316b2
Add a stable init for sklearn checks
Mar 7, 2019
b641641
FIX test_sdml_supervised
Mar 7, 2019
fedfb8e
Revert "Update travis not to use skggm"
Mar 8, 2019
f0bbf6d
Add fallback on skggm
Mar 8, 2019
520d7c2
FIX: fix versions comparison and tests
Mar 8, 2019
0437c62
MAINT: improve test of no warning
Mar 8, 2019
be1a5e6
FIX: fix wrap pairs that was returning column y (we need line y), and…
Mar 8, 2019
56efa09
FIX: force travis to do the right check
Mar 8, 2019
142eea9
TST: add non SPD test that works with skggm's quic but not sklearn's …
Mar 8, 2019
fcfd44c
Try again travis this time installing cython
Mar 8, 2019
019e28b
Try to make travis work with build_essential
Mar 8, 2019
04a5107
Try with installing liblapack
Mar 8, 2019
be3a2ad
TST: fix tests for when skggm is not installed
Mar 8, 2019
1ee8d1f
TST: use better pytest skipif syntax
Mar 8, 2019
03f4158
FIX: fix broken link in README.md
Mar 8, 2019
e621e27
use rst syntax for link
Mar 8, 2019
0086c98
use rst syntax for link
Mar 8, 2019
001600e
use rst syntax for link
Mar 8, 2019
8c50a0d
MAINT: remove test_sdml that was remaining from drafts tests
Mar 8, 2019
e4132d6
TST: remove skipping SDML in test_cross_validation_manual_vs_scikit
Mar 8, 2019
b3bf6a8
FIX link also in getting started
Mar 8, 2019
49f3b9e
Put back right indent
Mar 8, 2019
e1664c7
Remove unnecessary changes
Mar 8, 2019
187e22c
merging
Mar 18, 2019
60866cb
Nitpick for concatenation and refactor HAS_SKGGM
Mar 18, 2019
eb95719
ENH: Deal better with errors and skggm/scikit-learn
Mar 18, 2019
4d61dba
Better creation of prior
Mar 18, 2019
71a02e0
Simplification for init of sdml
Mar 18, 2019
1e6d440
Put skggm as optional
Mar 18, 2019
a7ed1bb
Specify skggm version
Mar 18, 2019
31072d3
TST: make test about 1 feature arrays more readable
Mar 18, 2019
000f29a
DOC: fix rst formatting
Mar 18, 2019
169dccf
DOC: reformulated skggm optional dependency
Mar 18, 2019
bfb0f8f
TST: give an example for sdml_supervised with skggm where it indeed f…
Mar 20, 2019
6f5666b
TST: fix test that fails weirdly when executing the whole test file a…
Mar 20, 2019
0973ef2
Revert "TST: fix test that fails weirdly when executing the whole tes…
Mar 20, 2019
1c28ecd
Merge branch 'master' into fix/proposal_for_sdml
wdevazelhes Mar 21, 2019
df2ae9c
Add coverage for all versions of python
Mar 21, 2019
9683934
Install pytest-cov for all versions
Mar 21, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@ cache: pip
python:
- "2.7"
- "3.4"
- "3.6"
before_install:
- sudo apt-get install liblapack-dev
- pip install --upgrade pip
- pip install wheel
- pip install numpy scipy scikit-learn
- pip install cython numpy scipy scikit-learn
- if [[ ($TRAVIS_PYTHON_VERSION == "3.6") ||
($TRAVIS_PYTHON_VERSION == "2.7")]]; then
pip install git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8;
fi
script: pytest test
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Metric Learning algorithms in Python.
**Dependencies**

- Python 2.7+, 3.4+
- numpy, scipy, scikit-learn
- numpy, scipy, scikit-learn, and skggm (commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_) for `SDML`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd call out skggm separately as an optional dependency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I agree, I forgot to change that after making skggm optional

- (for running the examples only: matplotlib)

**Installation/Setup**
Expand Down
2 changes: 1 addition & 1 deletion doc/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Alternately, download the source repository and run:
**Dependencies**

- Python 2.7+, 3.4+
- numpy, scipy, scikit-learn
- numpy, scipy, scikit-learn, and skggm (commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_) for `SDML`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- (for running the examples only: matplotlib)

**Notes**
Expand Down
8 changes: 8 additions & 0 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ def vector_norm(X):
return np.linalg.norm(X, axis=1)


def has_installed_skggm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this needs to be a function, as the answer isn't going to change during execution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right thanks

try:
import inverse_covariance
return True
except ImportError:
return False


def check_input(input_data, y=None, preprocessor=None,
type_of_inputs='classic', tuple_size=None, accept_sparse=False,
dtype='numeric', order=None,
Expand Down
2 changes: 1 addition & 1 deletion metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ def wrap_pairs(X, constraints):
c = np.array(constraints[2])
d = np.array(constraints[3])
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
y = np.hstack([np.ones((len(a),)), - np.ones((len(c),))])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact we should never return a column vector but a line vector (these are the ones scikit-learn likes to work on)
Otherwise scikit-learn's checks called by our checks will return a warning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe be simpler to do:

y = np.concatenate([np.ones_like(a), -np.ones_like(c)])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks

pairs = X[constraints]
return pairs, y
53 changes: 41 additions & 12 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
import warnings
import numpy as np
from sklearn.base import TransformerMixin
from sklearn.covariance import graph_lasso
from sklearn.utils.extmath import pinvh
from scipy.linalg import pinvh
from sklearn.covariance import graphical_lasso
from sklearn.exceptions import ConvergenceWarning

from .base_metric import MahalanobisMixin, _PairsClassifierMixin
from .constraints import Constraints, wrap_pairs
from ._util import transformer_from_metric
from ._util import transformer_from_metric, has_installed_skggm
if has_installed_skggm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer a simpler conditional import here:

try:
  from inverse_covariance import quic
except ImportError:
  HAS_SKGGM = False
else:
  HAS_SKGGM = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd need to duplicate the logic in the tests, but I'm fine with that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

from inverse_covariance import quic


class _BaseSDML(MahalanobisMixin):
Expand Down Expand Up @@ -52,24 +55,50 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
super(_BaseSDML, self).__init__(preprocessor)

def _fit(self, pairs, y):
if not has_installed_skggm():
msg = ("Warning, skggm is not installed, so SDML will use "
"scikit-learn's graphical_lasso method. It can fail to converge"
"on some non SPD matrices where skggm would converge. If so, "
"try to install skggm. (see the README.md for the right "
"version.)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can catch the case where scikit-learn's version fails and emit the warning then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, here is the new version I'll commit. When using scikit-learn's graphical lasso, we try it and if an error is returned or the result is not finite, we raise a warning that will be printed before the error (if there is an error), or before returning M (if there is no error but there are NaNs) : Tell me what you think

    if HAS_SKGGM:
      theta0 = pinvh(sigma0)
      M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
                              msg=self.verbose,
                              Theta0=theta0, Sigma0=sigma0)
    else:
      try:
        _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
                               verbose=self.verbose,
                               cov_init=sigma0)
        error = None
      except FloatingPointError as e:
        error = e
      if not np.isfinite(M).all() or error is not None:
        msg = ("Scikit-learn's graphical lasso has failed to converge. "
               "Package skggm's graphical lasso can sometimes converge on "
               "non SPD cases where scikit-learn's graphical lasso fails to "
               "converge. Try to install skggm and rerun the algorithm. (see "
               "the README.md for the right version.)")
        warnings.warn(msg)
        if error is not None:
          raise(error)  

EDIT:

    if HAS_SKGGM:
      theta0 = pinvh(sigma0)
      M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
                              msg=self.verbose,
                              Theta0=theta0, Sigma0=sigma0)
    else:
      try:
        _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
                               verbose=self.verbose,
                               cov_init=sigma0)
      except FloatingPointError as e:
        msg = ("Scikit-learn's graphical lasso has failed to converge. "
               "Package skggm's graphical lasso can sometimes converge on "
               "non SPD cases where scikit-learn's graphical lasso fails to "
               "converge. Try to install skggm and rerun the algorithm. (see "
               "the README.md for the right version of skggm.)")
        warnings.warn(msg)
        raise(e)

(in fact it's skggm's graphical lasso that throws nans, scikit-learn's graphical lasso will return FloatingPointError in case of error (i didn't find cases where it would give nans) so it's better to stick to the case we know)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, trying to come up with the right tests, I realized the following:
in pathological diagonal cases (see above), scikit-learn's graphical lasso can return no error and return a result that is not SPD (in the below example, a diagonal matrix with a negative coefficient on the diagonal). But then a NAN will appear when taking the transformer_from_metric. So maybe if using scikit-learn, I should additionally check that the result is indeed SPD and if not raise the error "SDML has failed to converge to an SPD matrix using scikit-learn's graphical lasso"
I don't know if skggm guarantees to return an SPD matrix, so maybe I should do the check there too ? (and then later if we dig deeper into that we could remove this if the test is not/[no more] needed)

Example (go in debug mode or put a print statement in SDML to see the result of the graphical lasso)

    from metric_learn import SDML
    import numpy as np

    pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
    y_pairs = [1, -1]

    sdml = SDML(use_cov=False, balance_param=100,verbose=True)

    diff = pairs[:, 0] - pairs[:, 1]
    emp_cov = np.identity(pairs.shape[2]) + 100 * (diff.T * y_pairs).dot(diff)

    print(emp_cov)

    sdml.fit(pairs, y_pairs)
    print(sdml.get_mahalanobis_matrix())

Returns:

[[   40001.        0.]
 [       0. -1209999.]]
SDML will use scikit-learn's graphical lasso solver.
/home/will/Code/metric-learn/metric_learn/sdml.py:92: ConvergenceWarning: Warning, the input matrix of graphical lasso is not positive semi-definite (PSD). The algorithm may diverge, and lead to degenerate solutions. To prevent that, try to decrease the balance parameter `balance_param` and/or to set use_covariance=False.
  ConvergenceWarning)
[graphical_lasso] Iteration   0, cost  inf, dual gap 0.000e+00
/home/will/Code/metric-learn/metric_learn/_util.py:346: RuntimeWarning: invalid value encountered in sqrt
  return np.sqrt(metric)
[[2.4999375e-05           nan]
 [          nan           nan]]

And if we print the result of graphical lasso (note that it's the inverse of the initial matrix):

[[ 2.49993750e-05  0.00000000e+00]
 [ 0.00000000e+00 -8.26446964e-07]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this:
1- first check if the matrix given as input is PSD. If not, show the warning that the algorithm may diverge etc
2- regardless of the solver: if an error is thrown by the solver, or the solution is not a PSD matrix, show error and abort (as there is no result or the result does not correspond to a valid distance)
3- if we fall in case 2- and we use sklearn solver, say that the user could try to install skggm as its graphical lasso solver is more stable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I'll go for something like this:

    try:
      if HAS_SKGGM:
        theta0 = pinvh(sigma0)
        M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
                                msg=self.verbose,
                                Theta0=theta0, Sigma0=sigma0)
      else:
        _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
                               verbose=self.verbose,
                               cov_init=sigma0)
      raised_error = None
      w_mahalanobis, _ = np.linalg.eigh(M)
      not_spd = any(w_mahalanobis < 0.)
    except Exception as e:
      raised_error = e
      not_spd = False  # not_spd not applicable so we set to False
    if raised_error is not None or not_spd:
      msg = ("There was a problem in SDML when using {}'s graphical "
             "lasso.").format("skggm" if HAS_SKGGM else "scikit-learn")
      if not HAS_SKGGM:
        skggm_advice = ("skggm's graphical lasso can sometimes converge "
                        "on non SPD cases where scikit-learn's graphical "
                        "lasso fails to converge. Try to install skggm and "
                        "rerun the algorithm. (See the README.md for the " 
                        "right version of skggm.)")
        msg += skggm_advice
      if raised_error is not None:
        msg += "The following error message was thrown: {}.".format(
          raised_error)
      raise RuntimeError(msg)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add infinite values that can be returned by skggm as a failure case too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, in commit eb95719

warnings.warn(msg)
else:
print("SDML will use skggm's solver.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only print this if self.verbose is true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And maybe clarify: skggm's graphical lasso solver

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and maybe print something similar when sklearn solver is used

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

pairs, y = self._prepare_inputs(pairs, y,
type_of_inputs='tuples')

# set up prior M
if self.use_cov:
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
prior = pinvh(np.atleast_2d(np.cov(X, rowvar=False)))
else:
M = np.identity(pairs.shape[2])
prior = np.identity(pairs.shape[2])
diff = pairs[:, 0] - pairs[:, 1]
loss_matrix = (diff.T * y).dot(diff)
P = M + self.balance_param * loss_matrix
emp_cov = pinvh(P)
# hack: ensure positive semidefinite
emp_cov = emp_cov.T.dot(emp_cov)
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)

self.transformer_ = transformer_from_metric(M)
emp_cov = pinvh(prior) + self.balance_param * loss_matrix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems odd to round-trip the covariance matrix through two pinvh calls in the self.use_cov case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, done


# our initialization will be the matrix with emp_cov's eigenvalues,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an init that we talked about with @bellet, that I found worked better (allowed tests to pass when with identity I had a lot of Linalg Error)

# with a constant added so that they are all positive (plus an epsilon
# to ensure definiteness). This is empirical.
w, V = np.linalg.eigh(emp_cov)
if any(w < 0.):
warnings.warn("Warning, the input matrix of graphical lasso is not "
"positive semi-definite (PSD). The algorithm may diverge, "
"and lead to degenerate solutions. "
"To prevent that, try to decrease the balance parameter "
"`balance_param` and/or to set use_covariance=False.",
ConvergenceWarning)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ConvergenceWarning OK ? It's the one that seemed the more appropriate but here we raise it before even running the graphical lasso so maybe it's a bit weird... It would be better a PossibleConvergenceWarning kind of warning maybe ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it.

sigma0 = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe simpler:

min_eigval = w.min()
if min_eigval < 0:
  warnings.warn(...)
  min_eigval = 0

w += 1e-10 - min_eigval
sigma0 = (V * w).dot(V.T)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the heuristic if min_eigval > 0 we would want to keep the matrix untouched (with just an epsilon added), (so I don't think with your solution it would be the case since we would substract the min eigval) but I agree something simpler would be better, so something like:

min_eigval = w.min()
if min_eigval < 0:
  warnings.warn(...)
  w -= min_eigval   # we translate the eigenvalues to make them all positive

w += 1e-10  # we add a small offset to avoid definiteness problems
sigma0 = (V * w).dot(V.T)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if has_installed_skggm():
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
return self


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
extras_require=dict(
docs=['sphinx', 'shinx_rtd_theme', 'numpydoc'],
demo=['matplotlib'],
sdml=['skggm']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we specify a commit hash here? Or maybe since their latest release is 0.2.8, we could specify skggm>=0.2.9.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried indeed but didn't manage to make it work. But good idea, skggm>=0.2.9 seems is better than nothing here

),
test_suite='test',
keywords=[
Expand Down
162 changes: 139 additions & 23 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import check_X_y

from metric_learn import (
LMNN, NCA, LFDA, Covariance, MLKR, MMC,
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised)
from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC,
LSML_Supervised, ITML_Supervised, SDML_Supervised,
RCA_Supervised, MMC_Supervised, SDML)
# Import this specially for testing.
from metric_learn._util import has_installed_skggm
from metric_learn.constraints import wrap_pairs
from metric_learn.lmnn import python_LMNN

Expand Down Expand Up @@ -148,27 +149,142 @@ def test_no_twice_same_objective(capsys):


class TestSDML(MetricTestCase):
def test_iris(self):
# Note: this is a flaky test, which fails for certain seeds.
# TODO: un-flake it!
rs = np.random.RandomState(5555)

sdml = SDML_Supervised(num_constraints=1500)
sdml.fit(self.iris_points, self.iris_labels, random_state=rs)
csep = class_separation(sdml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.25)

def test_deprecation_num_labeled(self):
# test that a deprecation message is thrown if num_labeled is set at
# initialization
# TODO: remove in v.0.6
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
msg = ('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)
@pytest.mark.skipif(has_installed_skggm(),
reason="The warning will be thrown only if skggm is "
"not installed.")
def test_raises_warning_msg_not_installed_skggm(self):
"""Tests that the right warning message is raised if someone tries to
use SDML but has not installed skggm"""
# TODO: remove if we don't need skggm anymore
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
X, y = make_classification(random_state=42)
sdml = SDML()
sdml_supervised = SDML_Supervised(use_cov=False, balance_param=1e-5)
msg = ("Warning, skggm is not installed, so SDML will use "
"scikit-learn's graphical_lasso method. It can fail to converge"
"on some non SPD matrices where skggm would converge. If so, "
"try to install skggm. (see the README.md for the right "
"version.)")
with pytest.warns(None) as record:
sdml.fit(pairs, y_pairs)
assert str(record[0].message) == msg
with pytest.warns(None) as record:
sdml_supervised.fit(X, y)
assert str(record[0].message) == msg

@pytest.mark.skipif(not has_installed_skggm(),
reason="It's only in the case where skggm is installed"
"that no warning should be thrown.")
def test_raises_no_warning_installed_skggm(self):
# otherwise we should be able to instantiate and fit SDML and it
# should raise no warning
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
X, y = make_classification(random_state=42)
with pytest.warns(None) as record:
sdml = SDML()
sdml.fit(pairs, y_pairs)
assert len(record) == 0
with pytest.warns(None) as record:
sdml = SDML_Supervised(use_cov=False, balance_param=1e-5)
sdml.fit(X, y)
assert len(record) == 0

def test_iris(self):
# Note: this is a flaky test, which fails for certain seeds.
# TODO: un-flake it!
rs = np.random.RandomState(5555)

sdml = SDML_Supervised(num_constraints=1500, use_cov=False,
balance_param=5e-5)
sdml.fit(self.iris_points, self.iris_labels, random_state=rs)
csep = class_separation(sdml.transform(self.iris_points),
self.iris_labels)
self.assertLess(csep, 0.22)

def test_deprecation_num_labeled(self):
# test that a deprecation message is thrown if num_labeled is set at
# initialization
# TODO: remove in v.0.6
X, y = make_classification(random_state=42)
sdml_supervised = SDML_Supervised(num_labeled=np.inf, use_cov=False,
balance_param=5e-5)
msg = ('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)

def test_sdml_raises_warning_non_psd(self):
"""Tests that SDML raises a warning on a toy example where we know the
pseudo-covariance matrix is not PSD"""
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
y = [1, -1]
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
msg = ("Warning, the input matrix of graphical lasso is not "
"positive semi-definite (PSD). The algorithm may diverge, "
"and lead to degenerate solutions. "
"To prevent that, try to decrease the balance parameter "
"`balance_param` and/or to set use_covariance=False.")
with pytest.warns(ConvergenceWarning) as raised_warning:
try:
sdml.fit(pairs, y)
except Exception:
pass
# we assert that this warning is in one of the warning raised by the
# estimator
assert msg in list(map(lambda w: str(w.message), raised_warning))

def test_sdml_converges_if_psd(self):
"""Tests that sdml converges on a simple problem where we know the
pseudo-covariance matrix is PSD"""
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y = [1, -1]
sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5)
sdml.fit(pairs, y)
assert np.isfinite(sdml.get_mahalanobis_matrix()).all()

@pytest.mark.skipif(not has_installed_skggm(),
reason="sklearn's graphical_lasso can sometimes not "
"work on some non SPD problems. We test that "
"is works only if skggm is installed.")
def test_sdml_works_on_non_spd_pb_with_skggm(self):
"""Test that SDML works on a certain non SPD problem on which we know
it should work, but scikit-learn's graphical_lasso does not work"""
X, y = load_iris(return_X_y=True)
sdml = SDML_Supervised(balance_param=0.5, sparsity_param=0.01,
use_cov=True)
sdml.fit(X, y)


@pytest.mark.skipif(not has_installed_skggm(),
reason='The message should be printed only if skggm is '
'installed.')
def test_verbose_has_installed_skggm_sdml(capsys):
# Test that if users have installed skggm, a message is printed telling them
# skggm's solver is used (when they use SDML)
# TODO: remove if we don't need skggm anymore
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
sdml = SDML()
sdml.fit(pairs, y_pairs)
out, _ = capsys.readouterr()
assert "SDML will use skggm's solver." in out


@pytest.mark.skipif(not has_installed_skggm(),
reason='The message should be printed only if skggm is '
'installed.')
def test_verbose_has_installed_skggm_sdml_supervised(capsys):
# Test that if users have installed skggm, a message is printed telling them
# skggm's solver is used (when they use SDML_Supervised)
# TODO: remove if we don't need skggm anymore
X, y = make_classification(random_state=42)
sdml = SDML_Supervised()
sdml.fit(X, y)
out, _ = capsys.readouterr()
assert "SDML will use skggm's solver." in out


class TestNCA(MetricTestCase):
Expand Down
4 changes: 3 additions & 1 deletion test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from sklearn import clone
from sklearn.utils.testing import set_random_state

from test.test_utils import ids_metric_learners, metric_learners


Expand Down Expand Up @@ -55,7 +56,8 @@ def test_lsml(self):
def test_sdml(self):
self.assertEqual(str(metric_learn.SDML()),
"SDML(balance_param=0.5, preprocessor=None, "
"sparsity_param=0.01, use_cov=True,\n verbose=False)")
"sparsity_param=0.01, use_cov=True,\n "
"verbose=False)")
self.assertEqual(str(metric_learn.SDML_Supervised()), """
SDML_Supervised(balance_param=0.5, num_constraints=None,
num_labeled='deprecated', preprocessor=None, sparsity_param=0.01,
Expand Down
10 changes: 7 additions & 3 deletions test/test_fit_transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
import unittest
import numpy as np
from sklearn.datasets import load_iris
from numpy.testing import assert_array_almost_equal

from metric_learn import (
LMNN, NCA, LFDA, Covariance, MLKR,
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised)
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised,
MMC_Supervised)


class TestFitTransform(unittest.TestCase):
Expand Down Expand Up @@ -62,12 +64,14 @@ def test_lmnn(self):

def test_sdml_supervised(self):
seed = np.random.RandomState(1234)
sdml = SDML_Supervised(num_constraints=1500)
sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5,
use_cov=False)
sdml.fit(self.X, self.y, random_state=seed)
res_1 = sdml.transform(self.X)

seed = np.random.RandomState(1234)
sdml = SDML_Supervised(num_constraints=1500)
sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5,
use_cov=False)
res_2 = sdml.fit_transform(self.X, self.y, random_state=seed)

assert_array_almost_equal(res_1, res_2)
Expand Down
12 changes: 11 additions & 1 deletion test/test_mahalanobis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def check_is_distance_matrix(pairwise):
assert np.array_equal(pairwise, pairwise.T) # symmetry
assert (pairwise.diagonal() == 0).all() # identity
# triangular inequality
tol = 1e-15
tol = 1e-12
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDML was failing due to the harsh tolerance so I changed it but I think it's still reasonable

assert (pairwise <= pairwise[:, :, np.newaxis] +
pairwise[:, np.newaxis, :] + tol).all()

Expand Down Expand Up @@ -281,5 +281,15 @@ def test_transformer_is_2D(estimator, build_dataset):

# test that it works for 1 feature
trunc_data = input_data[..., :1]
# we drop duplicates that might have been formed, i.e. of the form
# aabc or abcc or aabb for quadruplets, and aa for pairs.
slices = {4: [slice(0, 2), slice(2, 4)], 2: [slice(0, 2)]}
if trunc_data.ndim == 3:
for slice_idx in slices[trunc_data.shape[1]]:
pairs = trunc_data[:, slice_idx, :]
diffs = pairs[:, 1, :] - pairs[:, 0, :]
to_keep = np.nonzero(diffs.ravel())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit difficult to parse. Why do we need these slices? I am a bit lazy to check but maybe it should be made more clear even if it is less efficient (this is a small 1D dataset anyway so we don't care)

also maybe removing things that are very close to being the same (as opposed to exactly the same) would be more robust

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's difficult to parse, I'll change the test copying/pasting for the quadruplets/pairs case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, in commit 31072d3
What do you think ?

trunc_data = trunc_data[to_keep]
labels = labels[to_keep]
model.fit(trunc_data, labels)
assert model.transformer_.shape == (1, 1) # the transformer must be 2D