diff --git a/examples/applications/plot_face_recognition.py b/examples/applications/plot_face_recognition.py index b80cc7fd5e5a4..2d17a9f719343 100644 --- a/examples/applications/plot_face_recognition.py +++ b/examples/applications/plot_face_recognition.py @@ -34,7 +34,7 @@ from scikits.learn.grid_search import GridSearchCV from scikits.learn.metrics import classification_report from scikits.learn.metrics import confusion_matrix -from scikits.learn.pca import PCA +from scikits.learn.pca import RandomizedPCA from scikits.learn.svm import SVC ################################################################################ @@ -115,7 +115,7 @@ n_components = 150 print "Extracting the top %d eigenfaces" % n_components -pca = PCA(n_comp=n_components, whiten=True, do_fast_svd=True).fit(X_train) +pca = RandomizedPCA(n_components=n_components, whiten=True).fit(X_train) eigenfaces = pca.components_.T.reshape((n_components, 64, 64)) diff --git a/examples/cluster/kmeans_digits.py b/examples/cluster/kmeans_digits.py index ba63405f3b02b..c902413f7fb2e 100644 --- a/examples/cluster/kmeans_digits.py +++ b/examples/cluster/kmeans_digits.py @@ -51,7 +51,7 @@ # in this case the seeding of the centers is deterministic, hence we run the # kmeans algorithm only once with n_init=1 t0 = time() -pca = PCA(n_comp=n_digits).fit(data) +pca = PCA(n_components=n_digits).fit(data) km = KMeans(init=pca.components_.T, k=n_digits, n_init=1).fit(data) print "done in %0.3fs" % (time() - t0) print "inertia: %f" % km.inertia_ diff --git a/examples/plot_pca.py b/examples/plot_pca.py index 5832c0bd7db9b..7e40b2c73674b 100644 --- a/examples/plot_pca.py +++ b/examples/plot_pca.py @@ -25,7 +25,7 @@ y = iris.target target_names = iris.target_names -pca = PCA(n_comp=2) +pca = PCA(n_components=2) X_r = pca.fit(X).transform(X) # Percentage of variance explained for each components diff --git a/scikits/learn/fastica.py b/scikits/learn/fastica.py index 6bc42f425a8d7..3e5c8801d13a9 100644 --- a/scikits/learn/fastica.py +++ b/scikits/learn/fastica.py @@ -55,11 +55,11 @@ def _ica_def(X, tol, g, gprime, fun_args, max_iter, w_init): Used internally by FastICA. """ - n_comp = w_init.shape[0] - W = np.zeros((n_comp, n_comp), dtype=float) + n_components = w_init.shape[0] + W = np.zeros((n_components, n_components), dtype=float) # j is the index of the extracted component - for j in range(n_comp): + for j in range(n_components): w = w_init[j, :].copy() w /= np.sqrt((w**2).sum()) @@ -114,7 +114,7 @@ def _ica_par(X, tol, g, gprime, fun_args, max_iter, w_init): return W -def fastica(X, n_comp=None, algorithm="parallel", whiten=True, +def fastica(X, n_components=None, algorithm="parallel", whiten=True, fun="logcosh", fun_prime='', fun_args={}, max_iter=200, tol=1e-04, w_init=None): """Perform Fast Independent Component Analysis. @@ -124,7 +124,7 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True, X : (n, p) array of shape = [n_samples, n_features] Training vector, where n_samples is the number of samples and n_features is the number of features. - n_comp : int, optional + n_components : int, optional Number of components to extract. If None no dimension reduction is performed. algorithm : {'parallel','deflation'} @@ -151,22 +151,22 @@ def fastica(X, n_comp=None, algorithm="parallel", whiten=True, tol : float A positive scalar giving the tolerance at which the un-mixing matrix is considered to have converged - w_init : (n_comp,n_comp) array + w_init : (n_components,n_components) array Initial un-mixing array of dimension (n.comp,n.comp). If None (default) then an array of normal r.v.'s is used source_only: if True, only the sources matrix is returned Results ------- - K : (n_comp, p) array + K : (n_components, p) array pre-whitening matrix that projects data onto th first n.comp principal components. Returned only if whiten is True - W : (n_comp, n_comp) array + W : (n_components, n_components) array estimated un-mixing matrix The mixing matrix can be obtained by:: w = np.dot(W, K.T) A = w.T * (w * w.T).I - S : (n_comp, n) array + S : (n_components, n) array estimated source matrix @@ -227,11 +227,11 @@ def gprime(x, fun_args): n, p = X.shape - if n_comp is None: - n_comp = min(n, p) - if (n_comp > min(n, p)): - n_comp = min(n, p) - print("n_comp is too large: it will be set to %s" % n_comp) + if n_components is None: + n_components = min(n, p) + if (n_components > min(n, p)): + n_components = min(n, p) + print("n_components is too large: it will be set to %s" % n_components) if whiten: # Centering the columns (ie the variables) @@ -241,7 +241,7 @@ def gprime(x, fun_args): u, d, _ = linalg.svd(X, full_matrices=False) del _ - K = (u/d).T[:n_comp] # see (6.33) p.140 + K = (u/d).T[:n_components] # see (6.33) p.140 del u, d X1 = np.dot(K, X) # see (13.6) p.267 Here X1 is white and data @@ -251,12 +251,12 @@ def gprime(x, fun_args): X1 *= np.sqrt(p) if w_init is None: - w_init = np.random.normal(size=(n_comp, n_comp)) + w_init = np.random.normal(size=(n_components, n_components)) else: w_init = np.asarray(w_init) - if w_init.shape != (n_comp, n_comp): + if w_init.shape != (n_components, n_components): raise ValueError("w_init has invalid shape -- should be %(shape)s" - % {'shape': (n_comp, n_comp)}) + % {'shape': (n_components, n_components)}) kwargs = {'tol': tol, 'g': g, @@ -283,7 +283,7 @@ class FastICA(BaseEstimator): Parameters ---------- - n_comp : int, optional + n_components : int, optional Number of components to use. If none is passed, all are used. algorithm: {'parallel', 'deflation'} Apply parallel or deflational algorithm for FastICA @@ -300,12 +300,12 @@ class FastICA(BaseEstimator): Maximum number of iterations during fit tol : float, optional Tolerance on update at each iteration - w_init: None of an (n_comp, n_comp) ndarray + w_init: None of an (n_components, n_components) ndarray The mixing matrix to be used to initialize the algorithm. Attributes ---------- - unmixing_matrix_ : 2D array, [n_comp, n_samples] + unmixing_matrix_ : 2D array, [n_components, n_samples] Methods ------- @@ -322,11 +322,11 @@ class FastICA(BaseEstimator): """ - def __init__(self, n_comp=None, algorithm='parallel', whiten=True, + def __init__(self, n_components=None, algorithm='parallel', whiten=True, fun='logcosh', fun_prime='', fun_args={}, max_iter=200, tol=1e-4, w_init=None): super(FastICA, self).__init__() - self.n_comp = n_comp + self.n_components = n_components self.algorithm = algorithm self.whiten = whiten self.fun = fun @@ -338,7 +338,7 @@ def __init__(self, n_comp=None, algorithm='parallel', whiten=True, def fit(self, X, **params): self._set_params(**params) - whitening_, unmixing_, sources_ = fastica(X, self.n_comp, + whitening_, unmixing_, sources_ = fastica(X, self.n_components, self.algorithm, self.whiten, self.fun, self.fun_prime, self.fun_args, self.max_iter, self.tol, self.w_init) diff --git a/scikits/learn/pca.py b/scikits/learn/pca.py index 1a19ebfaca59b..00bab94e3b23f 100644 --- a/scikits/learn/pca.py +++ b/scikits/learn/pca.py @@ -2,14 +2,16 @@ """ # Author: Alexandre Gramfort +# Olivier Grisel # License: BSD Style. -import warnings import numpy as np from scipy import linalg from .base import BaseEstimator -from .utils.extmath import fast_logdet, fast_svd +from .utils.extmath import fast_logdet +from .utils.extmath import fast_svd +from .utils.extmath import safe_sparse_dot def _assess_dimension_(spectrum, rank, n_samples, dim): @@ -86,7 +88,6 @@ def _infer_dimension_(spectrum, n, p): return ll.argmax() -################################################################################ class PCA(BaseEstimator): """Principal component analysis (PCA) @@ -94,55 +95,51 @@ class PCA(BaseEstimator): data and keeping only the most significant singular vectors to project the data to a lower dimensional space. + This implementation uses the scipy.linalg implementation of the singular + value decomposition. It only works for dense arrays and is not scalable to + large dimensional data. + + The time complexity of this implementation is O(n ** 3) assuming + n ~ n_samples ~ n_features. + Parameters ---------- X: array-like, shape (n_samples, n_features) Training vector, where n_samples in the number of samples and n_features is the number of features. - Attributes - ---------- - n_comp: int, none or string - Number of components - if n_comp is not set all components are kept - if n_comp=='mle', Minka's MLE is used to guess the dimension + n_components: int, none or string + Number of components to keep. + if n_components is not set all components are kept: + n_components == min(n_samples, n_features) + if n_components == 'mle', Minka's MLE is used to guess the dimension copy: bool If False, data passed to fit are overwritten - components_: array, [n_features, n_comp] - Components with maximum variance. - - do_fast_svd: bool, optional - If True, the k-truncated SVD is computed using random projections - which speeds up the computation on large arrays. If all the - components are to be computed (as in n_comp=None or - n_comp='mle'), this option has no effects. Note that the solution will - be correct only if the requested n_comp is as large as the approximate - effective rank of the data. - - explained_variance_: array, [n_comp] - Percentage of variance explained by each of the selected components. - k is not set then all components are stored and the sum of - explained variances is equal to 1.0 - whiten: bool, optional When True (False by default) the components_ vectors are divided - by the singular values to ensure uncorrelated outputs with unit - component-wise variances. + by n_samples times singular values to ensure uncorrelated outputs + with unit component-wise variances. Whitening will remove some information from the transformed signal (the relative variance scales of the components) but can sometime improve the predictive accuracy of the downstream estimators by making there data respect some hard-wired assumptions. - iterated_power: int, optional - Number of iteration for the power method if do_fast_svd is True. 3 by - default. + Attributes + ---------- + components_: array, [n_features, n_components] + Components with maximum variance. + + explained_variance_ratio_: array, [n_components] + Percentage of variance explained by each of the selected components. + k is not set then all components are stored and the sum of + explained variances is equal to 1.0 Notes ----- - For n_comp='mle', this class uses the method of Thomas P. Minka: + For n_components='mle', this class uses the method of Thomas P. Minka: Automatic Choice of Dimensionality for PCA. NIPS 2000: 598-604 Examples @@ -150,23 +147,21 @@ class PCA(BaseEstimator): >>> import numpy as np >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) >>> from scikits.learn.pca import PCA - >>> pca = PCA(n_comp=2) + >>> pca = PCA(n_components=2) >>> pca.fit(X) - PCA(do_fast_svd=False, n_comp=2, copy=True, whiten=False, iterated_power=3) + PCA(copy=True, n_components=2, whiten=False) >>> print pca.explained_variance_ratio_ [ 0.99244289 0.00755711] See also -------- ProbabilisticPCA + RandomizedPCA """ - def __init__(self, n_comp=None, copy=True, do_fast_svd=False, - iterated_power=3, whiten=False): - self.n_comp = n_comp + def __init__(self, n_components=None, copy=True, whiten=False): + self.n_components = n_components self.copy = copy - self.do_fast_svd = do_fast_svd - self.iterated_power = iterated_power self.whiten = whiten def fit(self, X, **params): @@ -179,15 +174,7 @@ def fit(self, X, **params): # Center data self.mean_ = np.mean(X, axis=0) X -= self.mean_ - if self.do_fast_svd: - if self.n_comp == "mle" or self.n_comp is None: - warnings.warn('All components are to be computed' - 'Not using fast truncated SVD') - U, S, V = linalg.svd(X, full_matrices=False) - else: - U, S, V = fast_svd(X, self.n_comp, q=self.iterated_power) - else: - U, S, V = linalg.svd(X, full_matrices=False) + U, S, V = linalg.svd(X, full_matrices=False) self.explained_variance_ = (S ** 2) / n_samples self.explained_variance_ratio_ = self.explained_variance_ / \ self.explained_variance_.sum() @@ -198,15 +185,16 @@ def fit(self, X, **params): else: self.components_ = V.T - if self.n_comp == 'mle': - self.n_comp = _infer_dimension_(self.explained_variance_, + if self.n_components == 'mle': + self.n_components = _infer_dimension_(self.explained_variance_, n_samples, X.shape[1]) - if self.n_comp is not None: - self.components_ = self.components_[:, :self.n_comp] - self.explained_variance_ = self.explained_variance_[:self.n_comp] - self.explained_variance_ratio_ = self.explained_variance_ratio_[ - :self.n_comp] + if self.n_components is not None: + self.components_ = self.components_[:, :self.n_components] + self.explained_variance_ = \ + self.explained_variance_[:self.n_components] + self.explained_variance_ratio_ = \ + self.explained_variance_ratio_[:self.n_components] return self @@ -217,7 +205,6 @@ def transform(self, X): return Xr -################################################################################ class ProbabilisticPCA(PCA): """Additional layer on top of PCA that add a probabilistic evaluation @@ -238,14 +225,14 @@ def fit(self, X, homoscedastic=True): Xr = X - self.mean_ Xr -= np.dot(np.dot(Xr, self.components_), self.components_.T) n_samples = X.shape[0] - if self.dim <= self.n_comp: + if self.dim <= self.n_components: delta = np.zeros(self.dim) elif homoscedastic: delta = (Xr ** 2).sum() / (n_samples*(self.dim)) * np.ones(self.dim) else: - delta = (Xr ** 2).mean(0) / (self.dim - self.n_comp) + delta = (Xr ** 2).mean(0) / (self.dim - self.n_components) self.covariance_ = np.diag(delta) - for k in range(self.n_comp): + for k in range(self.n_components): add_cov = np.dot( self.components_[:, k:k+1], self.components_[:, k:k+1].T) self.covariance_ += self.explained_variance_[k] * add_cov @@ -272,3 +259,123 @@ def score(self, X): log_like += fast_logdet(self.precision_) - \ self.dim / 2 * np.log(2 * np.pi) return log_like + + +class RandomizedPCA(BaseEstimator): + """Principal component analysis (PCA) using randomized SVD + + Linear dimensionality reduction using approximated Singular Value + Decomposition of the data and keeping only the most significant + singular vectors to project the data to a lower dimensional space. + + This implementation uses a randomized SVD implementation and can + handle both scipy.sparse and numpy dense arrays as input. + + Parameters + ---------- + X: array-like or scipy.sparse matrix, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples and + n_features is the number of features. + + n_components: int + Maximum number of components to keep: default is 50. + + copy: bool + If False, data passed to fit are overwritten + + iterated_power: int, optional + Number of iteration for the power method. 3 by default. + + whiten: bool, optional + When True (False by default) the components_ vectors are divided + by the singular values to ensure uncorrelated outputs with unit + component-wise variances. + + Whitening will remove some information from the transformed signal + (the relative variance scales of the components) but can sometime + improve the predictive accuracy of the downstream estimators by + making there data respect some hard-wired assumptions. + + Attributes + ---------- + components_: array, [n_features, n_components] + Components with maximum variance. + + explained_variance_ratio_: array, [n_components] + Percentage of variance explained by each of the selected components. + k is not set then all components are stored and the sum of + explained variances is equal to 1.0 + + References + ----- + Finding structure with randomness: Stochastic algorithms for constructing + approximate matrix decompositions + Halko, et al., 2009 (arXiv:909) + + A randomized algorithm for the decomposition of matrices + Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert + + Examples + -------- + >>> import numpy as np + >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) + >>> from scikits.learn.pca import PCA + >>> pca = RandomizedPCA(n_components=2) + >>> pca.fit(X) + RandomizedPCA(copy=True, n_components=2, iterated_power=3, whiten=False) + >>> print pca.explained_variance_ratio_ + [ 0.99244289 0.00755711] + + See also + -------- + PCA + ProbabilisticPCA + """ + + def __init__(self, n_components, copy=True, iterated_power=3, + whiten=False): + self.n_components = n_components + self.copy = copy + self.iterated_power = iterated_power + self.whiten = whiten + self.mean_ = None + + def fit(self, X, **params): + """Fit the model to the data X""" + self._set_params(**params) + n_samples = X.shape[0] + + if self.copy: + X = X.copy() + + if not hasattr(X, 'todense'): + # not a sparse matrix, ensure this is a 2D array + X = np.atleast_2d(X) + + # Center data + self.mean_ = np.mean(X, axis=0) + X -= self.mean_ + + U, S, V = fast_svd(X, self.n_components, q=self.iterated_power) + + self.explained_variance_ = (S ** 2) / n_samples + self.explained_variance_ratio_ = self.explained_variance_ / \ + self.explained_variance_.sum() + + if self.whiten: + n = X.shape[0] + self.components_ = np.dot(V.T, np.diag(1.0 / S)) * np.sqrt(n) + else: + self.components_ = V.T + + return self + + def transform(self, X): + """Apply the dimension reduction learned on the training data.""" + if self.mean_ is not None: + X = X - self.mean_ + + X = safe_sparse_dot(X, self.components_) + return X + + diff --git a/scikits/learn/tests/test_fastica.py b/scikits/learn/tests/test_fastica.py index a3fad0acdb22e..6bf7871f7eed6 100644 --- a/scikits/learn/tests/test_fastica.py +++ b/scikits/learn/tests/test_fastica.py @@ -121,7 +121,7 @@ def test_non_square_fastica(add_noise=False): center_and_norm(m) - k_, mixing_, s_ = fastica.fastica(m, n_comp=2) + k_, mixing_, s_ = fastica.fastica(m, n_components=2) # Check that the mixing model described in the docstring holds: np.testing.assert_almost_equal(s_, np.dot(np.dot(mixing_, k_), m)) diff --git a/scikits/learn/tests/test_pca.py b/scikits/learn/tests/test_pca.py index 1282213625caa..713e66746e165 100644 --- a/scikits/learn/tests/test_pca.py +++ b/scikits/learn/tests/test_pca.py @@ -3,17 +3,22 @@ from nose.tools import assert_true from nose.tools import assert_equal +from scipy.sparse import csr_matrix from numpy.testing import assert_almost_equal from .. import datasets -from ..pca import PCA, ProbabilisticPCA, _assess_dimension_, _infer_dimension_ +from ..pca import PCA +from ..pca import ProbabilisticPCA +from ..pca import RandomizedPCA +from ..pca import _assess_dimension_ +from ..pca import _infer_dimension_ iris = datasets.load_iris() def test_pca(): """PCA on dense arrays""" - pca = PCA(n_comp=2) + pca = PCA(n_components=2) X = iris.data X_r = pca.fit(X).transform(X) np.testing.assert_equal(X_r.shape[1], 2) @@ -45,7 +50,7 @@ def test_whitening(): assert_almost_equal(X.std(axis=0).std(), 43.9, 1) # whiten the data while projecting to the lower dim subspace - pca = PCA(n_comp=n_components, whiten=True).fit(X) + pca = PCA(n_components=n_components, whiten=True).fit(X) X_whitened = pca.transform(X) assert_equal(X_whitened.shape, (n_samples, n_components)) @@ -54,7 +59,7 @@ def test_whitening(): # is possible to project on the low dim space without scaling by the # singular values - pca = PCA(n_comp=n_components, whiten=False).fit(X) + pca = PCA(n_components=n_components, whiten=False).fit(X) X_unwhitened = pca.transform(X) assert_equal(X_unwhitened.shape, (n_samples, n_components)) @@ -67,24 +72,39 @@ def test_pca_check_projection(): n, p = 100, 3 X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5]) - pca = PCA(n_comp=2) - pca.fit(X) - Xt = 0.1* randn(1, p) + np.array([3, 4, 5]) - Yt = pca.transform(Xt) + Xt = 0.1 * randn(1, p) + np.array([3, 4, 5]) + + Yt = PCA(n_components=2).fit(X).transform(Xt) Yt /= np.sqrt((Yt**2).sum()) + np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1) -def test_fast_pca_check_projection(): - """Test that the projection of data is correct""" +def test_randomized_pca_check_projection(): + """Test that the projection by RandomizedPCA on dense data is correct""" n, p = 100, 3 X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5]) - pca = PCA(n_comp=2, do_fast_svd=True) - pca.fit(X) - Xt = 0.1* randn(1, p) + np.array([3, 4, 5]) - Yt = pca.transform(Xt) + Xt = 0.1 * randn(1, p) + np.array([3, 4, 5]) + + Yt = RandomizedPCA(n_components=2).fit(X).transform(Xt) + Yt /= np.sqrt((Yt ** 2).sum()) + + np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1) + + +def test_sparse_randomized_pca_check_projection(): + """Test that the projection by RandomizedPCA on sparse data is correct""" + n, p = 100, 3 + X = randn(n, p) * .1 + X[:10] += np.array([3, 4, 5]) + X = csr_matrix(X) + Xt = 0.1 * randn(1, p) + np.array([3, 4, 5]) + Xt = csr_matrix(Xt) + + Yt = RandomizedPCA(n_components=2).fit(X).transform(Xt) Yt /= np.sqrt((Yt ** 2).sum()) + np.testing.assert_almost_equal(np.abs(Yt[0][0]), 1., 1) @@ -93,9 +113,9 @@ def test_pca_dim(): n, p = 100, 5 X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5, 1, 2]) - pca = PCA(n_comp='mle') + pca = PCA(n_components='mle') pca.fit(X) - assert_true(pca.n_comp == 1) + assert_true(pca.n_components == 1) def test_infer_dim_1(): @@ -106,7 +126,7 @@ def test_infer_dim_1(): n, p = 1000, 5 X = randn(n, p) * .1 + randn(n, 1) * np.array([3, 4, 5, 1, 2]) \ + np.array([1, 0, 7, 4, 6]) - pca = PCA(n_comp=p) + pca = PCA(n_components=p) pca.fit(X) spect = pca.explained_variance_ ll = [] @@ -125,7 +145,7 @@ def test_infer_dim_2(): X = randn(n, p) * .1 X[:10] += np.array([3, 4, 5, 1, 2]) X[10:20] += np.array([6, 0, 7, 2, -1]) - pca = PCA(n_comp=p) + pca = PCA(n_components=p) pca.fit(X) spect = pca.explained_variance_ assert_true(_infer_dimension_(spect, n, p) > 1) @@ -139,7 +159,7 @@ def test_infer_dim_3(): X[:10] += np.array([3, 4, 5, 1, 2]) X[10:20] += np.array([6, 0, 7, 2, -1]) X[30:40] += 2*np.array([-1, 1, -1, 1, -1]) - pca = PCA(n_comp=p) + pca = PCA(n_components=p) pca.fit(X) spect = pca.explained_variance_ assert_true(_infer_dimension_(spect, n, p) > 2) @@ -149,7 +169,7 @@ def test_probabilistic_pca_1(): """Test that probabilistic PCA yields a reasonable score""" n, p = 1000, 3 X = randn(n, p)*.1 + np.array([3, 4, 5]) - ppca = ProbabilisticPCA(n_comp=2) + ppca = ProbabilisticPCA(n_components=2) ppca.fit(X) ll1 = ppca.score(X) h = 0.5 * np.log(2 * np.pi * np.exp(1) / 0.1**2) * p @@ -160,7 +180,7 @@ def test_probabilistic_pca_2(): """Test that probabilistic PCA correctly separated different datasets""" n, p = 100, 3 X = randn(n, p) * .1 + np.array([3, 4, 5]) - ppca = ProbabilisticPCA(n_comp=2) + ppca = ProbabilisticPCA(n_components=2) ppca.fit(X) ll1 = ppca.score(X) ll2 = ppca.score(randn(n, p) * .2 + np.array([3, 4, 5])) @@ -173,7 +193,7 @@ def test_probabilistic_pca_3(): """ n, p = 100, 3 X = randn(n, p)*.1 + np.array([3, 4, 5]) - ppca = ProbabilisticPCA(n_comp=2) + ppca = ProbabilisticPCA(n_components=2) ppca.fit(X) ll1 = ppca.score(X) ppca.fit(X, False) @@ -188,7 +208,7 @@ def test_probabilistic_pca_4(): Xt = randn(n, p) + randn(n, 1)*np.array([3, 4, 5]) + np.array([1, 0, 7]) ll = np.zeros(p) for k in range(p): - ppca = ProbabilisticPCA(n_comp=k) + ppca = ProbabilisticPCA(n_components=k) ppca.fit(Xl) ll[k] = ppca.score(Xt).mean() diff --git a/scikits/learn/utils/_csgraph.py b/scikits/learn/utils/_csgraph.py index e395c6b1d8ad7..c577fc35a39bd 100644 --- a/scikits/learn/utils/_csgraph.py +++ b/scikits/learn/utils/_csgraph.py @@ -32,7 +32,7 @@ def cs_graph_components(x): Returns -------- - n_comp: int + n_components: int The number of connected components. label: ndarray (ints, 1 dimension): The label array of each connected component (-2 is used to @@ -74,8 +74,8 @@ def cs_graph_components(x): label = np.empty((shape[0],), dtype=x.indptr.dtype) - n_comp = _cs_graph_components(shape[0], x.indptr, x.indices, label) + n_components = _cs_graph_components(shape[0], x.indptr, x.indices, label) - return n_comp, label + return n_components, label