Skip to content

Commit

Permalink
Fix covariance initialization when matrix is not invertible (#277)
Browse files Browse the repository at this point in the history
* Fix covariance init when matrix is not invertible

* replaced import scipy for only required functions

* Change inv for pseudo-inv on custom matrix init

* Change from EVD to SVD

* Roll back to EVD and pseudo inverse of EVD

* Fix non-ASCII char

* rephrasing warnings

* added tests

* more rephrasing

* fix test

* add test

* fixes & adds singular pinv test fron eig

* fix tolerance of assert

* fix tolerance of assert

* fix tolerance of assert

* fix random seed

* isolate random seed setting
  • Loading branch information
grudloff committed Feb 4, 2020
1 parent 2380f51 commit e739239
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 15 deletions.
62 changes: 53 additions & 9 deletions metric_learn/_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import scipy
import six
from numpy.linalg import LinAlgError
from sklearn.datasets import make_spd_matrix
Expand All @@ -8,9 +7,10 @@
from sklearn.utils.validation import check_X_y, check_random_state
from .exceptions import PreprocessorError, NonPSDError
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from scipy.linalg import pinvh
from scipy.linalg import pinvh, eigh
import sys
import time
import warnings

# hack around lack of axis kwarg in older numpy versions
try:
Expand Down Expand Up @@ -678,17 +678,20 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,

random_state = check_random_state(random_state)
M = init
if isinstance(init, np.ndarray):
s, u = scipy.linalg.eigh(init)
init_is_definite = _check_sdp_from_eigen(s)
if isinstance(M, np.ndarray):
w, V = eigh(M, check_finite=False)
init_is_definite = _check_sdp_from_eigen(w)
if strict_pd and not init_is_definite:
raise LinAlgError("You should provide a strictly positive definite "
"matrix as `{}`. This one is not definite. Try another"
" {}, or an algorithm that does not "
"require the {} to be strictly positive definite."
.format(*((matrix_name,) * 3)))
elif return_inverse and not init_is_definite:
warnings.warn('The initialization matrix is not invertible: '
'using the pseudo-inverse instead.')
if return_inverse:
M_inv = np.dot(u / s, u.T)
M_inv = _pseudo_inverse_from_eig(w, V)
return M, M_inv
else:
return M
Expand All @@ -707,15 +710,23 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
X = input
# atleast2d is necessary to deal with scalar covariance matrices
M_inv = np.atleast_2d(np.cov(X, rowvar=False))
s, u = scipy.linalg.eigh(M_inv)
cov_is_definite = _check_sdp_from_eigen(s)
w, V = eigh(M_inv, check_finite=False)
cov_is_definite = _check_sdp_from_eigen(w)
if strict_pd and not cov_is_definite:
raise LinAlgError("Unable to get a true inverse of the covariance "
"matrix since it is not definite. Try another "
"`{}`, or an algorithm that does not "
"require the `{}` to be strictly positive definite."
.format(*((matrix_name,) * 2)))
M = np.dot(u / s, u.T)
elif not cov_is_definite:
warnings.warn('The covariance matrix is not invertible: '
'using the pseudo-inverse instead.'
'To make the covariance matrix invertible'
' you can remove any linearly dependent features and/or '
'reduce the dimensionality of your input, '
'for instance using `sklearn.decomposition.PCA` as a '
'preprocessing step.')
M = _pseudo_inverse_from_eig(w, V)
if return_inverse:
return M, M_inv
else:
Expand All @@ -742,3 +753,36 @@ def _check_n_components(n_features, n_components):
if 0 < n_components <= n_features:
return n_components
raise ValueError('Invalid n_components, must be in [1, %d]' % n_features)


def _pseudo_inverse_from_eig(w, V, tol=None):
"""Compute the (Moore-Penrose) pseudo-inverse of the EVD of a symetric
matrix.
Parameters
----------
w : (..., M) ndarray
The eigenvalues in ascending order, each repeated according to
its multiplicity.
v : {(..., M, M) ndarray, (..., M, M) matrix}
The column ``v[:, i]`` is the normalized eigenvector corresponding
to the eigenvalue ``w[i]``. Will return a matrix object if `a` is
a matrix object.
tol : positive `float`, optional
Absolute eigenvalues below tol are considered zero.
Returns
-------
output : (..., M, N) array_like
The pseudo-inverse given by the EVD.
"""
if tol is None:
tol = np.amax(w) * np.max(w.shape) * np.finfo(w.dtype).eps
# discard small eigenvalues and invert the rest
large = np.abs(w) > tol
w = np.divide(1, w, where=large, out=w)
w[~large] = 0

return np.dot(V * w, np.conjugate(V).T)
77 changes: 72 additions & 5 deletions test/test_mahalanobis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from scipy.stats import ortho_group
from sklearn import clone
from sklearn.cluster import DBSCAN
from sklearn.datasets import make_spd_matrix
from sklearn.utils import check_random_state
from sklearn.datasets import make_spd_matrix, make_blobs
from sklearn.utils import check_random_state, shuffle
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.testing import set_random_state

from metric_learn._util import make_context
from metric_learn._util import make_context, _initialize_metric_mahalanobis
from metric_learn.base_metric import (_QuadrupletsClassifierMixin,
_PairsClassifierMixin)
from metric_learn.exceptions import NonPSDError
Expand Down Expand Up @@ -569,7 +569,7 @@ def test_init_mahalanobis(estimator, build_dataset):
in zip(ids_metric_learners,
metric_learners)
if idml[:4] in ['ITML', 'SDML', 'LSML']])
def test_singular_covariance_init_or_prior(estimator, build_dataset):
def test_singular_covariance_init_or_prior_strictpd(estimator, build_dataset):
"""Tests that when using the 'covariance' init or prior, it returns the
appropriate error if the covariance matrix is singular, for algorithms
that need a strictly PD prior or init (see
Expand Down Expand Up @@ -603,6 +603,48 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset):
assert str(raised_err.value) == msg


@pytest.mark.integration
@pytest.mark.parametrize('estimator, build_dataset',
[(ml, bd) for idml, (ml, bd)
in zip(ids_metric_learners,
metric_learners)
if idml[:3] in ['MMC']],
ids=[idml for idml, (ml, _)
in zip(ids_metric_learners,
metric_learners)
if idml[:3] in ['MMC']])
def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset):
"""Tests that when using the 'covariance' init or prior, it returns the
appropriate warning if the covariance matrix is singular, for algorithms
that don't need a strictly PD init. Also checks that the returned
inverse matrix has finite values
"""
input_data, labels, _, X = build_dataset()
model = clone(estimator)
set_random_state(model)
# We create a feature that is a linear combination of the first two
# features:
input_data = np.concatenate([input_data, input_data[:, ..., :2].dot([[2],
[3]])],
axis=-1)
model.set_params(init='covariance')
msg = ('The covariance matrix is not invertible: '
'using the pseudo-inverse instead.'
'To make the covariance matrix invertible'
' you can remove any linearly dependent features and/or '
'reduce the dimensionality of your input, '
'for instance using `sklearn.decomposition.PCA` as a '
'preprocessing step.')
with pytest.warns(UserWarning) as raised_warning:
model.fit(input_data, labels)
assert np.any([str(warning.message) == msg for warning in raised_warning])
M, _ = _initialize_metric_mahalanobis(X, init='covariance',
random_state=RNG,
return_inverse=True,
strict_pd=False)
assert np.isfinite(M).all()


@pytest.mark.integration
@pytest.mark.parametrize('estimator, build_dataset',
[(ml, bd) for idml, (ml, bd)
Expand All @@ -614,7 +656,7 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset):
metric_learners)
if idml[:4] in ['ITML', 'SDML', 'LSML']])
@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20])
def test_singular_array_init_or_prior(estimator, build_dataset, w0):
def test_singular_array_init_or_prior_strictpd(estimator, build_dataset, w0):
"""Tests that when using a custom array init (or prior), it returns the
appropriate error if it is singular, for algorithms
that need a strictly PD prior or init (see
Expand Down Expand Up @@ -654,6 +696,31 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0):
assert str(raised_err.value) == msg


