diff --git a/doc/introduction.rst b/doc/introduction.rst index 0612c6807..933a8eca5 100644 --- a/doc/introduction.rst +++ b/doc/introduction.rst @@ -9,7 +9,8 @@ Introduction API's of imbalanced-learn samplers ---------------------------------- -The available samplers follows the scikit-learn API using the base estimator and adding a sampling functionality throw the ``sample`` method:: +The available samplers follows the scikit-learn API using the base estimator +and adding a sampling functionality throw the ``sample`` method:: :Estimator: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 142668f27..3852051ed 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -40,6 +40,10 @@ Bug fixes New features ~~~~~~~~~~~~ +- :class:`under_sampling.ClusterCentroids` accepts a parameter ``voting`` + allowing to use nearest-neighbors of centroids instead of centroids + themselves. It is more efficient for sparse input. By `Guillaume Lemaitre`_. + - Turn off steps in :class:`pipeline.Pipeline` using the `None` object. By `Christos Aridas`_. diff --git a/examples/under-sampling/plot_cluster_centroids.py b/examples/under-sampling/plot_cluster_centroids.py index d6f7eaf25..d13b669e3 100644 --- a/examples/under-sampling/plot_cluster_centroids.py +++ b/examples/under-sampling/plot_cluster_centroids.py @@ -24,7 +24,7 @@ X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, n_features=20, n_clusters_per_class=1, - n_samples=200, random_state=10) + n_samples=50, random_state=10) # Instanciate a PCA object for the sake of easy visualisation pca = PCA(n_components=2) @@ -34,10 +34,15 @@ # Apply Cluster Centroids cc = ClusterCentroids() X_resampled, y_resampled = cc.fit_sample(X, y) -X_res_vis = pca.transform(X_resampled) +X_res_vis_soft = pca.transform(X_resampled) + +# Use hard voting instead of soft voting +cc = ClusterCentroids(voting='hard') +X_resampled, y_resampled = cc.fit_sample(X, y) +X_res_vis_hard = pca.transform(X_resampled) # Two subplots, unpack the axes array immediately -f, (ax1, ax2) = plt.subplots(1, 2) +f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) c0 = ax1.scatter(X_vis[y == 0, 0], X_vis[y == 0, 1], label="Class #0", alpha=0.5) @@ -45,14 +50,30 @@ alpha=0.5) ax1.set_title('Original set') -ax2.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1], +ax2.scatter(X_res_vis_soft[y_resampled == 0, 0], + X_res_vis_soft[y_resampled == 0, 1], + label="Class #0", alpha=.5) +ax2.scatter(X_res_vis_soft[y_resampled == 1, 0], + X_res_vis_soft[y_resampled == 1, 1], + label="Class #1", alpha=.5) +c2 = ax2.scatter(X_vis[y == 1, 0], + X_vis[y == 1, 1], label="Original #1", + alpha=0.2) +ax2.set_title('Cluster centroids with soft voting') + +ax3.scatter(X_res_vis_hard[y_resampled == 0, 0], + X_res_vis_hard[y_resampled == 0, 1], label="Class #0", alpha=.5) -ax2.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1], +ax3.scatter(X_res_vis_hard[y_resampled == 1, 0], + X_res_vis_hard[y_resampled == 1, 1], label="Class #1", alpha=.5) -ax2.set_title('Cluster centroids') +ax3.scatter(X_vis[y == 1, 0], + X_vis[y == 1, 1], + alpha=0.2) +ax3.set_title('Cluster centroids with hard voting') # make nice plotting -for ax in (ax1, ax2): +for ax in (ax1, ax2, ax3): ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.get_xaxis().tick_bottom() @@ -62,7 +83,8 @@ ax.set_xlim([-6, 8]) ax.set_ylim([-6, 6]) -plt.figlegend((c0, c1), ('Class #0', 'Class #1'), loc='lower center', - ncol=2, labelspacing=0.) +plt.figlegend((c0, c1), ('Class #0', 'Class #1', 'Original Class #1'), + loc='lower center', + ncol=3, labelspacing=0.) plt.tight_layout(pad=3) plt.show() diff --git a/imblearn/under_sampling/prototype_generation/cluster_centroids.py b/imblearn/under_sampling/prototype_generation/cluster_centroids.py index 0cfebb193..28d70b6c9 100644 --- a/imblearn/under_sampling/prototype_generation/cluster_centroids.py +++ b/imblearn/under_sampling/prototype_generation/cluster_centroids.py @@ -12,10 +12,13 @@ from scipy import sparse from sklearn.cluster import KMeans +from sklearn.neighbors import NearestNeighbors from sklearn.utils import safe_indexing from ..base import BaseUnderSampler +VOTING_KIND = ('auto', 'hard', 'soft') + class ClusterCentroids(BaseUnderSampler): """Perform under-sampling by generating centroids based on @@ -58,6 +61,18 @@ class ClusterCentroids(BaseUnderSampler): estimator : object, optional(default=KMeans()) Pass a :class:`sklearn.cluster.KMeans` estimator. + voting : str, optional (default='auto') + Voting strategy to generate the new samples: + + - If ``'hard'``, the nearest-neighbors of the centroids found using the + clustering algorithm will be used. + - If ``'soft'``, the centroids found by the clustering algorithm will + be used. + - If ``'auto'``, if the input is sparse, it will default on ``'hard'`` + otherwise, ``'soft'`` will be used. + + .. versionadded:: 0.3.0 + n_jobs : int, optional (default=1) The number of threads to open if possible. @@ -91,10 +106,12 @@ def __init__(self, ratio='auto', random_state=None, estimator=None, + voting='auto', n_jobs=1): super(ClusterCentroids, self).__init__( ratio=ratio, random_state=random_state) self.estimator = estimator + self.voting = voting self.n_jobs = n_jobs def _validate_estimator(self): @@ -108,6 +125,22 @@ def _validate_estimator(self): raise ValueError('`estimator` has to be a KMeans clustering.' ' Got {} instead.'.format(type(self.estimator))) + def _generate_sample(self, X, y, centroids, target_class): + if self.voting_ == 'hard': + nearest_neighbors = NearestNeighbors(n_neighbors=1) + nearest_neighbors.fit(X, y) + indices = nearest_neighbors.kneighbors(centroids, + return_distance=False) + X_new = safe_indexing(X, np.squeeze(indices)) + else: + if sparse.issparse(X): + X_new = sparse.csr_matrix(centroids) + else: + X_new = centroids + y_new = np.array([target_class] * centroids.shape[0]) + + return X_new, y_new + def _sample(self, X, y): """Resample the dataset. @@ -131,28 +164,37 @@ def _sample(self, X, y): """ self._validate_estimator() - idx_under = np.empty((0, ), dtype=int) - centroids, y_resampled = [], [] + if self.voting == 'auto': + if sparse.issparse(X): + self.voting_ = 'hard' + else: + self.voting_ = 'soft' + else: + if self.voting in VOTING_KIND: + self.voting_ = self.voting + else: + raise ValueError("'voting' needs to be one of {}. Got {}" + " instead.".format(VOTING_KIND, self.voting)) + + X_resampled, y_resampled = [], [] for target_class in np.unique(y): if target_class in self.ratio_.keys(): n_samples = self.ratio_[target_class] self.estimator_.set_params(**{'n_clusters': n_samples}) self.estimator_.fit(X[y == target_class]) - centroids.append(self.estimator_.cluster_centers_) - y_resampled += [target_class] * n_samples - + X_new, y_new = self._generate_sample( + X, y, self.estimator_.cluster_centers_, target_class) + X_resampled.append(X_new) + y_resampled.append(y_new) else: target_class_indices = np.flatnonzero(y == target_class) - idx_under = np.concatenate( - (idx_under, target_class_indices), axis=0) - - X_resampled = np.concatenate((centroids)) + X_resampled.append(safe_indexing(X, target_class_indices)) + y_resampled.append(safe_indexing(y, target_class_indices)) if sparse.issparse(X): - X_resampled = sparse.vstack([sparse.csr_matrix(X_resampled), - safe_indexing(X, idx_under)]) + X_resampled = sparse.vstack(X_resampled) else: - X_resampled = np.vstack((X_resampled, safe_indexing(X, idx_under))) - y_resampled = np.hstack((y_resampled, safe_indexing(y, idx_under))) + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) return X_resampled, np.array(y_resampled) diff --git a/imblearn/under_sampling/prototype_generation/tests/test_cluster_centroids.py b/imblearn/under_sampling/prototype_generation/tests/test_cluster_centroids.py index 09e634832..1788a8402 100644 --- a/imblearn/under_sampling/prototype_generation/tests/test_cluster_centroids.py +++ b/imblearn/under_sampling/prototype_generation/tests/test_cluster_centroids.py @@ -4,11 +4,11 @@ from collections import Counter import numpy as np +from scipy import sparse from pytest import raises from sklearn.utils.testing import assert_allclose from sklearn.utils.testing import assert_array_equal - from sklearn.cluster import KMeans from imblearn.under_sampling import ClusterCentroids @@ -23,17 +23,26 @@ R_TOL = 1e-4 +def test_fit_sample_check_voting(): + cc = ClusterCentroids(random_state=RND_SEED) + cc.fit_sample(X, Y) + assert cc.voting_ == 'soft' + cc = ClusterCentroids(random_state=RND_SEED) + cc.fit_sample(sparse.csr_matrix(X), Y) + assert cc.voting_ == 'hard' + + def test_fit_sample_auto(): ratio = 'auto' cc = ClusterCentroids(ratio=ratio, random_state=RND_SEED) X_resampled, y_resampled = cc.fit_sample(X, Y) - X_gt = np.array([[0.06738818, -0.529627], - [0.17901516, 0.69860992], - [0.094035, -2.55298982], - [0.92923648, 0.76103773], + X_gt = np.array([[0.92923648, 0.76103773], [0.47104475, 0.44386323], - [0.13347175, 0.12167502]]) - y_gt = np.array([1, 1, 1, 0, 0, 0]) + [0.13347175, 0.12167502], + [0.06738818, -0.529627], + [0.17901516, 0.69860992], + [0.094035, -2.55298982]]) + y_gt = np.array([0, 0, 0, 1, 1, 1]) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) @@ -42,16 +51,16 @@ def test_fit_sample_half(): ratio = .5 cc = ClusterCentroids(ratio=ratio, random_state=RND_SEED) X_resampled, y_resampled = cc.fit_sample(X, Y) - X_gt = np.array([[0.09125309, -0.85409574], + X_gt = np.array([[0.92923648, 0.76103773], + [0.47104475, 0.44386323], + [0.13347175, 0.12167502], + [0.09125309, -0.85409574], [0.19220316, 0.32337101], [0.094035, -2.55298982], [0.20792588, 1.49407907], [0.04352327, -0.20515826], - [0.12372842, 0.6536186], - [0.92923648, 0.76103773], - [0.47104475, 0.44386323], - [0.13347175, 0.12167502]]) - y_gt = np.array([1, 1, 1, 1, 1, 1, 0, 0, 0]) + [0.12372842, 0.6536186]]) + y_gt = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1]) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) @@ -75,21 +84,48 @@ def test_fit_sample_object(): ratio=ratio, random_state=RND_SEED, estimator=cluster) X_resampled, y_resampled = cc.fit_sample(X, Y) - X_gt = np.array([[0.06738818, -0.529627], + X_gt = np.array([[0.92923648, 0.76103773], + [0.47104475, 0.44386323], + [0.13347175, 0.12167502], + [0.06738818, -0.529627], [0.17901516, 0.69860992], - [0.094035, -2.55298982], - [0.92923648, 0.76103773], + [0.094035, -2.55298982]]) + y_gt = np.array([0, 0, 0, 1, 1, 1]) + assert_allclose(X_resampled, X_gt, rtol=R_TOL) + assert_array_equal(y_resampled, y_gt) + + +def test_fit_hard_voting(): + ratio = 'auto' + voting = 'hard' + cluster = KMeans(random_state=RND_SEED) + cc = ClusterCentroids( + ratio=ratio, random_state=RND_SEED, estimator=cluster, + voting=voting) + + X_resampled, y_resampled = cc.fit_sample(X, Y) + X_gt = np.array([[0.92923648, 0.76103773], [0.47104475, 0.44386323], - [0.13347175, 0.12167502]]) - y_gt = np.array([1, 1, 1, 0, 0, 0]) + [0.13347175, 0.12167502], + [0.09125309, -0.85409574], + [0.12372842, 0.6536186], + [0.094035, -2.55298982]]) + y_gt = np.array([0, 0, 0, 1, 1, 1]) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) + for x in X_resampled: + assert np.any(np.all(x == X, axis=1)) -def test_fit_sample_wrong_object(): +def test_fit_sample_error(): ratio = 'auto' cluster = 'rnd' cc = ClusterCentroids( ratio=ratio, random_state=RND_SEED, estimator=cluster) with raises(ValueError, match="has to be a KMeans clustering"): cc.fit_sample(X, Y) + + voting = 'unknown' + cc = ClusterCentroids(ratio=ratio, voting=voting, random_state=RND_SEED) + with raises(ValueError, match="needs to be one of"): + cc.fit_sample(X, Y) diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 2df602a08..319a57a14 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -271,6 +271,7 @@ def check_samplers_sparse(name, Sampler): elif isinstance(Sampler(), ClusterCentroids): # set KMeans to full since it support sparse and dense samplers = [Sampler(random_state=0, + voting='soft', estimator=KMeans(random_state=1, algorithm='full'))] else: