Skip to content

Commit

Permalink
[MRG] FIX: sdml formulation and solvers (#162)
Browse files Browse the repository at this point in the history
* FIX: make proposal for sdml formulation

* MAINT clearer formulation to make the prior appear

* MAINT call the prior prior

* Use skggm instead of graphical lasso

* Be more severe for the class separation

* Put back verbose param

* MAINT: make more explicit the fact that to use identity (i.e. an SPD matrix) as initialization

* Add skggm as a requirement for SDML

* Add skggm to required packages for travis

* Also add cython as a dependency

* FIX: install all except skggm and then skggm

* Remove cython dependency

* Install skggm only if we have at least python 3.6

* Should work if we want other versions superior to 3.6

* Fix bash >= which should be written -ge

* Deal with tests when skggm is not installed and fix some PEP8 warnings

* replace manual calls of algorithms with tuples_learners

* Remove another call of SDML if skggm is not installed

* FIX fix the test_error_message_tuple_size

* FIX fix test_sdml_supervised

* FIX: fix another sdml test

* FIX quic call for python 2.7

* Fix quic import

* Add Sigma0 initalization (both sigma zero and theta zero should be specified otherwise an error is returned

* Deal with SDML making some tests fail

* Remove epsilon that was unnecessary

* FIX: use latest commit of skggm that fixes the non deterministic problem

* MAINT: add message for SDML when not SPD

* MAINT: add test for error message if skggm not installed

* Try other syntax for installing the right commit of skggm

* MAINT: make sklearn compat sdml test be run only if skggm is installed

* Try another syntax for running travis

* Better bash syntax

* Fix tests by removing duplicates

* FIX: fix for sdml by reducing balance parameter

* FIX: update code to work with old version of numpy that does not have axis for unique

* Remove the need for skggm

* Update travis not to use skggm

* Add a stable init for sklearn checks

* FIX test_sdml_supervised

* Revert "Update travis not to use skggm"

This reverts commit 57b0567.

* Add fallback on skggm

* FIX: fix versions comparison and tests

* MAINT: improve test of no warning

* FIX: fix wrap pairs that was returning column y (we need line y), and fix the example for SDML to not raise another warning

* FIX: force travis to do the right check

* TST: add non SPD test that works with skggm's quic but not sklearn's graphical_lasso

* Try again travis this time installing cython

* Try to make travis work with build_essential

* Try with installing liblapack

* TST: fix tests for when skggm is not installed

* TST: use better pytest skipif syntax

* FIX: fix broken link in README.md

* use rst syntax for link

* use rst syntax for link

* use rst syntax for link

* MAINT: remove test_sdml that was remaining from drafts tests

* TST: remove skipping SDML in test_cross_validation_manual_vs_scikit

* FIX link also in getting started

* Put back right indent

* Remove unnecessary changes

* Nitpick for concatenation and refactor HAS_SKGGM

* ENH: Deal better with errors and skggm/scikit-learn

* Better creation of prior

* Simplification for init of sdml

* Put skggm as optional

* Specify skggm version

* TST: make test about 1 feature arrays more readable

* DOC: fix rst formatting

* DOC: reformulated skggm optional dependency

* TST: give an example for sdml_supervised with skggm where it indeed fails

* TST: fix test that fails weirdly when executing the whole test file and not just the test

* Revert "TST: fix test that fails weirdly when executing the whole test file and not just the test"

This reverts commit 6f5666b.

* Add coverage for all versions of python

* Install pytest-cov for all versions
  • Loading branch information
wdevazelhes authored and bellet committed Mar 22, 2019
1 parent 4e37d7c commit e8c74d0
Show file tree
Hide file tree
Showing 12 changed files with 375 additions and 63 deletions.
19 changes: 10 additions & 9 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ cache: pip
python:
- "2.7"
- "3.4"
- "3.6"
before_install:
- sudo apt-get install liblapack-dev
- pip install --upgrade pip pytest
- pip install wheel
- pip install codecov
- if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]];
then pip install pytest-cov;
- pip install wheel cython numpy scipy scikit-learn codecov pytest-cov
- if [[ ($TRAVIS_PYTHON_VERSION == "3.6") ||
($TRAVIS_PYTHON_VERSION == "2.7")]]; then
pip install git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8;
fi
- pip install numpy scipy scikit-learn
script:
- if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]];
then pytest test --cov;
else pytest test;
fi
# we do coverage for all versions so that codecov will merge them: this
# way we will see that both paths (with or without skggm) are tested
- pytest test --cov;
after_success:
- bash <(curl -s https://codecov.io/bash)

7 changes: 6 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ Metric Learning algorithms in Python.

- Python 2.7+, 3.4+
- numpy, scipy, scikit-learn
- (for running the examples only: matplotlib)

**Optional dependencies**

- For SDML, using skggm will allow the algorithm to solve problematic cases
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
- For running the examples only: matplotlib

**Installation/Setup**

Expand Down
7 changes: 6 additions & 1 deletion doc/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ Alternately, download the source repository and run:

- Python 2.7+, 3.4+
- numpy, scipy, scikit-learn
- (for running the examples only: matplotlib)

**Optional dependencies**

- For SDML, using skggm will allow the algorithm to solve problematic cases
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
- For running the examples only: matplotlib

**Notes**

Expand Down
2 changes: 1 addition & 1 deletion metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ def wrap_pairs(X, constraints):
c = np.array(constraints[2])
d = np.array(constraints[3])
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
y = np.concatenate([np.ones_like(a), -np.ones_like(c)])
pairs = X[constraints]
return pairs, y
81 changes: 69 additions & 12 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
import warnings
import numpy as np
from sklearn.base import TransformerMixin
from sklearn.covariance import graph_lasso
from sklearn.utils.extmath import pinvh
from scipy.linalg import pinvh
from sklearn.covariance import graphical_lasso
from sklearn.exceptions import ConvergenceWarning

from .base_metric import MahalanobisMixin, _PairsClassifierMixin
from .constraints import Constraints, wrap_pairs
from ._util import transformer_from_metric
try:
from inverse_covariance import quic
except ImportError:
HAS_SKGGM = False
else:
HAS_SKGGM = True


class _BaseSDML(MahalanobisMixin):
Expand Down Expand Up @@ -52,24 +59,74 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
super(_BaseSDML, self).__init__(preprocessor)

def _fit(self, pairs, y):
if not HAS_SKGGM:
if self.verbose:
print("SDML will use scikit-learn's graphical lasso solver.")
else:
if self.verbose:
print("SDML will use skggm's graphical lasso solver.")
pairs, y = self._prepare_inputs(pairs, y,
type_of_inputs='tuples')

# set up prior M
# set up (the inverse of) the prior M
if self.use_cov:
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
prior_inv = np.atleast_2d(np.cov(X, rowvar=False))
else:
M = np.identity(pairs.shape[2])
prior_inv = np.identity(pairs.shape[2])
diff = pairs[:, 0] - pairs[:, 1]
loss_matrix = (diff.T * y).dot(diff)
P = M + self.balance_param * loss_matrix
emp_cov = pinvh(P)
# hack: ensure positive semidefinite
emp_cov = emp_cov.T.dot(emp_cov)
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)

