From 5f7e6de69e425cfdfeeb82cf692b81fe22589dd5 Mon Sep 17 00:00:00 2001 From: huntzhan Date: Fri, 23 Aug 2019 11:53:31 +0800 Subject: [PATCH 1/6] Try csr support. --- sklearn/cluster/optics_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index ecf5fa6a2bcc0..732c7cfb352fc 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -233,7 +233,7 @@ def fit(self, X, y=None): self : instance of OPTICS The instance. """ - X = check_array(X, dtype=np.float) + X = check_array(X, accept_sparse='csr') if self.cluster_method not in ['dbscan', 'xi']: raise ValueError("cluster_method should be one of" From 18425aaffc785024d8baaa2bb31e4fd52427eed8 Mon Sep 17 00:00:00 2001 From: huntzhan Date: Fri, 23 Aug 2019 16:15:39 +0800 Subject: [PATCH 2/6] Change the default metric of OPTICS to euclidean. --- sklearn/cluster/optics_.py | 4 ++-- sklearn/cluster/tests/test_optics.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index 732c7cfb352fc..26734da87d778 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -194,7 +194,7 @@ class OPTICS(BaseEstimator, ClusterMixin): the Conference "Lernen, Wissen, Daten, Analysen" (LWDA) (2018): 318-329. """ - def __init__(self, min_samples=5, max_eps=np.inf, metric='minkowski', p=2, + def __init__(self, min_samples=5, max_eps=np.inf, metric='euclidean', p=2, metric_params=None, cluster_method='xi', eps=None, xi=0.05, predecessor_correction=True, min_cluster_size=None, algorithm='auto', leaf_size=30, n_jobs=None): @@ -517,7 +517,7 @@ def _set_reach_dist(core_distances_, reachability_, predecessor_, # the same logic as neighbors, p is ignored if explicitly set # in the dict params _params['p'] = p - dists = pairwise_distances(P, np.take(X, unproc, axis=0), + dists = pairwise_distances(P, X[unproc], metric, n_jobs=None, **_params).ravel() diff --git a/sklearn/cluster/tests/test_optics.py b/sklearn/cluster/tests/test_optics.py index 5ae8b3f898fcf..b7e46fe18e921 100644 --- a/sklearn/cluster/tests/test_optics.py +++ b/sklearn/cluster/tests/test_optics.py @@ -352,7 +352,7 @@ def test_compare_to_ELKI(): # Tests against known extraction array # Does NOT work with metric='euclidean', because sklearn euclidean has # worse numeric precision. 'minkowski' is slower but more accurate. - clust1 = OPTICS(min_samples=5).fit(X) + clust1 = OPTICS(metric='minkowski', min_samples=5).fit(X) assert_array_equal(clust1.ordering_, np.array(o1)) assert_array_equal(clust1.predecessor_[clust1.ordering_], np.array(p1)) @@ -386,7 +386,7 @@ def test_compare_to_ELKI(): 11, 19, 15, 10, 47, -1, 20, 22, 25, 25, 25, 25, 22, 22, 23, -1, 30, 30, 34, 34, 34, 32, 32, 37, 38, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] - clust2 = OPTICS(min_samples=5, max_eps=0.5).fit(X) + clust2 = OPTICS(metric='minkowski', min_samples=5, max_eps=0.5).fit(X) assert_array_equal(clust2.ordering_, np.array(o2)) assert_array_equal(clust2.predecessor_[clust2.ordering_], np.array(p2)) From f712b46baa1ad3ca25d52f85f01162e4e613f1a7 Mon Sep 17 00:00:00 2001 From: huntzhan Date: Mon, 26 Aug 2019 11:14:49 +0800 Subject: [PATCH 3/6] Retain the default metric minkowski. --- sklearn/cluster/optics_.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index 26734da87d778..d3faed82d8650 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -19,6 +19,7 @@ from ..neighbors import NearestNeighbors from ..base import BaseEstimator, ClusterMixin from ..metrics import pairwise_distances +from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS class OPTICS(BaseEstimator, ClusterMixin): @@ -194,7 +195,7 @@ class OPTICS(BaseEstimator, ClusterMixin): the Conference "Lernen, Wissen, Daten, Analysen" (LWDA) (2018): 318-329. """ - def __init__(self, min_samples=5, max_eps=np.inf, metric='euclidean', p=2, + def __init__(self, min_samples=5, max_eps=np.inf, metric='minkowski', p=2, metric_params=None, cluster_method='xi', eps=None, xi=0.05, predecessor_correction=True, min_cluster_size=None, algorithm='auto', leaf_size=30, n_jobs=None): @@ -222,7 +223,8 @@ def fit(self, X, y=None): Parameters ---------- X : array, shape (n_samples, n_features), or (n_samples, n_samples) \ -if metric=’precomputed’. +if metric=’precomputed’, or sparse matrix \ + if metric in ['cityblock', 'cosine', 'euclidean', 'haversine', 'l2', 'l1', 'manhattan']. A feature array, or array of distances between samples if metric='precomputed'. @@ -233,7 +235,10 @@ def fit(self, X, y=None): self : instance of OPTICS The instance. """ - X = check_array(X, accept_sparse='csr') + if self.metric in PAIRWISE_DISTANCE_FUNCTIONS: + X = check_array(X, accept_sparse='csr') + else: + X = check_array(X) if self.cluster_method not in ['dbscan', 'xi']: raise ValueError("cluster_method should be one of" From 876bb2935599e37d5c827a83fbcee12829101e6d Mon Sep 17 00:00:00 2001 From: huntzhan Date: Mon, 26 Aug 2019 11:17:14 +0800 Subject: [PATCH 4/6] Undo tests. --- sklearn/cluster/tests/test_optics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/tests/test_optics.py b/sklearn/cluster/tests/test_optics.py index b7e46fe18e921..5ae8b3f898fcf 100644 --- a/sklearn/cluster/tests/test_optics.py +++ b/sklearn/cluster/tests/test_optics.py @@ -352,7 +352,7 @@ def test_compare_to_ELKI(): # Tests against known extraction array # Does NOT work with metric='euclidean', because sklearn euclidean has # worse numeric precision. 'minkowski' is slower but more accurate. - clust1 = OPTICS(metric='minkowski', min_samples=5).fit(X) + clust1 = OPTICS(min_samples=5).fit(X) assert_array_equal(clust1.ordering_, np.array(o1)) assert_array_equal(clust1.predecessor_[clust1.ordering_], np.array(p1)) @@ -386,7 +386,7 @@ def test_compare_to_ELKI(): 11, 19, 15, 10, 47, -1, 20, 22, 25, 25, 25, 25, 22, 22, 23, -1, 30, 30, 34, 34, 34, 32, 32, 37, 38, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] - clust2 = OPTICS(metric='minkowski', min_samples=5, max_eps=0.5).fit(X) + clust2 = OPTICS(min_samples=5, max_eps=0.5).fit(X) assert_array_equal(clust2.ordering_, np.array(o2)) assert_array_equal(clust2.predecessor_[clust2.ordering_], np.array(p2)) From 71b5c1e88da2424b2289779c29c20c78040ee861 Mon Sep 17 00:00:00 2001 From: huntzhan Date: Mon, 26 Aug 2019 11:24:18 +0800 Subject: [PATCH 5/6] Fix flake8. --- sklearn/cluster/optics_.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index d3faed82d8650..3864750901509 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -223,8 +223,9 @@ def fit(self, X, y=None): Parameters ---------- X : array, shape (n_samples, n_features), or (n_samples, n_samples) \ -if metric=’precomputed’, or sparse matrix \ - if metric in ['cityblock', 'cosine', 'euclidean', 'haversine', 'l2', 'l1', 'manhattan']. +if metric=’precomputed’, or sparse matrix (n_samples, n_features) if metric + in ['cityblock', 'cosine', 'euclidean', 'haversine', 'l2', 'l1', + 'manhattan']. A feature array, or array of distances between samples if metric='precomputed'. From 6f498a9206f8e0b0a7422bb7eeb571918d4f7dd3 Mon Sep 17 00:00:00 2001 From: huntzhan Date: Mon, 26 Aug 2019 12:40:06 +0800 Subject: [PATCH 6/6] Add sparse tests. --- sklearn/cluster/optics_.py | 4 +- sklearn/cluster/tests/test_optics.py | 197 ++++++++++++++++++++------- 2 files changed, 151 insertions(+), 50 deletions(-) diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index 3864750901509..be417b98cb7ab 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -236,7 +236,9 @@ def fit(self, X, y=None): self : instance of OPTICS The instance. """ - if self.metric in PAIRWISE_DISTANCE_FUNCTIONS: + # TODO: Support the sparse input for metric = 'precopmuted'. + if self.metric != 'precomputed' \ + and self.metric in PAIRWISE_DISTANCE_FUNCTIONS: X = check_array(X, accept_sparse='csr') else: X = check_array(X) diff --git a/sklearn/cluster/tests/test_optics.py b/sklearn/cluster/tests/test_optics.py index 5ae8b3f898fcf..f251978a2eab4 100644 --- a/sklearn/cluster/tests/test_optics.py +++ b/sklearn/cluster/tests/test_optics.py @@ -5,6 +5,8 @@ import numpy as np import pytest +from scipy import sparse + from sklearn.datasets.samples_generator import make_blobs from sklearn.cluster.optics_ import (OPTICS, _extend_region, @@ -101,12 +103,22 @@ def test_extract_xi(): xi=0.4).fit(X) assert_array_equal(clust.labels_, expected_labels) + clust = OPTICS(min_samples=3, min_cluster_size=2, + max_eps=20, cluster_method='xi', + xi=0.4, metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + # check float min_samples and min_cluster_size clust = OPTICS(min_samples=0.1, min_cluster_size=0.08, max_eps=20, cluster_method='xi', xi=0.4).fit(X) assert_array_equal(clust.labels_, expected_labels) + clust = OPTICS(min_samples=0.1, min_cluster_size=0.08, + max_eps=20, cluster_method='xi', + xi=0.4, metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + X = np.vstack((C1, C2, C3, C4, C5, np.array([[100, 100]] * 2), C6)) expected_labels = np.r_[[1] * 5, [3] * 5, [2] * 5, [0] * 5, [2] * 5, -1, -1, [4] * 5] @@ -118,6 +130,11 @@ def test_extract_xi(): # this may fail if the predecessor correction is not at work! assert_array_equal(clust.labels_, expected_labels) + clust = OPTICS(min_samples=3, min_cluster_size=3, + max_eps=20, cluster_method='xi', + xi=0.3, metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + C1 = [[0, 0], [0, 0.1], [0, -.1], [0.1, 0]] C2 = [[10, 10], [10, 9], [10, 11], [9, 10]] C3 = [[100, 100], [100, 90], [100, 110], [90, 100]] @@ -130,6 +147,11 @@ def test_extract_xi(): xi=0.04).fit(X) assert_array_equal(clust.labels_, expected_labels) + clust = OPTICS(min_samples=2, min_cluster_size=2, + max_eps=np.inf, cluster_method='xi', + xi=0.04, metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + def test_cluster_hierarchy_(): rng = np.random.RandomState(0) @@ -144,33 +166,42 @@ def test_cluster_hierarchy_(): diff = np.sum(clusters - np.array([[0, 99], [0, 199]])) assert diff / len(X) < 0.05 + clust = OPTICS(min_samples=20, xi=.1, + metric='euclidean').fit(sparse.lil_matrix(X)) + clusters = clust.cluster_hierarchy_ + assert clusters.shape == (2, 2) + diff = np.sum(clusters - np.array([[0, 99], [0, 199]])) + assert diff / len(X) < 0.05 + def test_correct_number_of_clusters(): # in 'auto' mode n_clusters = 3 X = generate_clustered_data(n_clusters=n_clusters) - # Parameters chosen specifically for this task. - # Compute OPTICS - clust = OPTICS(max_eps=5.0 * 6.0, min_samples=4, xi=.1) - clust.fit(X) - # number of clusters, ignoring noise if present - n_clusters_1 = len(set(clust.labels_)) - int(-1 in clust.labels_) - assert n_clusters_1 == n_clusters - # check attribute types and sizes - assert clust.labels_.shape == (len(X),) - assert clust.labels_.dtype.kind == 'i' + for metric in ['minkowski', 'euclidean']: + # Parameters chosen specifically for this task. + # Compute OPTICS + clust = OPTICS(max_eps=5.0 * 6.0, min_samples=4, xi=.1, metric=metric) + clust.fit(X if metric == 'minkowski' else sparse.lil_matrix(X)) + # number of clusters, ignoring noise if present + n_clusters_1 = len(set(clust.labels_)) - int(-1 in clust.labels_) + assert n_clusters_1 == n_clusters - assert clust.reachability_.shape == (len(X),) - assert clust.reachability_.dtype.kind == 'f' + # check attribute types and sizes + assert clust.labels_.shape == (len(X),) + assert clust.labels_.dtype.kind == 'i' - assert clust.core_distances_.shape == (len(X),) - assert clust.core_distances_.dtype.kind == 'f' + assert clust.reachability_.shape == (len(X),) + assert clust.reachability_.dtype.kind == 'f' - assert clust.ordering_.shape == (len(X),) - assert clust.ordering_.dtype.kind == 'i' - assert set(clust.ordering_) == set(range(len(X))) + assert clust.core_distances_.shape == (len(X),) + assert clust.core_distances_.dtype.kind == 'f' + + assert clust.ordering_.shape == (len(X),) + assert clust.ordering_.dtype.kind == 'i' + assert set(clust.ordering_) == set(range(len(X))) def test_minimum_number_of_sample_check(): @@ -184,6 +215,14 @@ def test_minimum_number_of_sample_check(): # Run the fit assert_raise_message(ValueError, msg, clust.fit, X) + # Compute OPTICS + X = sparse.lil_matrix([[1, 1]]) + clust = OPTICS(max_eps=5.0 * 0.3, min_samples=10, min_cluster_size=1, + metric='euclidean') + + # Run the fit + assert_raise_message(ValueError, msg, clust.fit, X) + def test_bad_extract(): # Test an extraction of eps too close to original eps @@ -198,6 +237,13 @@ def test_bad_extract(): eps=0.3, min_samples=10) assert_raise_message(ValueError, msg, clust.fit, X) + # Compute OPTICS + clust = OPTICS(max_eps=5.0 * 0.03, + cluster_method='dbscan', + eps=0.3, min_samples=10, + metric='euclidean') + assert_raise_message(ValueError, msg, clust.fit, sparse.lil_matrix(X)) + def test_bad_reachability(): msg = "All reachability values are inf. Set a larger max_eps." @@ -209,6 +255,11 @@ def test_bad_reachability(): clust = OPTICS(max_eps=5.0 * 0.003, min_samples=10, eps=0.015) clust.fit(X) + with pytest.warns(UserWarning, match=msg): + clust = OPTICS(max_eps=5.0 * 0.003, min_samples=10, eps=0.015, + metric='euclidean') + clust.fit(sparse.lil_matrix(X)) + def test_close_extract(): # Test extract where extraction eps is close to scaled max_eps @@ -223,32 +274,43 @@ def test_close_extract(): # Cluster ordering starts at 0; max cluster label = 2 is 3 clusters assert max(clust.labels_) == 2 + # Compute OPTICS + clust = OPTICS(max_eps=1.0, cluster_method='dbscan', + eps=0.3, min_samples=10, + metric='euclidean').fit(sparse.lil_matrix(X)) + # Cluster ordering starts at 0; max cluster label = 2 is 3 clusters + assert max(clust.labels_) == 2 + @pytest.mark.parametrize('eps', [0.1, .3, .5]) @pytest.mark.parametrize('min_samples', [3, 10, 20]) def test_dbscan_optics_parity(eps, min_samples): # Test that OPTICS clustering labels are <= 5% difference of DBSCAN - centers = [[1, 1], [-1, -1], [1, -1]] - X, labels_true = make_blobs(n_samples=750, centers=centers, - cluster_std=0.4, random_state=0) + for metric in ['minkowski', 'euclidean']: - # calculate optics with dbscan extract at 0.3 epsilon - op = OPTICS(min_samples=min_samples, cluster_method='dbscan', - eps=eps).fit(X) + centers = [[1, 1], [-1, -1], [1, -1]] + _X, labels_true = make_blobs(n_samples=750, centers=centers, + cluster_std=0.4, random_state=0) + X = _X if metric == 'minkowski' else sparse.lil_matrix(_X) - # calculate dbscan labels - db = DBSCAN(eps=eps, min_samples=min_samples).fit(X) + # calculate optics with dbscan extract at 0.3 epsilon + op = OPTICS(min_samples=min_samples, cluster_method='dbscan', + eps=eps, + metric=metric).fit(X) - contingency = contingency_matrix(db.labels_, op.labels_) - agree = min(np.sum(np.max(contingency, axis=0)), - np.sum(np.max(contingency, axis=1))) - disagree = X.shape[0] - agree + # calculate dbscan labels + db = DBSCAN(eps=eps, min_samples=min_samples).fit(X) - percent_mismatch = np.round((disagree - 1) / X.shape[0], 2) + contingency = contingency_matrix(db.labels_, op.labels_) + agree = min(np.sum(np.max(contingency, axis=0)), + np.sum(np.max(contingency, axis=1))) + disagree = X.shape[0] - agree - # verify label mismatch is <= 5% labels - assert percent_mismatch <= 0.05 + percent_mismatch = np.round((disagree - 1) / X.shape[0], 2) + + # verify label mismatch is <= 5% labels + assert percent_mismatch <= 0.05 def test_min_samples_edge_case(): @@ -263,12 +325,24 @@ def test_min_samples_edge_case(): xi=0.04).fit(X) assert_array_equal(clust.labels_, expected_labels) + clust = OPTICS(min_samples=3, + max_eps=7, cluster_method='xi', + xi=0.04, + metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + expected_labels = np.r_[[0] * 3, [1] * 3, [-1] * 3] clust = OPTICS(min_samples=3, max_eps=3, cluster_method='xi', xi=0.04).fit(X) assert_array_equal(clust.labels_, expected_labels) + clust = OPTICS(min_samples=3, + max_eps=3, cluster_method='xi', + xi=0.04, + metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + expected_labels = np.r_[[-1] * 9] with pytest.warns(UserWarning, match="All reachability values"): clust = OPTICS(min_samples=4, @@ -276,20 +350,33 @@ def test_min_samples_edge_case(): xi=0.04).fit(X) assert_array_equal(clust.labels_, expected_labels) + with pytest.warns(UserWarning, match="All reachability values"): + clust = OPTICS(min_samples=4, + max_eps=3, cluster_method='xi', + xi=0.04, + metric='euclidean').fit(sparse.lil_matrix(X)) + assert_array_equal(clust.labels_, expected_labels) + # try arbitrary minimum sizes @pytest.mark.parametrize('min_cluster_size', range(2, X.shape[0] // 10, 23)) def test_min_cluster_size(min_cluster_size): - redX = X[::2] # reduce for speed - clust = OPTICS(min_samples=9, min_cluster_size=min_cluster_size).fit(redX) - cluster_sizes = np.bincount(clust.labels_[clust.labels_ != -1]) - if cluster_sizes.size: - assert min(cluster_sizes) >= min_cluster_size - # check behaviour is the same when min_cluster_size is a fraction - clust_frac = OPTICS(min_samples=9, - min_cluster_size=min_cluster_size / redX.shape[0]) - clust_frac.fit(redX) - assert_array_equal(clust.labels_, clust_frac.labels_) + _redX = X[::2] # reduce for speed + + for metric in ['minkowski', 'euclidean']: + redX = _redX if metric == 'minkowski' else sparse.lil_matrix(_redX) + + clust = OPTICS(min_samples=9, min_cluster_size=min_cluster_size, + metric=metric).fit(redX) + cluster_sizes = np.bincount(clust.labels_[clust.labels_ != -1]) + if cluster_sizes.size: + assert min(cluster_sizes) >= min_cluster_size + # check behaviour is the same when min_cluster_size is a fraction + clust_frac = OPTICS(min_samples=9, + min_cluster_size=min_cluster_size / redX.shape[0], + metric=metric) + clust_frac.fit(redX) + assert_array_equal(clust.labels_, clust_frac.labels_) @pytest.mark.parametrize('min_cluster_size', [0, -1, 1.1, 2.2]) @@ -298,21 +385,33 @@ def test_min_cluster_size_invalid(min_cluster_size): with pytest.raises(ValueError, match="must be a positive integer or a "): clust.fit(X) + clust = OPTICS(min_cluster_size=min_cluster_size, metric='euclidean') + with pytest.raises(ValueError, match="must be a positive integer or a "): + clust.fit(sparse.lil_matrix(X)) + def test_min_cluster_size_invalid2(): clust = OPTICS(min_cluster_size=len(X) + 1) with pytest.raises(ValueError, match="must be no greater than the "): clust.fit(X) + clust = OPTICS(min_cluster_size=len(X) + 1, metric='euclidean') + with pytest.raises(ValueError, match="must be no greater than the "): + clust.fit(sparse.lil_matrix(X)) + def test_processing_order(): - # Ensure that we consider all unprocessed points, - # not only direct neighbors. when picking the next point. - Y = [[0], [10], [-10], [25]] - clust = OPTICS(min_samples=3, max_eps=15).fit(Y) - assert_array_equal(clust.reachability_, [np.inf, 10, 10, 15]) - assert_array_equal(clust.core_distances_, [10, 15, np.inf, np.inf]) - assert_array_equal(clust.ordering_, [0, 1, 2, 3]) + for metric in ['minkowski', 'euclidean']: + + # Ensure that we consider all unprocessed points, + # not only direct neighbors. when picking the next point. + _Y = [[0], [10], [-10], [25]] + Y = _Y if metric == 'minkowski' else sparse.lil_matrix(_Y) + + clust = OPTICS(min_samples=3, max_eps=15, metric=metric).fit(Y) + assert_array_equal(clust.reachability_, [np.inf, 10, 10, 15]) + assert_array_equal(clust.core_distances_, [10, 15, np.inf, np.inf]) + assert_array_equal(clust.ordering_, [0, 1, 2, 3]) def test_compare_to_ELKI():