From dd68aa343f31caba87440b16b9a7517a7d45333f Mon Sep 17 00:00:00 2001 From: Timothy Click Date: Tue, 26 Feb 2019 16:44:58 +0800 Subject: [PATCH] Changed to safe_sqr and corrected an import bug in test_svd. Signed-off-by: Timothy Click --- src/fluctmatch/decomposition/svd.py | 4 ++-- tests/decomposition/test_svd.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/fluctmatch/decomposition/svd.py b/src/fluctmatch/decomposition/svd.py index df6fd45..6ed5d7d 100644 --- a/src/fluctmatch/decomposition/svd.py +++ b/src/fluctmatch/decomposition/svd.py @@ -21,7 +21,7 @@ from scipy import linalg from sklearn.base import BaseEstimator, TransformerMixin from sklearn.decomposition.base import _BasePCA -from sklearn.utils import check_array +from sklearn.utils import check_array, safe_sqr from sklearn.utils.validation import check_is_fitted from sklearn.utils.extmath import fast_logdet, svd_flip @@ -135,7 +135,7 @@ def _fit(self, X: np.ndarray): U, Vt = svd_flip(U, Vt) # Get variance explained by singular values - explained_variance_: np.ndarray = (S ** 2) / (n_samples - 1) + explained_variance_: np.ndarray = safe_sqr(S) / (n_samples - 1) total_var: float = explained_variance_.sum() explained_variance_ratio_: np.ndarray = explained_variance_ / total_var singular_values_: np.ndarray = S.copy() # Store the singular values. diff --git a/tests/decomposition/test_svd.py b/tests/decomposition/test_svd.py index 8357ab6..99c3afb 100644 --- a/tests/decomposition/test_svd.py +++ b/tests/decomposition/test_svd.py @@ -31,10 +31,10 @@ # These tests were taken from the scikit-learn tests for PCA. import numpy as np -from numpy.testing import assert_almost_equal -from numpy.testing import assert_array_almost_equal -from numpy.testing import assert_greater -from numpy.testing import assert_no_warnings +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_no_warnings from sklearn import datasets from fluctmatch.decomposition.svd import SVD