From 7e43cfd330e6b11bf45a083e201220a75f901468 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 8 Jun 2019 00:38:50 +0200 Subject: [PATCH 1/2] MAINT use sklearn estimator tag in our test suite --- imblearn/base.py | 5 +++ .../over_sampling/_random_over_sampler.py | 6 +++ .../_cluster_centroids.py | 3 ++ .../_condensed_nearest_neighbour.py | 3 ++ .../_edited_nearest_neighbours.py | 9 +++++ .../_instance_hardness_threshold.py | 3 ++ .../_prototype_selection/_nearmiss.py | 3 ++ .../_neighbourhood_cleaning_rule.py | 3 ++ .../_one_sided_selection.py | 3 ++ .../_random_under_sampler.py | 6 +++ .../_prototype_selection/_tomek_links.py | 3 ++ imblearn/utils/estimator_checks.py | 38 +++---------------- 12 files changed, 52 insertions(+), 33 deletions(-) diff --git a/imblearn/base.py b/imblearn/base.py index 0ddc9386a..07945b0d1 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -12,6 +12,7 @@ import numpy as np from sklearn.base import BaseEstimator +from sklearn.base import _DEFAULT_TAGS from sklearn.preprocessing import label_binarize from sklearn.utils import check_X_y from sklearn.utils.multiclass import check_classification_targets @@ -19,6 +20,10 @@ from .utils import check_sampling_strategy, check_target_type from .utils.deprecation import deprecate_parameter +_DEFAULT_TAGS.update( + {'sample_indices': False} +) + class SamplerMixin(BaseEstimator, metaclass=ABCMeta): """Mixin class for samplers with abstract method. diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 26d350776..2f1fbfdc7 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -119,3 +119,9 @@ def _fit_resample(self, X, y): safe_indexing(y, sample_indices), sample_indices) return (safe_indexing(X, sample_indices), safe_indexing(y, sample_indices)) + + def _more_tags(self): + # TODO: remove the str tag once the following PR is merged: + # https://github.com/scikit-learn/scikit-learn/pull/14043 + return {'X_types': ['2darray', 'str', 'string'], + 'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py index ffeb0127c..1588afc07 100644 --- a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py @@ -170,3 +170,6 @@ def _fit_resample(self, X, y): y_resampled = np.hstack(y_resampled) return X_resampled, np.array(y_resampled, dtype=y.dtype) + + def _more_tags(self): + return {'sample_indices': False} diff --git a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py index 7e22109cc..dae8c862e 100644 --- a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py +++ b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py @@ -220,3 +220,6 @@ def _fit_resample(self, X, y): return (safe_indexing(X, idx_under), safe_indexing(y, idx_under), idx_under) return safe_indexing(X, idx_under), safe_indexing(y, idx_under) + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py index 20ba41c4a..48cd1f94d 100644 --- a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py +++ b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py @@ -186,6 +186,9 @@ def _fit_resample(self, X, y): idx_under) return safe_indexing(X, idx_under), safe_indexing(y, idx_under) + def _more_tags(self): + return {'sample_indices': True} + @Substitution( sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring, @@ -377,6 +380,9 @@ def _fit_resample(self, X, y): return X_resampled, y_resampled, self.sample_indices_ return X_resampled, y_resampled + def _more_tags(self): + return {'sample_indices': True} + @Substitution( sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring, @@ -564,3 +570,6 @@ def _fit_resample(self, X, y): if self.return_indices: return X_resampled, y_resampled, self.sample_indices_ return X_resampled, y_resampled + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 63ec28a1b..343c6a251 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -187,3 +187,6 @@ def _fit_resample(self, X, y): return (safe_indexing(X, idx_under), safe_indexing(y, idx_under), idx_under) return safe_indexing(X, idx_under), safe_indexing(y, idx_under) + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_nearmiss.py b/imblearn/under_sampling/_prototype_selection/_nearmiss.py index d5f4d228d..1786352ad 100644 --- a/imblearn/under_sampling/_prototype_selection/_nearmiss.py +++ b/imblearn/under_sampling/_prototype_selection/_nearmiss.py @@ -293,3 +293,6 @@ def _fit_resample(self, X, y): return (safe_indexing(X, idx_under), safe_indexing(y, idx_under), idx_under) return safe_indexing(X, idx_under), safe_indexing(y, idx_under) + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py index 93ca7df5a..28731c5dd 100644 --- a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py @@ -204,3 +204,6 @@ def _fit_resample(self, X, y): self.sample_indices_) return (safe_indexing(X, self.sample_indices_), safe_indexing(y, self.sample_indices_)) + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py index 262bea6d9..fede527d1 100644 --- a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py +++ b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py @@ -189,3 +189,6 @@ def _fit_resample(self, X, y): if self.return_indices: return (X_cleaned, y_cleaned, self.sample_indices_) return X_cleaned, y_cleaned + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index 41237969e..0394c8625 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -135,3 +135,9 @@ def _fit_resample(self, X, y): return (safe_indexing(X, idx_under), safe_indexing(y, idx_under), idx_under) return safe_indexing(X, idx_under), safe_indexing(y, idx_under) + + def _more_tags(self): + # TODO: remove the str tag once the following PR is merged: + # https://github.com/scikit-learn/scikit-learn/pull/14043 + return {'X_types': ['2darray', 'str', 'string'], + 'sample_indices': True} diff --git a/imblearn/under_sampling/_prototype_selection/_tomek_links.py b/imblearn/under_sampling/_prototype_selection/_tomek_links.py index 18f986197..50d43217b 100644 --- a/imblearn/under_sampling/_prototype_selection/_tomek_links.py +++ b/imblearn/under_sampling/_prototype_selection/_tomek_links.py @@ -166,3 +166,6 @@ def _fit_resample(self, X, y): self.sample_indices_) return (safe_indexing(X, self.sample_indices_), safe_indexing(y, self.sample_indices_)) + + def _more_tags(self): + return {'sample_indices': True} diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index b564f2f95..c7bfaa4c9 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -27,6 +27,10 @@ from sklearn.utils.testing import set_random_state from sklearn.utils.multiclass import type_of_target +# import the _safe_tags from sklearn and then the updated _DEFAULT_TAG +from sklearn.utils.estimator_checks import _safe_tags +from imblearn.base import _DEFAULT_TAGS + from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler from imblearn.ensemble.base import BaseEnsembleSampler @@ -34,40 +38,12 @@ from imblearn.under_sampling import NearMiss, ClusterCentroids DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE'] -SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler'] -HAVE_SAMPLE_INDICES = [ - 'RandomOverSampler', 'RandomUnderSampler', 'InstanceHardnessThreshold', - 'NearMiss', 'TomekLinks', 'EditedNearestNeighbours', - 'RepeatedEditedNearestNeighbours', 'AllKNN', 'OneSidedSelection', - 'CondensedNearestNeighbour', 'NeighbourhoodCleaningRule'] # FIXME: remove in 0.6 DONT_HAVE_RANDOM_STATE = ('NearMiss', 'EditedNearestNeighbours', 'RepeatedEditedNearestNeighbours', 'AllKNN', 'NeighbourhoodCleaningRule', 'TomekLinks') -def monkey_patch_check_dtype_object(name, estimator_orig): - # check that estimators treat dtype object as numeric if possible - rng = np.random.RandomState(0) - X = rng.rand(40, 10).astype(object) - y = np.array([0] * 10 + [1] * 30, dtype=np.int) - estimator = clone(estimator_orig) - estimator.fit(X, y) - - try: - estimator.fit(X, y.astype(object)) - except Exception as e: - if "Unknown label type" not in str(e): - raise - - if name not in SUPPORT_STRING: - X[0, 0] = {'foo': 'bar'} - msg = "argument must be a string or a number" - assert_raises_regex(TypeError, msg, estimator.fit, X, y) - else: - estimator.fit(X, y) - - def _yield_sampler_checks(name, Estimator): yield check_target_type yield check_samplers_one_label @@ -106,10 +82,6 @@ def check_estimator(Estimator, run_sampler_tests=True): Will run or not the samplers tests. """ name = Estimator.__name__ - # monkey patch check_dtype_object for the sampler allowing strings - import sklearn.utils.estimator_checks - sklearn.utils.estimator_checks.check_dtype_object = \ - monkey_patch_check_dtype_object # scikit-learn common tests sklearn_check_estimator(Estimator) check_parameters_default_constructible(name, Estimator) @@ -369,7 +341,7 @@ def check_samplers_sample_indices(name, Sampler): weights=[0.2, 0.3, 0.5], random_state=0) sampler = Sampler() sampler.fit_resample(X, y) - if name in HAVE_SAMPLE_INDICES: + if _safe_tags(sampler, 'sample_indices'): assert hasattr(sampler, 'sample_indices_') else: assert not hasattr(sampler, 'sample_indices_') From 4e506472d690c6225eb6c98386dd3718600fe8c5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 11 Jun 2019 21:19:08 +0200 Subject: [PATCH 2/2] iter --- imblearn/base.py | 5 ----- imblearn/utils/estimator_checks.py | 9 +++------ 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/imblearn/base.py b/imblearn/base.py index 07945b0d1..0ddc9386a 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -12,7 +12,6 @@ import numpy as np from sklearn.base import BaseEstimator -from sklearn.base import _DEFAULT_TAGS from sklearn.preprocessing import label_binarize from sklearn.utils import check_X_y from sklearn.utils.multiclass import check_classification_targets @@ -20,10 +19,6 @@ from .utils import check_sampling_strategy, check_target_type from .utils.deprecation import deprecate_parameter -_DEFAULT_TAGS.update( - {'sample_indices': False} -) - class SamplerMixin(BaseEstimator, metaclass=ABCMeta): """Mixin class for samplers with abstract method. diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index c7bfaa4c9..72960c9c8 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -27,10 +27,6 @@ from sklearn.utils.testing import set_random_state from sklearn.utils.multiclass import type_of_target -# import the _safe_tags from sklearn and then the updated _DEFAULT_TAG -from sklearn.utils.estimator_checks import _safe_tags -from imblearn.base import _DEFAULT_TAGS - from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler from imblearn.ensemble.base import BaseEnsembleSampler @@ -341,7 +337,8 @@ def check_samplers_sample_indices(name, Sampler): weights=[0.2, 0.3, 0.5], random_state=0) sampler = Sampler() sampler.fit_resample(X, y) - if _safe_tags(sampler, 'sample_indices'): - assert hasattr(sampler, 'sample_indices_') + sample_indices = sampler._get_tags().get('sample_indices', None) + if sample_indices: + assert hasattr(sampler, 'sample_indices_') is sample_indices else: assert not hasattr(sampler, 'sample_indices_')