diff --git a/sklearn/cluster/__init__.py b/sklearn/cluster/__init__.py index 16f43d374e510..6accb53e2dcad 100644 --- a/sklearn/cluster/__init__.py +++ b/sklearn/cluster/__init__.py @@ -3,7 +3,8 @@ algorithms. """ -from .spectral import spectral_clustering, SpectralClustering +from .spectral import (spectral_clustering, SpectralClustering, + diffusion_clustering) from .mean_shift_ import (mean_shift, MeanShift, estimate_bandwidth, get_bin_seeds) from .affinity_propagation_ import affinity_propagation, AffinityPropagation @@ -34,6 +35,7 @@ 'linkage_tree', 'mean_shift', 'spectral_clustering', + 'diffusion_clustering', 'ward_tree', 'SpectralBiclustering', 'SpectralCoclustering'] diff --git a/sklearn/cluster/spectral.py b/sklearn/cluster/spectral.py index 39475929ebd0a..a155797758c9d 100644 --- a/sklearn/cluster/spectral.py +++ b/sklearn/cluster/spectral.py @@ -15,7 +15,7 @@ from ..utils.extmath import norm from ..metrics.pairwise import pairwise_kernels from ..neighbors import kneighbors_graph -from ..manifold import spectral_embedding +from ..manifold import spectral_embedding, diffusion_embedding from .k_means_ import k_means @@ -453,3 +453,114 @@ def fit(self, X, y=None): @property def _pairwise(self): return self.affinity == "precomputed" + + +def diffusion_clustering(affinity, n_clusters=8, n_components=None, + eigen_solver=None, random_state=None, n_init=10, + eigen_tol=0.0, assign_labels='kmeans', + diffusion_time=0): + """Apply clustering to a projection to the normalized laplacian. + + In practice Spectral Clustering is very useful when the structure of + the individual clusters is highly non-convex or more generally when + a measure of the center and spread of the cluster is not a suitable + description of the complete cluster. For instance when clusters are + nested circles on the 2D plan. + + If affinity is the adjacency matrix of a graph, this method can be + used to find normalized graph cuts. + + Parameters + ----------- + affinity: array-like or sparse matrix, shape: (n_samples, n_samples) + The affinity matrix describing the relationship of the samples to + embed. **Must be symmetric**. + + Possible examples: + - adjacency matrix of a graph, + - heat kernel of the pairwise distance matrix of the samples, + - symmetric k-nearest neighbours connectivity matrix of the samples. + + n_clusters: integer, optional + Number of clusters to extract. + + n_components: integer, optional, default is k + Number of eigen vectors to use for the spectral embedding + + eigen_solver: {None, 'arpack', 'lobpcg', or 'amg'} + The eigenvalue decomposition strategy to use. AMG requires pyamg + to be installed. It can be faster on very large, sparse problems, + but may also lead to instabilities + + random_state: int seed, RandomState instance, or None (default) + A pseudo random number generator used for the initialization + of the lobpcg eigen vectors decomposition when eigen_solver == 'amg' + and by the K-Means initialization. + + n_init: int, optional, default: 10 + Number of time the k-means algorithm will be run with different + centroid seeds. The final results will be the best output of + n_init consecutive runs in terms of inertia. + + eigen_tol : float, optional, default: 0.0 + Stopping criterion for eigendecomposition of the Laplacian matrix + when using arpack eigen_solver. + + assign_labels : {'kmeans', 'discretize'}, default: 'kmeans' + The strategy to use to assign labels in the embedding + space. There are two ways to assign labels after the laplacian + embedding. k-means can be applied and is a popular choice. But it can + also be sensitive to initialization. Discretization is another + approach which is less sensitive to random initialization. See + the 'Multiclass spectral clustering' paper referenced below for + more details on the discretization approach. + + Returns + ------- + labels: array of integers, shape: n_samples + The labels of the clusters. + + References + ---------- + + - Normalized cuts and image segmentation, 2000 + Jianbo Shi, Jitendra Malik + http://citeseer.ist.psu.edu/viewdoc/summary?doi=10.1.1.160.2324 + + - A Tutorial on Spectral Clustering, 2007 + Ulrike von Luxburg + http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.165.9323 + + - Multiclass spectral clustering, 2003 + Stella X. Yu, Jianbo Shi + http://www1.icsi.berkeley.edu/~stellayu/publication/doc/2003kwayICCV.pdf + + Notes + ------ + The graph should contain only one connect component, elsewhere + the results make little sense. + + This algorithm solves the normalized cut for k=2: it is a + normalized spectral clustering. + """ + if not assign_labels in ('kmeans', 'discretize'): + raise ValueError("The 'assign_labels' parameter should be " + "'kmeans' or 'discretize', but '%s' was given" + % assign_labels) + + random_state = check_random_state(random_state) + n_components = n_clusters if n_components is None else n_components + maps = diffusion_embedding(affinity, + n_components=n_components, + eigen_solver=eigen_solver, + random_state=random_state, + eigen_tol=eigen_tol, + diffusion_time=diffusion_time) + + if assign_labels == 'kmeans': + _, labels, _ = k_means(maps, n_clusters, random_state=random_state, + n_init=n_init) + else: + labels = discretize(maps, random_state=random_state) + + return labels diff --git a/sklearn/manifold/__init__.py b/sklearn/manifold/__init__.py index c16d87881a1ad..fe0254020cbfe 100644 --- a/sklearn/manifold/__init__.py +++ b/sklearn/manifold/__init__.py @@ -5,8 +5,10 @@ from .locally_linear import locally_linear_embedding, LocallyLinearEmbedding from .isomap import Isomap from .mds import MDS -from .spectral_embedding_ import SpectralEmbedding, spectral_embedding from .t_sne import TSNE +from .spectral_embedding_ import (SpectralEmbedding, spectral_embedding, + diffusion_embedding, DiffusionEmbedding) __all__ = ['locally_linear_embedding', 'LocallyLinearEmbedding', 'Isomap', - 'MDS', 'SpectralEmbedding', 'spectral_embedding', "TSNE"] + 'MDS', 'SpectralEmbedding', 'spectral_embedding', + 'diffusion_embedding', 'DiffusionEmbedding', 'TSNE'] diff --git a/sklearn/manifold/spectral_embedding_.py b/sklearn/manifold/spectral_embedding_.py index 277ce0c48c9f3..a0ded36b1d0f8 100644 --- a/sklearn/manifold/spectral_embedding_.py +++ b/sklearn/manifold/spectral_embedding_.py @@ -17,7 +17,7 @@ from ..utils.extmath import _deterministic_vector_sign_flip from ..utils.graph import graph_laplacian from ..utils.sparsetools import connected_components -from ..utils.arpack import eigsh +from ..utils.arpack import eigsh, eigs from ..metrics.pairwise import rbf_kernel from ..neighbors import kneighbors_graph @@ -118,9 +118,9 @@ def _set_diag(laplacian, value): return laplacian -def spectral_embedding(adjacency, n_components=8, eigen_solver=None, - random_state=None, eigen_tol=0.0, - norm_laplacian=True, drop_first=True): +def _solve_eigenvalue_problem(adjacency, n_components=1, eigen_solver=None, + random_state=None, eigen_tol=0.0, + norm_laplacian=True, mode=None): """Project the sample on the first eigen vectors of the graph Laplacian. The adjacency matrix is used to compute a normalized graph Laplacian @@ -170,8 +170,12 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, Returns ------- - embedding : array, shape=(n_samples, n_components) - The reduced samples. + lambdas : array, shape=(n_components,) + The eigenvalues of the Laplacian + vectors : array, shape=(n_components, n_samples) + The eigenvectors of the Laplacian + degrees : array, shape=(n_samples,) + The degrees of the graph Notes ----- @@ -207,9 +211,23 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, random_state = check_random_state(random_state) n_nodes = adjacency.shape[0] + ''' +<<<<<<< HEAD + # Check that the matrices given is symmetric + if ((not sparse.isspmatrix(adjacency) and + not np.all((adjacency - adjacency.T) < 1e-10)) or + (sparse.isspmatrix(adjacency) and + not np.all((adjacency - adjacency.T).data < 1e-10))): + warnings.warn("Graph adjacency matrix should be symmetric. " + "Converted to be symmetric by average with its " + "transpose.") + adjacency = .5 * (adjacency + adjacency.T) +======= # Whether to drop the first eigenvector if drop_first: n_components = n_components + 1 +>>>>>>> upstream/master + ''' if not _graph_is_connected(adjacency): warnings.warn("Graph is not fully connected, spectral embedding" @@ -217,6 +235,7 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, laplacian, dd = graph_laplacian(adjacency, normed=norm_laplacian, return_diag=True) + if (eigen_solver == 'arpack' or eigen_solver != 'lobpcg' and (not sparse.isspmatrix(laplacian) @@ -250,7 +269,8 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, lambdas, diffusion_map = eigsh(laplacian, k=n_components, sigma=1.0, which='LM', tol=eigen_tol) - embedding = diffusion_map.T[n_components::-1] * dd + vectors = diffusion_map.T[n_components::-1] + lambdas = lambdas[n_components::-1] except RuntimeError: # When submatrices are exactly singular, an LU decomposition # in arpack fails. We fallback to lobpcg @@ -271,8 +291,8 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, X[:, 0] = dd.ravel() lambdas, diffusion_map = lobpcg(laplacian, X, M=M, tol=1.e-12, largest=False) - embedding = diffusion_map.T * dd - if embedding.shape[0] == 1: + vectors = diffusion_map.T + if vectors.shape[0] == 1: raise ValueError elif eigen_solver == "lobpcg": @@ -282,9 +302,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, # number of nodes # lobpcg will fallback to eigh, so we short circuit it if sparse.isspmatrix(laplacian): - laplacian = laplacian.toarray() + laplacian = laplacian.todense() lambdas, diffusion_map = eigh(laplacian) - embedding = diffusion_map.T[:n_components] * dd + vectors = diffusion_map.T[:n_components] else: # lobpcg needs native floats laplacian = laplacian.astype(np.float) @@ -295,15 +315,98 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, X[:, 0] = dd.ravel() lambdas, diffusion_map = lobpcg(laplacian, X, tol=1e-15, largest=False, maxiter=2000) - embedding = diffusion_map.T[:n_components] * dd - if embedding.shape[0] == 1: + vectors = diffusion_map.T[:n_components] + if vectors.shape[0] == 1: raise ValueError + return lambdas[:n_components], vectors, dd + + +def spectral_embedding(adjacency, n_components=8, eigen_solver=None, + random_state=None, eigen_tol=0.0, + norm_laplacian=True, drop_first=True, + mode=None, diffusion_time=None): + """Project the sample on the first eigen vectors of the graph Laplacian + + The adjacency matrix is used to compute a normalized graph Laplacian + whose spectrum (especially the eigen vectors associated to the + smallest eigen values) has an interpretation in terms of minimal + number of cuts necessary to split the graph into comparably sized + components. + + This embedding can also 'work' even if the ``adjacency`` variable is + not strictly the adjacency matrix of a graph but more generally + an affinity or similarity matrix between samples (for instance the + heat kernel of a euclidean distance matrix or a k-NN matrix). + + However care must taken to always make the affinity matrix symmetric + so that the eigen vector decomposition works as expected. - embedding = _deterministic_vector_sign_flip(embedding) + Parameters + ---------- + adjacency : array-like or sparse matrix, shape: (n_samples, n_samples) + The adjacency matrix of the graph to embed. + + n_components : integer, optional + The dimension of the projection subspace. + + eigen_solver : {None, 'arpack', 'lobpcg', or 'amg'} + The eigenvalue decomposition strategy to use. AMG requires pyamg + to be installed. It can be faster on very large, sparse problems, + but may also lead to instabilities. + + random_state : int seed, RandomState instance, or None (default) + A pseudo random number generator used for the initialization of the + lobpcg eigen vectors decomposition when eigen_solver == 'amg'. + By default, arpack is used. + + eigen_tol : float, optional, default=0.0 + Stopping criterion for eigendecomposition of the Laplacian matrix + when using arpack eigen_solver. + + drop_first : bool, optional, default=True + Whether to drop the first eigenvector. For spectral embedding, this + should be True as the first eigenvector should be constant vector for + connected graph, but for spectral clustering, this should be kept as + False to retain the first eigenvector. + + diffusion_time: float, optional, default=None + Determines the scaling of the eigenvalues of the Laplacian + + Returns + ------- + embedding : array, shape=(n_samples, n_components) + The reduced samples. + + Notes + ----- + Spectral embedding is most useful when the graph has one connected + component. If there graph has many components, the first few eigenvectors + will simply uncover the connected components of the graph. + + References + ---------- + * http://en.wikipedia.org/wiki/LOBPCG + + * Toward the Optimal Preconditioned Eigensolver: Locally Optimal + Block Preconditioned Conjugate Gradient Method + Andrew V. Knyazev + http://dx.doi.org/10.1137%2FS1064827500366124 + """ + # Whether to drop the first eigenvector + if drop_first: + n_components += 1 + _, vectors, dd = _solve_eigenvalue_problem(adjacency=adjacency, + n_components=n_components, + eigen_solver=eigen_solver, + random_state=random_state, + eigen_tol=eigen_tol, + norm_laplacian=norm_laplacian, + mode=mode) + embedding = _deterministic_vector_sign_flip(vectors * dd) if drop_first: - return embedding[1:n_components].T + return embedding[1:].T else: - return embedding[:n_components].T + return embedding.T class SpectralEmbedding(BaseEstimator): @@ -482,3 +585,208 @@ def fit_transform(self, X, y=None): """ self.fit(X) return self.embedding_ + + +def diffusion_embedding(adjacency, n_components=8, diffusion_time=None, + eigen_solver=None, random_state=None, eigen_tol=0.0, + norm_laplacian=True): + """Project the sample on the first eigenvectors of the graph Laplacian + + The adjacency matrix is used to compute a normalized graph Laplacian + whose spectrum (especially the eigen vectors associated to the + smallest eigen values) has an interpretation in terms of minimal + number of cuts necessary to split the graph into comparably sized + components. + + This embedding can also 'work' even if the ``adjacency`` variable is + not strictly the adjacency matrix of a graph but more generally + an affinity or similarity matrix between samples (for instance the + heat kernel of a euclidean distance matrix or a k-NN matrix). + + However care must taken to always make the affinity matrix symmetric + so that the eigen vector decomposition works as expected. + + Parameters + ---------- + adjacency : array-like or sparse matrix, shape: (n_samples, n_samples) + The adjacency matrix of the graph to embed. + + n_components : integer, optional + The dimension of the projection subspace. + + diffusion_time: float, optional, default=None + Determines the scaling of the eigenvalues of the Laplacian + + eigen_solver : {None, 'arpack', 'lobpcg', or 'amg'} + The eigenvalue decomposition strategy to use. AMG requires pyamg + to be installed. It can be faster on very large, sparse problems, + but may also lead to instabilities. + + random_state : int seed, RandomState instance, or None (default) + A pseudo random number generator used for the initialization of the + lobpcg eigen vectors decomposition when eigen_solver == 'amg'. + By default, arpack is used. + + eigen_tol : float, optional, default=0.0 + Stopping criterion for eigendecomposition of the Laplacian matrix + when using arpack eigen_solver. + + Returns + ------- + embedding : array, shape=(n_samples, n_components) + The reduced samples. + + Notes + ----- + Diffusion embedding is most useful when the graph has one connected + component. If there graph has many components, the first few eigenvectors + will simply uncover the connected components of the graph. + + References + ---------- + + - Lafon, Stephane, and Ann B. Lee. "Diffusion maps and coarse-graining: A + unified framework for dimensionality reduction, graph partitioning, and + data set parameterization." Pattern Analysis and Machine Intelligence, + IEEE Transactions on 28.9 (2006): 1393-1403. + - Coifman, Ronald R., and Stephane Lafon. Diffusion maps. Applied and + Computational Harmonic Analysis 21.1 (2006): 5-30. + + """ + + if not _graph_is_connected(adjacency): + warnings.warn("Graph is not fully connected, spectral embedding" + " may not work as expected.") + K = sparse.csr_matrix(adjacency) + ndim = K.shape[0] + v = np.array(np.sqrt(K.sum(axis=1))).flatten() + A = K.copy() + del K + A.data /= v[A.indices] + A = sparse.csr_matrix(A.transpose().toarray()) + A.data /= v[A.indices] + A = sparse.csr_matrix(A.transpose().toarray()) + + func = eigs + if n_components is not None: + lambdas, vectors = func(A, k=n_components + 1) + else: + lambdas, vectors = func(A, k=max(2, int(np.sqrt(ndim)))) + del A + + if func == eigsh: + lambdas = lambdas[::-1] + vectors = vectors[:, ::-1] + else: + lambdas = np.real(lambdas) + vectors = np.real(vectors) + lambda_idx = np.argsort(lambdas)[::-1] + lambdas = lambdas[lambda_idx] + vectors = vectors[:, lambda_idx] + + psi = vectors/vectors[:, [0]] + if diffusion_time <= 0: + lambdas = lambdas[1:] / (1 - lambdas[1:]) + else: + lambdas = lambdas[1:] ** float(diffusion_time) + lambda_ratio = lambdas/lambdas[0] + threshold = max(0.05, lambda_ratio[-1]) + + n_components_auto = np.amax(np.nonzero(lambda_ratio > threshold)[0]) + n_components_auto = min(n_components_auto, ndim) + if n_components is None: + n_components = n_components_auto + embedding = psi[:, 1:(n_components + 1)] * lambdas[:n_components][None, :] + return embedding + + +class DiffusionEmbedding(SpectralEmbedding): + """Diffusion embedding for nonlinear dimensionality reduction + + Diffusion embedding adds an extra parameter `diffusion_time` to spectral + embedding. + + Parameters + ----------- + diffusion_time : float + Determines the scaling of the eigenvalues of the Laplacian + + For all other parameters see `SpectralEmbedding` + + References + ---------- + + - Lafon, Stephane, and Ann B. Lee. "Diffusion maps and coarse-graining: A + unified framework for dimensionality reduction, graph partitioning, and + data set parameterization." Pattern Analysis and Machine Intelligence, + IEEE Transactions on 28.9 (2006): 1393-1403. + - Coifman, Ronald R., and Stephane Lafon. Diffusion maps. Applied and + Computational Harmonic Analysis 21.1 (2006): 5-30. + + """ + + def __init__(self, diffusion_time=0., n_components=None, + affinity="nearest_neighbors", gamma=None, random_state=None, + eigen_solver=None, n_neighbors=None): + super(DiffusionEmbedding, self).__init__(n_components=n_components, + affinity=affinity, + gamma=gamma, + random_state=random_state, + eigen_solver=eigen_solver, + n_neighbors=n_neighbors) + self.diffusion_time = diffusion_time + + def fit(self, X, y=None): + """Fit the model from data in X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples + and n_features is the number of features. + + If affinity is "precomputed" + X : array-like, shape (n_samples, n_samples), + Interpret X as precomputed adjacency graph computed from + samples. + + Returns + ------- + self : object + Returns the instance itself. + """ + random_state = check_random_state(self.random_state) + if isinstance(self.affinity, basestring): + if self.affinity not in set(("nearest_neighbors", "rbf", "cauchy", + "precomputed", "markov")): + raise ValueError(("%s is not a valid affinity. Expected 'markov', " + "'precomputed', 'rbf', 'nearest_neighbors' " + "or a callable.") % self.affinity) + elif not hasattr(self.affinity, "__call__"): + raise ValueError(("'affinity' is expected to be an an affinity " + "name or a callable. Got: %s") % self.affinity) + + from ..decomposition import RandomizedPCA + pca = RandomizedPCA(n_components=self.n_components, + random_state=random_state) + #X = pca.fit_transform(X) + eps = self.gamma + if self.affinity in ['markov', 'cauchy']: + from ..metrics import pairwise_distances + D = pairwise_distances(X, metric='euclidean') #, squared=True) + if eps is None: + k = int(max(2, np.round(D.shape[0] * 0.01))) + eps = 2 * np.median(np.sort(D, axis=0)[k+1, :])**2 + if self.affinity == 'markov': + affinity_matrix = np.exp(-(D * D) / eps) + elif self.affinity == 'cauchy': + affinity_matrix = 1./(D * D + eps) + else: + affinity_matrix = self._get_affinity_matrix(X) + self.eps_ = eps + self.embedding_ = diffusion_embedding(affinity_matrix, + n_components=self.n_components, + eigen_solver=self.eigen_solver, + random_state=random_state, + diffusion_time=self.diffusion_time) + return self