From 1d945457b2f2565d31b35086f2b978adaf69d64e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 10 Jul 2023 12:32:10 +0200 Subject: [PATCH 1/3] FIX/DEPR follow literature for the implementation of NCR --- doc/whats_new/v0.12.rst | 18 ++- .../_neighbourhood_cleaning_rule.py | 120 ++++++++++----- .../tests/test_neighbourhood_cleaning_rule.py | 140 +++++++++--------- 3 files changed, 174 insertions(+), 104 deletions(-) diff --git a/doc/whats_new/v0.12.rst b/doc/whats_new/v0.12.rst index 1063ffb05..6c5974231 100644 --- a/doc/whats_new/v0.12.rst +++ b/doc/whats_new/v0.12.rst @@ -6,6 +6,18 @@ Version 0.12.0 (Under development) Changelog --------- +Bug fixes +......... + +- Fix a bug in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule` where the + `kind_sel="all"` was not working as explained in the literature. + :pr:`1012` by :user:`Guillaume Lemaitre `. + +- Fix a bug in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule` where the + `threshold_cleaning` ratio was multiplied on the total number of samples instead of + the number of samples in the minority class. + :pr:`1012` by :user:`Guillaume Lemaitre `. + Deprecations ............ @@ -13,4 +25,8 @@ Deprecations :class:`~imblearn.under_sampling.CondensedNearestNeighbour` and :class:`~imblearn.under_sampling.OneSidedSelection`. `estimator_` will be removed in 0.14. - :pr:`xxx` by :user:`Guillaume Lemaitre `. + :pr:`1011` by :user:`Guillaume Lemaitre `. + +- Deprecate `kind_sel` in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule. + It will be removed in 0.14. The parameter does not have any effect. + :pr:`1012` by :user:`Guillaume Lemaitre `. diff --git a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py index f9c08ea56..188ba32f3 100644 --- a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py @@ -5,15 +5,17 @@ # License: MIT import numbers +import warnings from collections import Counter import numpy as np +from sklearn.base import clone +from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors from sklearn.utils import _safe_indexing -from ...utils import Substitution, check_neighbors_object +from ...utils import Substitution from ...utils._docstring import _n_jobs_docstring -from ...utils._param_validation import HasMethods, Interval, StrOptions -from ...utils.fixes import _mode +from ...utils._param_validation import HasMethods, Hidden, Interval, StrOptions from ..base import BaseCleaningSampler from ._edited_nearest_neighbours import EditedNearestNeighbours @@ -35,9 +37,14 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler): ---------- {sampling_strategy} + edited_nearest_neighbours : estimator object, default=None + The :class:`~imblearn.under_sampling.EditedNearestNeighbours` (ENN) + object to clean the dataset. If `None`, a default ENN is created with + `kind_sel="mode"` and `n_neighbors=n_neighbors`. + n_neighbors : int or estimator object, default=3 If ``int``, size of the neighbourhood to consider to compute the - nearest neighbors. If object, an estimator that inherits from + K-nearest neighbors. If object, an estimator that inherits from :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to find the nearest-neighbors. By default, it will be a 3-NN. @@ -52,6 +59,11 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler): The strategy `"all"` will be less conservative than `'mode'`. Thus, more samples will be removed when `kind_sel="all"` generally. + .. deprecated:: 0.12 + `kind_sel` is deprecated in 0.12 and will be removed in 0.14. + Currently the parameter has no effect and corresponds always to the + `"all"` strategy. + threshold_cleaning : float, default=0.5 Threshold used to whether consider a class or not during the cleaning after applying ENN. A class will be considered during cleaning when: @@ -70,9 +82,16 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler): corresponds to the class labels from which to sample and the values are the number of samples to sample. + edited_nearest_neighbours_ : estimator object + The edited nearest neighbour object used to make the first resampling. + nn_ : estimator object Validated K-nearest Neighbours object created from `n_neighbors` parameter. + classes_to_clean_ : list + The classes considered with under-sampling by `nn_` in the second cleaning + phase. + sample_indices_ : ndarray of shape (n_new_samples,) Indices of the samples selected. @@ -118,17 +137,21 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler): >>> ncr = NeighbourhoodCleaningRule() >>> X_res, y_res = ncr.fit_resample(X, y) >>> print('Resampled dataset shape %s' % Counter(y_res)) - Resampled dataset shape Counter({{1: 877, 0: 100}}) + Resampled dataset shape Counter({{1: 888, 0: 100}}) """ _parameter_constraints: dict = { **BaseCleaningSampler._parameter_constraints, + "edited_nearest_neighbours": [ + HasMethods(["fit_resample"]), + None, + ], "n_neighbors": [ Interval(numbers.Integral, 1, None, closed="left"), HasMethods(["kneighbors", "kneighbors_graph"]), ], - "kind_sel": [StrOptions({"all", "mode"})], - "threshold_cleaning": [Interval(numbers.Real, 0, 1, closed="neither")], + "kind_sel": [StrOptions({"all", "mode"}), Hidden(StrOptions({"deprecated"}))], + "threshold_cleaning": [Interval(numbers.Real, 0, None, closed="neither")], "n_jobs": [numbers.Integral, None], } @@ -136,12 +159,14 @@ def __init__( self, *, sampling_strategy="auto", + edited_nearest_neighbours=None, n_neighbors=3, - kind_sel="all", + kind_sel="deprecated", threshold_cleaning=0.5, n_jobs=None, ): super().__init__(sampling_strategy=sampling_strategy) + self.edited_nearest_neighbours = edited_nearest_neighbours self.n_neighbors = n_neighbors self.kind_sel = kind_sel self.threshold_cleaning = threshold_cleaning @@ -149,21 +174,38 @@ def __init__( def _validate_estimator(self): """Create the objects required by NCR.""" - self.nn_ = check_neighbors_object( - "n_neighbors", self.n_neighbors, additional_neighbor=1 - ) - self.nn_.set_params(**{"n_jobs": self.n_jobs}) + if isinstance(self.n_neighbors, numbers.Integral): + self.nn_ = KNeighborsClassifier( + n_neighbors=self.n_neighbors, n_jobs=self.n_jobs + ) + elif isinstance(self.n_neighbors, NearestNeighbors): + # backward compatibility when passing a NearestNeighbors object + self.nn_ = KNeighborsClassifier( + n_neighbors=self.n_neighbors.n_neighbors - 1, n_jobs=self.n_jobs + ) + else: + self.nn_ = clone(self.n_neighbors) + + if self.edited_nearest_neighbours is None: + self.edited_nearest_neighbours_ = EditedNearestNeighbours( + sampling_strategy=self.sampling_strategy, + n_neighbors=self.n_neighbors, + kind_sel="mode", + n_jobs=self.n_jobs, + ) + else: + self.edited_nearest_neighbours_ = clone(self.edited_nearest_neighbours) def _fit_resample(self, X, y): + if self.kind_sel != "deprecated": + warnings.warn( + "`kind_sel` is deprecated in 0.12 and will be removed in 0.14. " + "It already has not effect and corresponds to the `'all'` option.", + FutureWarning, + ) self._validate_estimator() - enn = EditedNearestNeighbours( - sampling_strategy=self.sampling_strategy, - n_neighbors=self.n_neighbors, - kind_sel="mode", - n_jobs=self.n_jobs, - ) - enn.fit_resample(X, y) - index_not_a1 = enn.sample_indices_ + self.edited_nearest_neighbours_.fit_resample(X, y) + index_not_a1 = self.edited_nearest_neighbours_.sample_indices_ index_a1 = np.ones(y.shape, dtype=bool) index_a1[index_not_a1] = False index_a1 = np.flatnonzero(index_a1) @@ -172,30 +214,34 @@ def _fit_resample(self, X, y): target_stats = Counter(y) class_minority = min(target_stats, key=target_stats.get) # compute which classes to consider for cleaning for the A2 group - classes_under_sample = [ + self.classes_to_clean_ = [ c for c, n_samples in target_stats.items() if ( c in self.sampling_strategy_.keys() - and (n_samples > X.shape[0] * self.threshold_cleaning) + and (n_samples > target_stats[class_minority] * self.threshold_cleaning) ) ] - self.nn_.fit(X) + self.nn_.fit(X, y) + class_minority_indices = np.flatnonzero(y == class_minority) - X_class = _safe_indexing(X, class_minority_indices) - y_class = _safe_indexing(y, class_minority_indices) - nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:] - nnhood_label = y[nnhood_idx] - if self.kind_sel == "mode": - nnhood_label_majority, _ = _mode(nnhood_label, axis=1) - nnhood_bool = np.ravel(nnhood_label_majority) == y_class - else: # self.kind_sel == "all": - nnhood_label_majority = nnhood_label == class_minority - nnhood_bool = np.all(nnhood_label, axis=1) - # compute a2 group - index_a2 = np.ravel(nnhood_idx[~nnhood_bool]) - index_a2 = np.unique( - [index for index in index_a2 if y[index] in classes_under_sample] + X_minority = _safe_indexing(X, class_minority_indices) + y_minority = _safe_indexing(y, class_minority_indices) + + y_pred_minority = self.nn_.predict(X_minority) + # add an additional sample since the query points contains the original dataset + neighbors_to_minority_indices = self.nn_.kneighbors( + X_minority, n_neighbors=self.nn_.n_neighbors + 1, return_distance=False + )[:, 1:] + + mask_misclassified_minority = y_pred_minority != y_minority + index_a2 = np.ravel(neighbors_to_minority_indices[mask_misclassified_minority]) + index_a2 = np.array( + [ + index + for index in np.unique(index_a2) + if y[index] in self.classes_to_clean_ + ] ) union_a1_a2 = np.union1d(index_a1, index_a2).astype(int) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py index 971d5b559..97a1d0259 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py @@ -3,74 +3,82 @@ # Christos Aridas # License: MIT +from collections import Counter + import numpy as np +import pytest +from sklearn.datasets import make_classification from sklearn.utils._testing import assert_array_equal -from imblearn.under_sampling import NeighbourhoodCleaningRule - -X = np.array( - [ - [1.57737838, 0.1997882], - [0.8960075, 0.46130762], - [0.34096173, 0.50947647], - [-0.91735824, 0.93110278], - [-0.14619583, 1.33009918], - [-0.20413357, 0.64628718], - [0.85713638, 0.91069295], - [0.35967591, 2.61186964], - [0.43142011, 0.52323596], - [0.90701028, -0.57636928], - [-1.20809175, -1.49917302], - [-0.60497017, -0.66630228], - [1.39272351, -0.51631728], - [-1.55581933, 1.09609604], - [1.55157493, -1.6981518], - ] -) -Y = np.array([1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 0, 0, 2, 1, 2]) - - -def test_ncr_fit_resample(): - ncr = NeighbourhoodCleaningRule() - X_resampled, y_resampled = ncr.fit_resample(X, Y) - - X_gt = np.array( - [ - [0.34096173, 0.50947647], - [-0.91735824, 0.93110278], - [-0.20413357, 0.64628718], - [0.35967591, 2.61186964], - [0.90701028, -0.57636928], - [-1.20809175, -1.49917302], - [-0.60497017, -0.66630228], - [1.39272351, -0.51631728], - [-1.55581933, 1.09609604], - [1.55157493, -1.6981518], - ] +from imblearn.under_sampling import EditedNearestNeighbours, NeighbourhoodCleaningRule + + +@pytest.fixture(scope="module") +def data(): + return make_classification( + n_samples=200, + n_features=2, + n_informative=2, + n_redundant=0, + n_repeated=0, + n_clusters_per_class=1, + n_classes=3, + weights=[0.1, 0.3, 0.6], + random_state=0, ) - y_gt = np.array([1, 1, 1, 2, 2, 0, 0, 2, 1, 2]) - assert_array_equal(X_resampled, X_gt) - assert_array_equal(y_resampled, y_gt) - - -def test_ncr_fit_resample_mode(): - ncr = NeighbourhoodCleaningRule(kind_sel="mode") - X_resampled, y_resampled = ncr.fit_resample(X, Y) - - X_gt = np.array( - [ - [0.34096173, 0.50947647], - [-0.91735824, 0.93110278], - [-0.20413357, 0.64628718], - [0.35967591, 2.61186964], - [0.90701028, -0.57636928], - [-1.20809175, -1.49917302], - [-0.60497017, -0.66630228], - [1.39272351, -0.51631728], - [-1.55581933, 1.09609604], - [1.55157493, -1.6981518], - ] + + +def test_ncr_threshold_cleaning(data): + """Test the effect of the `threshold_cleaning` parameter.""" + X, y = data + # with a large `threshold_cleaning`, the algorithm is equivalent to ENN + enn = EditedNearestNeighbours() + ncr = NeighbourhoodCleaningRule( + edited_nearest_neighbours=enn, n_neighbors=10, threshold_cleaning=10 ) - y_gt = np.array([1, 1, 1, 2, 2, 0, 0, 2, 1, 2]) - assert_array_equal(X_resampled, X_gt) - assert_array_equal(y_resampled, y_gt) + + enn.fit_resample(X, y) + ncr.fit_resample(X, y) + + assert_array_equal(np.sort(enn.sample_indices_), np.sort(ncr.sample_indices_)) + assert ncr.classes_to_clean_ == [] + + # set a threshold that we should consider only the class #2 + counter = Counter(y) + threshold = counter[1] / counter[0] + ncr.set_params(threshold_cleaning=threshold) + ncr.fit_resample(X, y) + + assert set(ncr.classes_to_clean_) == {2} + + # making the threshold slightly smaller to take into account class #1 + ncr.set_params(threshold_cleaning=threshold - np.finfo(np.float32).eps) + ncr.fit_resample(X, y) + + assert set(ncr.classes_to_clean_) == {1, 2} + + +def test_ncr_n_neighbors(data): + """Check the effect of the NN on the cleaning of the second phase.""" + X, y = data + + enn = EditedNearestNeighbours() + ncr = NeighbourhoodCleaningRule(edited_nearest_neighbours=enn, n_neighbors=3) + + ncr.fit_resample(X, y) + sample_indices_3_nn = ncr.sample_indices_ + + ncr.set_params(n_neighbors=10).fit_resample(X, y) + sample_indices_10_nn = ncr.sample_indices_ + + # we should have a more aggressive cleaning with n_neighbors is larger + assert len(sample_indices_3_nn) > len(sample_indices_10_nn) + + +# TODO: remove in 0.14 +@pytest.mark.parametrize("kind_sel", ["all", "mode"]) +def test_ncr_deprecate_kind_sel(data, kind_sel): + X, y = data + + with pytest.warns(FutureWarning, match="`kind_sel` is deprecated"): + NeighbourhoodCleaningRule(kind_sel=kind_sel).fit_resample(X, y) From ae10b7fea65942dcd12ae79b20a7844cb1156082 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 10 Jul 2023 12:44:18 +0200 Subject: [PATCH 2/3] iter --- examples/under-sampling/plot_comparison_under_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/under-sampling/plot_comparison_under_sampling.py b/examples/under-sampling/plot_comparison_under_sampling.py index 4f7d09cb8..e2418e304 100644 --- a/examples/under-sampling/plot_comparison_under_sampling.py +++ b/examples/under-sampling/plot_comparison_under_sampling.py @@ -264,7 +264,7 @@ def plot_decision_function(X, y, clf, ax, title=None): samplers = [ CondensedNearestNeighbour(random_state=0), OneSidedSelection(random_state=0), - NeighbourhoodCleaningRule(), + NeighbourhoodCleaningRule(n_neighbors=11), ] for ax, sampler in zip(axs, samplers): From 7b96b24ef3a19fbdfc6218c195b1fb0af67174b0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 10 Jul 2023 13:14:48 +0200 Subject: [PATCH 3/3] iter --- doc/under_sampling.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/under_sampling.rst b/doc/under_sampling.rst index c3b8ed4f6..9f2795430 100644 --- a/doc/under_sampling.rst +++ b/doc/under_sampling.rst @@ -353,10 +353,10 @@ union of samples to be rejected between the :class:`EditedNearestNeighbours` and the output a 3 nearest neighbors classifier. The class can be used as:: >>> from imblearn.under_sampling import NeighbourhoodCleaningRule - >>> ncr = NeighbourhoodCleaningRule() + >>> ncr = NeighbourhoodCleaningRule(n_neighbors=11) >>> X_resampled, y_resampled = ncr.fit_resample(X, y) >>> print(sorted(Counter(y_resampled).items())) - [(0, 64), (1, 234), (2, 4666)] + [(0, 64), (1, 193), (2, 4535)] .. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_005.png :target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html