Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions examples/cluster/plot_coin_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

* with 'kmeans' spectral clustering will cluster samples in the embedding space
using a kmeans algorithm
* with 'clusterQR' will cluster samples in the embedding space
using a clusterQR algorithm,
* whereas 'discrete' will iteratively search for the closest partition
space to the embedding space.
"""
print(__doc__)

# Author: Gael Varoquaux <gael.varoquaux@normalesup.org>, Brian Cheung
# Andrew Knyazev added clusterQR
# License: BSD 3 clause

import time
Expand Down Expand Up @@ -54,28 +57,31 @@
eps = 1e-6
graph.data = np.exp(-beta * graph.data / graph.data.std()) + eps

# Apply spectral clustering (this step goes much faster if you have pyamg
# installed)
N_REGIONS = 25
# the actual number of regions in this example is 27: background and 26 coins
N_REGIONS = 26

#############################################################################
# Visualize the resulting regions
# compute and visualize the resulting regions

for assign_labels in ('kmeans', 'discretize'):
# if often helps the spectral clustering to compute a few extra eigenvectors
N_REGIONS_PLUS = 3

for assign_labels in ('kmeans', 'discretize', 'clusterQR'):
t0 = time.time()
labels = spectral_clustering(graph, n_clusters=N_REGIONS,
labels = spectral_clustering(graph, n_clusters=N_REGIONS+N_REGIONS_PLUS,
assign_labels=assign_labels, random_state=42)
t1 = time.time()
labels = labels.reshape(rescaled_coins.shape)

plt.figure(figsize=(5, 5))
plt.imshow(rescaled_coins, cmap=plt.cm.gray)
for l in range(N_REGIONS):
plt.contour(labels == l,
colors=[plt.cm.nipy_spectral(l / float(N_REGIONS))])
plt.imshow(rescaled_coins, cmap=plt.get_cmap('gray'))
plt.xticks(())
plt.yticks(())
title = 'Spectral clustering: %s, %.2fs' % (assign_labels, (t1 - t0))
print(title)
plt.title(title)
for l in range(N_REGIONS):
plt.contour(labels == l,
colors=[plt.cm.nipy_spectral((l+3) / float(N_REGIONS+3))])
plt.pause(0.5)
plt.show()
51 changes: 45 additions & 6 deletions sklearn/cluster/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# Author: Gael Varoquaux gael.varoquaux@normalesup.org
# Brian Cheung
# Wei LI <kuantkid@gmail.com>
# Modified by Andrew Knyazev to add clusterQR
# License: BSD 3 clause
import warnings

import numpy as np

from scipy.linalg import qr, svd

from ..base import BaseEstimator, ClusterMixin
from ..utils import check_random_state, as_float_array
from ..utils.validation import check_array
Expand All @@ -18,6 +21,40 @@
from .k_means_ import k_means


def clusterQR(vectors):
"""Search for a partition matrix (clustering) which is
closest to the eigenvector embedding.

Parameters
----------
vectors : array-like, shape: (n_samples, n_clusters)
The embedding space of the samples.

Returns
-------
labels : array of integers, shape: n_samples
The labels of the clusters.

References
----------
https://github.com/asdamle/QR-spectral-clustering
https://arxiv.org/abs/1708.07481

Notes
-----
T.conj() allows the vectors to be complex-valued, just in case for future use

"""

k = vectors.shape[1]
piv = qr(vectors.T.conj(), pivoting=True)[2]
piv = piv[0:k]
Ut, Vt = svd(vectors[piv, :].T.conj())[0],\
svd(vectors[piv, :].T.conj())[2].T.conj()
vectors = abs(np.dot(vectors, np.dot(Ut, Vt.T.conj())))
return (vectors.argmax(axis=1)).T


def discretize(vectors, copy=True, max_svd_restarts=30, n_iter_max=20,
random_state=None):
"""Search for a partition matrix (clustering) which is closest to the
Expand All @@ -42,7 +79,6 @@ def discretize(vectors, copy=True, max_svd_restarts=30, n_iter_max=20,
Determines random number generation for rotation matrix initialization.
Use an int to make the randomness deterministic.
See :term:`Glossary <random_state>`.

Returns
-------
labels : array of integers, shape: n_samples
Expand Down Expand Up @@ -210,7 +246,7 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None,
Stopping criterion for eigendecomposition of the Laplacian matrix
when using arpack eigen_solver.

assign_labels : {'kmeans', 'discretize'}, default: 'kmeans'
assign_labels : {'kmeans', 'discretize', 'clusterQR'}, default: 'kmeans'
The strategy to use to assign labels in the embedding
space. There are two ways to assign labels after the laplacian
embedding. k-means can be applied and is a popular choice. But it can
Expand Down Expand Up @@ -247,10 +283,11 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None,
This algorithm solves the normalized cut for k=2: it is a
normalized spectral clustering.
"""
if assign_labels not in ('kmeans', 'discretize'):
raise ValueError("The 'assign_labels' parameter should be "
"'kmeans' or 'discretize', but '%s' was given"
% assign_labels)
if assign_labels not in ('kmeans', 'discretize', 'clusterQR'):
raise ValueError(
"The 'assign_labels' parameter should be "
"'kmeans', 'discretize', or 'clusterQR' but '%s' was given"
% assign_labels)

random_state = check_random_state(random_state)
n_components = n_clusters if n_components is None else n_components
Expand All @@ -266,6 +303,8 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None,
if assign_labels == 'kmeans':
_, labels, _ = k_means(maps, n_clusters, random_state=random_state,
n_init=n_init)
elif assign_labels == 'clusterQR':
labels = clusterQR(maps)
else:
labels = discretize(maps, random_state=random_state)

Expand Down
8 changes: 6 additions & 2 deletions sklearn/cluster/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@


@pytest.mark.parametrize('eigen_solver', ('arpack', 'lobpcg'))
@pytest.mark.parametrize('assign_labels', ('kmeans', 'discretize'))
@pytest.mark.parametrize(
'assign_labels',
('kmeans',
'discretize',
'clusterQR'))
def test_spectral_clustering(eigen_solver, assign_labels):
S = np.array([[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0],
Expand Down Expand Up @@ -109,7 +113,7 @@ def test_affinities():
# on OSX and Linux
X, y = make_blobs(n_samples=20, random_state=0,
centers=[[1, 1], [-1, -1]], cluster_std=0.01
)
)
# nearest neighbors affinity
sp = SpectralClustering(n_clusters=2, affinity='nearest_neighbors',
random_state=0)
Expand Down
24 changes: 16 additions & 8 deletions sklearn/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Denis A. Engemann <denis-alexander.engemann@inria.fr>
# Michael Eickenberg <michael.eickenberg@inria.fr>
# Giorgio Patrini <giorgio.patrini@anu.edu.au>
#
# Andrew Knyazev added lobpcg
# License: BSD 3 clause

from math import log, sqrt
Expand All @@ -24,7 +24,7 @@
from .base import _BasePCA
from ..utils import check_random_state
from ..utils import check_array
from ..utils.extmath import fast_logdet, randomized_svd, svd_flip
from ..utils.extmath import fast_logdet, randomized_svd, svd_flip, lobpcg_svd
from ..utils.extmath import stable_cumsum
from ..utils.validation import check_is_fitted

Expand Down Expand Up @@ -159,7 +159,7 @@ class PCA(_BasePCA):
improve the predictive accuracy of the downstream estimators by
making their data respect some hard-wired assumptions.

svd_solver : string {'auto', 'full', 'arpack', 'randomized'}
svd_solver : string {'auto', 'full', 'arpack', 'randomized', 'lobpcg'}
auto :
the solver is selected by a default policy based on `X.shape` and
`n_components`: if the input data is larger than 500x500 and the
Expand All @@ -185,7 +185,7 @@ class PCA(_BasePCA):
.. versionadded:: 0.18.0

iterated_power : int >= 0, or 'auto', (default 'auto')
Number of iterations for the power method computed by
Number of iterations of svd_solver = 'lobpcg' or for the power method computed by
svd_solver == 'randomized'.

.. versionadded:: 0.18.0
Expand All @@ -194,7 +194,7 @@ class PCA(_BasePCA):
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`. Used when ``svd_solver`` == 'arpack' or 'randomized'.
by `np.random`. Used when ``svd_solver`` == 'arpack', 'lobpcg', or 'randomized'.

.. versionadded:: 0.18.0

Expand Down Expand Up @@ -356,7 +356,7 @@ def fit_transform(self, X, y=None):
X_new : array-like, shape (n_samples, n_components)

"""
U, S, V = self._fit(X)
U, S, _ = self._fit(X)
U = U[:, :self.n_components_]

if self.whiten:
Expand Down Expand Up @@ -397,14 +397,15 @@ def _fit(self, X):
self._fit_svd_solver = 'full'
elif n_components >= 1 and n_components < .8 * min(X.shape):
self._fit_svd_solver = 'randomized'
# need to add 'lobpcg' here
# This is also the case of n_components in (0,1)
else:
self._fit_svd_solver = 'full'

# Call different fits for either full or truncated SVD
if self._fit_svd_solver == 'full':
return self._fit_full(X, n_components)
elif self._fit_svd_solver in ['arpack', 'randomized']:
elif self._fit_svd_solver in ['arpack', 'randomized', 'lobpcg']:
return self._fit_truncated(X, n_components, self._fit_svd_solver)
else:
raise ValueError("Unrecognized svd_solver='{0}'"
Expand Down Expand Up @@ -474,7 +475,7 @@ def _fit_full(self, X, n_components):
return U, S, V

def _fit_truncated(self, X, n_components, svd_solver):
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
"""Fit the model by computing truncated SVD (by ARPACK, LOBPCG, or randomized)
on X
"""
n_samples, n_features = X.shape
Expand Down Expand Up @@ -524,6 +525,13 @@ def _fit_truncated(self, X, n_components, svd_solver):
flip_sign=True,
random_state=random_state)

elif svd_solver == 'lobpcg':
# sign flipping is done inside
U, S, V = lobpcg_svd(X, n_components=n_components,
n_iter=self.iterated_power,
flip_sign=True,
random_state=random_state)

self.n_samples_, self.n_features_ = n_samples, n_features
self.components_ = V
self.n_components_ = n_components
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sklearn.decomposition.pca import _infer_dimension_

iris = datasets.load_iris()
solver_list = ['full', 'arpack', 'randomized', 'auto']
solver_list = ['full', 'arpack', 'randomized', 'lobpcg', 'auto']


def test_pca():
Expand Down
26 changes: 21 additions & 5 deletions sklearn/decomposition/tests/test_truncated_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,22 @@
def test_algorithms():
svd_a = TruncatedSVD(30, algorithm="arpack")
svd_r = TruncatedSVD(30, algorithm="randomized", random_state=42)
svd_l = TruncatedSVD(30, algorithm="lobpcg", random_state=42)

Xa = svd_a.fit_transform(X)[:, :6]
Xr = svd_r.fit_transform(X)[:, :6]
Xl = svd_l.fit_transform(X)[:, :6]
assert_array_almost_equal(Xa, Xr, decimal=5)
assert_array_almost_equal(Xa, Xl, decimal=5)

comp_a = np.abs(svd_a.components_)
comp_r = np.abs(svd_r.components_)
comp_l = np.abs(svd_l.components_)
# All elements are equal, but some elements are more equal than others.
assert_array_almost_equal(comp_a[:9], comp_r[:9])
assert_array_almost_equal(comp_a[9:], comp_r[9:], decimal=2)
assert_array_almost_equal(comp_a[:9], comp_l[:9])
assert_array_almost_equal(comp_a[9:], comp_l[9:], decimal=2)


def test_attributes():
Expand All @@ -45,7 +51,7 @@ def test_attributes():
assert_equal(tsvd.components_.shape, (n_components, n_features))


@pytest.mark.parametrize('algorithm', ("arpack", "randomized"))
@pytest.mark.parametrize('algorithm', ("arpack", "randomized", "lobpcg"))
def test_too_many_components(algorithm):
for n_components in (n_features, n_features + 1):
tsvd = TruncatedSVD(n_components=n_components, algorithm=algorithm)
Expand All @@ -62,7 +68,7 @@ def test_sparse_formats(fmt):
assert_equal(Xtrans.shape, (n_samples, 11))


@pytest.mark.parametrize('algo', ("arpack", "randomized"))
@pytest.mark.parametrize('algo', ("arpack", "randomized", "lobpcg"))
def test_inverse_transform(algo):
# We need a lot of components for the reconstruction to be "almost
# equal" in all positions. XXX Test means or sums instead?
Expand All @@ -83,26 +89,36 @@ def test_explained_variance():
# Test sparse data
svd_a_10_sp = TruncatedSVD(10, algorithm="arpack")
svd_r_10_sp = TruncatedSVD(10, algorithm="randomized", random_state=42)
svd_l_10_sp = TruncatedSVD(10, algorithm="lobpcg", random_state=42)
svd_a_20_sp = TruncatedSVD(20, algorithm="arpack")
svd_r_20_sp = TruncatedSVD(20, algorithm="randomized", random_state=42)
svd_l_20_sp = TruncatedSVD(20, algorithm="lobpcg", random_state=42)
X_trans_a_10_sp = svd_a_10_sp.fit_transform(X)
X_trans_r_10_sp = svd_r_10_sp.fit_transform(X)
X_trans_l_10_sp = svd_l_10_sp.fit_transform(X)
X_trans_a_20_sp = svd_a_20_sp.fit_transform(X)
X_trans_r_20_sp = svd_r_20_sp.fit_transform(X)
X_trans_l_20_sp = svd_l_20_sp.fit_transform(X)

# Test dense data
svd_a_10_de = TruncatedSVD(10, algorithm="arpack")
svd_r_10_de = TruncatedSVD(10, algorithm="randomized", random_state=42)
svd_l_10_de = TruncatedSVD(10, algorithm="lobpcg", random_state=42)
svd_a_20_de = TruncatedSVD(20, algorithm="arpack")
svd_r_20_de = TruncatedSVD(20, algorithm="randomized", random_state=42)
svd_l_20_de = TruncatedSVD(20, algorithm="lobpcg", random_state=42)
X_trans_a_10_de = svd_a_10_de.fit_transform(X.toarray())
X_trans_r_10_de = svd_r_10_de.fit_transform(X.toarray())
X_trans_l_10_de = svd_l_10_de.fit_transform(X.toarray())
X_trans_a_20_de = svd_a_20_de.fit_transform(X.toarray())
X_trans_r_20_de = svd_r_20_de.fit_transform(X.toarray())
X_trans_l_20_de = svd_l_20_de.fit_transform(X.toarray())

# helper arrays for tests below
svds = (svd_a_10_sp, svd_r_10_sp, svd_a_20_sp, svd_r_20_sp, svd_a_10_de,
svd_r_10_de, svd_a_20_de, svd_r_20_de)
svds = (svd_a_10_sp, svd_r_10_sp, svd_l_10_sp,
svd_a_20_sp, svd_r_20_sp, svd_l_20_sp,
svd_a_10_de, svd_r_10_de, svd_l_10_de,
svd_a_20_de, svd_r_20_de, svd_l_20_de)
svds_trans = (
(svd_a_10_sp, X_trans_a_10_sp),
(svd_r_10_sp, X_trans_r_10_sp),
Expand Down Expand Up @@ -177,7 +193,7 @@ def test_singular_values():

apca = TruncatedSVD(n_components=2, algorithm='arpack',
random_state=rng).fit(X)
rpca = TruncatedSVD(n_components=2, algorithm='arpack',
rpca = TruncatedSVD(n_components=2, algorithm='randomized',
random_state=rng).fit(X)
assert_array_almost_equal(apca.singular_values_, rpca.singular_values_, 12)

Expand Down
Loading