diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 7a04923d..83d2f9d8 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -10,6 +10,7 @@ from __future__ import absolute_import import numpy as np +import scipy from sklearn.base import TransformerMixin from .base_metric import MahalanobisMixin @@ -35,11 +36,11 @@ def fit(self, X, y=None): y : unused """ X = self._prepare_inputs(X, ensure_min_samples=2) - M = np.cov(X, rowvar = False) - if M.ndim == 0: - M = 1./M + M = np.atleast_2d(np.cov(X, rowvar=False)) + if M.size == 1: + M = 1. / M else: - M = np.linalg.inv(M) + M = scipy.linalg.pinvh(M) self.transformer_ = transformer_from_metric(np.atleast_2d(M)) return self diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index a785d60d..26e204ea 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -6,7 +6,8 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.datasets import load_iris, make_classification, make_regression -from numpy.testing import assert_array_almost_equal, assert_array_equal +from numpy.testing import (assert_array_almost_equal, assert_array_equal, + assert_allclose) from sklearn.utils.testing import assert_warns_message from sklearn.exceptions import ConvergenceWarning from sklearn.utils.validation import check_X_y @@ -53,6 +54,23 @@ def test_iris(self): # deterministic result self.assertAlmostEqual(csep, 0.72981476) + def test_singular_returns_pseudo_inverse(self): + """Checks that if the input covariance matrix is singular, we return + the pseudo inverse""" + X, y = load_iris(return_X_y=True) + # We add a virtual column that is a linear combination of the other + # columns so that the covariance matrix will be singular + X = np.concatenate([X, X[:, :2].dot([[2], [3]])], axis=1) + cov_matrix = np.cov(X, rowvar=False) + covariance = Covariance() + covariance.fit(X) + pseudo_inverse = covariance.get_mahalanobis_matrix() + # here is the definition of a pseudo inverse according to wikipedia: + assert_allclose(cov_matrix.dot(pseudo_inverse).dot(cov_matrix), + cov_matrix) + assert_allclose(pseudo_inverse.dot(cov_matrix).dot(pseudo_inverse), + pseudo_inverse) + class TestLSML(MetricTestCase): def test_iris(self):