@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20])
def test_singular_array_init_of_non_strict_pd(w0):
"""Tests that when using a custom array init, it returns the
appropriate warning if it is singular. Also checks if the returned
inverse matrix is finite. This isn't checked for model fitting as no
model curently uses this setting.
"""
rng = np.random.RandomState(42)
X, y = shuffle(*make_blobs(random_state=rng),
random_state=rng)
P = ortho_group.rvs(X.shape[1], random_state=rng)
w = np.abs(rng.randn(X.shape[1]))
w[0] = w0
M = P.dot(np.diag(w)).dot(P.T)
msg = ('The initialization matrix is not invertible: '
'using the pseudo-inverse instead.')
with pytest.warns(UserWarning) as raised_warning:
_, M_inv = _initialize_metric_mahalanobis(X, init=M,
random_state=rng,
return_inverse=True,
strict_pd=False)
assert str(raised_warning[0].message) == msg
assert np.isfinite(M_inv).all()


@pytest.mark.integration
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
ids=ids_metric_learners)
Expand Down
27 changes: 26 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from scipy.linalg import eigh, pinvh
from collections import namedtuple
import numpy as np
from numpy.testing import assert_array_equal, assert_equal
Expand All @@ -11,7 +12,7 @@
check_collapsed_pairs, validate_vector,
_check_sdp_from_eigen, _check_n_components,
check_y_valid_values_for_pairs,
_auto_select_init)
_auto_select_init, _pseudo_inverse_from_eig)
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
MMC_Supervised, RCA_Supervised, SDML_Supervised,
Expand Down Expand Up @@ -1150,3 +1151,27 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components,
"""Checks that the auto selection of the init works as expected"""
assert (_auto_select_init(has_classes, n_features,
n_samples, n_components, n_classes) == result)


@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20])
def test_pseudo_inverse_from_eig_and_pinvh_singular(w0):
"""Checks that _pseudo_inverse_from_eig returns the same result as
scipy.linalg.pinvh for a singular matrix"""
rng = np.random.RandomState(SEED)
A = rng.rand(100, 100)
A = A + A.T
w, V = eigh(A)
w[0] = w0
A = V.dot(np.diag(w)).dot(V.T)
np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A),
rtol=1e-05)


def test_pseudo_inverse_from_eig_and_pinvh_nonsingular():
"""Checks that _pseudo_inverse_from_eig returns the same result as
scipy.linalg.pinvh for a non singular matrix"""
rng = np.random.RandomState(SEED)
A = rng.rand(100, 100)
A = A + A.T
w, V = eigh(A, check_finite=False)
np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A))

0 comments on commit e739239

Please sign in to comment.