self.transformer_ = transformer_from_metric(M)
emp_cov = prior_inv + self.balance_param * loss_matrix

# our initialization will be the matrix with emp_cov's eigenvalues,
# with a constant added so that they are all positive (plus an epsilon
# to ensure definiteness). This is empirical.
w, V = np.linalg.eigh(emp_cov)
min_eigval = np.min(w)
if min_eigval < 0.:
warnings.warn("Warning, the input matrix of graphical lasso is not "
"positive semi-definite (PSD). The algorithm may diverge, "
"and lead to degenerate solutions. "
"To prevent that, try to decrease the balance parameter "
"`balance_param` and/or to set use_covariance=False.",
ConvergenceWarning)
w -= min_eigval # we translate the eigenvalues to make them all positive
w += 1e-10 # we add a small offset to avoid definiteness problems
sigma0 = (V * w).dot(V.T)
try:
if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
raised_error = None
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
not_finite = not np.isfinite(M).all()
except Exception as e:
raised_error = e
not_spd = False # not_spd not applicable here so we set to False
not_finite = False # not_finite not applicable here so we set to False
if raised_error is not None or not_spd or not_finite:
msg = ("There was a problem in SDML when using {}'s graphical "
"lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn")
if not HAS_SKGGM:
skggm_advice = (" skggm's graphical lasso can sometimes converge "
"on non SPD cases where scikit-learn's graphical "
"lasso fails to converge. Try to install skggm and "
"rerun the algorithm (see the README.md for the "
"right version of skggm).")
msg += skggm_advice
if raised_error is not None:
msg += " The following error message was thrown: {}.".format(
raised_error)
raise RuntimeError(msg)

self.transformer_ = transformer_from_metric(np.atleast_2d(M))
return self


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
extras_require=dict(
docs=['sphinx', 'shinx_rtd_theme', 'numpydoc'],
demo=['matplotlib'],
sdml=['skggm>=0.2.9']
),
test_suite='test',
keywords=[
Expand Down

0 comments on commit e8c74d0

Please sign in to comment.