Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kmeans2 '++' method is orders of magnitude slower than sklearn.cluster.KMeans() #11877

Closed
Gabriel-p opened this issue Apr 17, 2020 · 2 comments
Closed
Labels
defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.cluster
Milestone

Comments

@Gabriel-p
Copy link

Gabriel-p commented Apr 17, 2020

The implementation of kmeans that uses the '++' method is orders of magnitude slower than the same algorithm in sklearn.cluster.KMeans(). This happens when n_clusters>10 in the code below.

With a 'random' initialization scipy's method is faster, so the problem is with '++'.

Reproducing code example:

import numpy as np
from sklearn.cluster import KMeans
from scipy.cluster.vq import kmeans2
import time as t

N = 1000
data = np.random.uniform(0., 1., (N, 4))

n_clusters = 50

s = t.time()
centroid, labels1 = kmeans2(
    data, n_clusters, minit='++', iter=10, thresh=1.e-4)
print("scipy:", t.time() - s)

s = t.time()
# KMeans
model = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, tol=1.e-4,
    max_iter=1000, verbose=0)
model.fit(data)
labels2 = model.labels_
print("sklearn", t.time() - s)

Scipy/Numpy/Python version information:

1.4.1 1.18.1 sys.version_info(major=3, minor=7, micro=6, releaselevel='final', serial=0)

@miladsade96 miladsade96 added defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.cluster labels Apr 17, 2020
@jjerphan
Copy link
Contributor

jjerphan commented Apr 30, 2020

Hello,
I am interested to contribute to scipy, is this a good first issue to pick?

jjerphan added a commit to jjerphan/scipy that referenced this issue Apr 30, 2020
The squared distances were computed in Python in a suboptimal way with a
complexity of O(n²). It took the overall running time of the function.

They are now computed using cdist, making the overall algorithm fast.

Relates to: scipy#11877
jjerphan added a commit to jjerphan/scipy that referenced this issue May 1, 2020
rgommers pushed a commit that referenced this issue May 2, 2020
The squared distances were computed in Python in a suboptimal way with a
complexity of O(n²). It took the overall running time of the function.

They are now computed using cdist, making the overall algorithm fast.

Relates to: #11877
@rgommers
Copy link
Member

rgommers commented May 2, 2020

Fixed by gh-11982. Thanks @jjerphan, and thanks for reporting @Gabriel-p

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.cluster
Projects
None yet
Development

No branches or pull requests

4 participants