-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Fast, low memory, single linkage implementation (#11514)
* First cut at basic single linkage internals * Refer to correct dist_metrics package * Add csgraph sparse implementation for single linkage * Add fast labelling/conversion from MST to single linkage tree; remove uneeded single_linkage.pyx file. * Ensure existing tests cover single linkage * Name cingle linkage labelling correctly. * Iterating toward correct solution. Still have to get n_clusters, compute_full_tree=False working * Get n_components correct. * Update docstrings. * Fix the parents array when we don't get the "full tree" * Add single linkage to agglomerative clustering example. * Add single linkage to digits agglomerative clustering example. * Update documentation to reflect the addition of single linkage. * Update documentation to reflect the addition of single linkage. * Pep8 fix for class declaration in cython * Fix heading in clustering docs * Update the digits clustering text to reflect the new reality. * Provide a more complete comparison of the different linkage methods, highlighting the relative strengths and weaknesses. * We don't need connectivity here, and we can ignore issues with warnings for spectral clustering. * Add an explicit test that single linkage successfully works on examples it should perform well on. * Update docs with a more complete comparison on linkage methods (scale to be determined?) * List formatting in example linkage comparison. * Flake8 fixes. * Flake8 fixes. * More Flake8 fixes. * Fix agglomerative plot example with correct subplot spec * Explicitly test linkages (including single) produce results identical to scipy.cluster.hierarchical * Fix comment on why we sort (consistency) * Make dense single linkage faster * Add docstring to new mst-linkage-core computations. * Add a test that new single linkage code matches scipy * Ensure we only attemtp this for metrics Jake implemented. * Per amueller; it's a long paper, ref the figure. * Clean up a few things. * Too many blank lines for flake8 * Bad scipy slink input * Flake8 fixes * Clean up cython a little; fix typo/carryover * Convert memoryview to numpy array on return * Just convert to the correct dtype * Update sklearn/cluster/_hierarchical.pyx Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Update sklearn/cluster/_hierarchical.pyx Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Update sklearn/cluster/_hierarchical.pyx Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Update sklearn/cluster/tests/test_hierarchical.py Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Fixes as per @NicolasHug suggestions. * Update renaming of params in test_hierarchical * Relative import? * Ah, it got renamed in master... * A bad merge on my part. * In principle this is in sklearn.neighbors now... * No; not that way... * Declare dim before use. * Update sklearn/cluster/tests/test_hierarchical.py Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Remaining fixes per Nicolas Hug. * Update sklearn/cluster/tests/test_hierarchical.py Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Fix flake8 issues. * Switch from stable to mergesort per jnotham * Update sklearn/cluster/_hierarchical.py Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com> * Skip checks that are already validated. * Update docstring per Gael's suggestion * Add a benchmark script for agglomerative clustering * Fix some flake8 issues * No flake8 on the one line * Update parameters and output for benchmark hierarchical * Switch to 2D plotting for hierarchical benchmark * Wrong colormap name * Formatting fpr bench hierarchical * Add an item to WhatsNew
- Loading branch information
Showing
6 changed files
with
229 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from collections import defaultdict | ||
from time import time | ||
|
||
import numpy as np | ||
from numpy import random as nr | ||
|
||
from sklearn.cluster import AgglomerativeClustering | ||
|
||
|
||
def compute_bench(samples_range, features_range): | ||
|
||
it = 0 | ||
results = defaultdict(lambda: []) | ||
|
||
max_it = len(samples_range) * len(features_range) | ||
for n_samples in samples_range: | ||
for n_features in features_range: | ||
it += 1 | ||
print('==============================') | ||
print('Iteration %03d of %03d' % (it, max_it)) | ||
print('n_samples %05d; n_features %02d' % (n_samples, n_features)) | ||
print('==============================') | ||
print() | ||
data = nr.randint(-50, 51, (n_samples, n_features)) | ||
|
||
for linkage in ("single", "average", "complete", "ward"): | ||
print(linkage.capitalize()) | ||
tstart = time() | ||
AgglomerativeClustering( | ||
linkage=linkage, | ||
n_clusters=10 | ||
).fit(data) | ||
|
||
delta = time() - tstart | ||
print("Speed: %0.3fs" % delta) | ||
print() | ||
|
||
results[linkage].append(delta) | ||
|
||
return results | ||
|
||
|
||
if __name__ == '__main__': | ||
import matplotlib.pyplot as plt | ||
|
||
samples_range = np.linspace(1000, 15000, 8).astype(np.int) | ||
features_range = np.array([2, 10, 20, 50]) | ||
|
||
results = compute_bench(samples_range, features_range) | ||
|
||
max_time = max([max(i) for i in [t for (label, t) in results.items()]]) | ||
|
||
colors = plt.get_cmap('tab10')(np.linspace(0, 1, 10))[:4] | ||
lines = {linkage: None for linkage in results.keys()} | ||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) | ||
fig.suptitle( | ||
'Scikit-learn agglomerative clustering benchmark results', | ||
fontsize=16 | ||
) | ||
for c, (label, timings) in zip(colors, | ||
sorted(results.items())): | ||
timing_by_samples = np.asarray(timings).reshape( | ||
samples_range.shape[0], | ||
features_range.shape[0] | ||
) | ||
|
||
for n in range(timing_by_samples.shape[1]): | ||
ax = axs.flatten()[n] | ||
lines[label], = ax.plot( | ||
samples_range, | ||
timing_by_samples[:, n], | ||
color=c, | ||
label=label | ||
) | ||
ax.set_title('n_features = %d' % features_range[n]) | ||
if n >= 2: | ||
ax.set_xlabel('n_samples') | ||
if n % 2 == 0: | ||
ax.set_ylabel('time (s)') | ||
|
||
fig.subplots_adjust(right=0.8) | ||
fig.legend([lines[link] for link in sorted(results.keys())], | ||
sorted(results.keys()), loc="center right", fontsize=8) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters