From 72969aded5b00d92f01bd0b3c56b89dd7e63881c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 1 Nov 2019 00:14:21 +0100 Subject: [PATCH 1/3] MNT synchronize forest with scikit-learn --- imblearn/ensemble/_forest.py | 23 +++++++++++++++++++- imblearn/ensemble/tests/test_forest.py | 30 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index a3066fef5..7aae380bb 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -17,6 +17,7 @@ from sklearn.base import clone from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble._base import _set_random_states +from sklearn.ensemble._forest import _get_n_samples_bootstrap from sklearn.ensemble._forest import _parallel_build_trees from sklearn.exceptions import DataConversionWarning from sklearn.tree import DecisionTreeClassifier @@ -44,6 +45,7 @@ def _local_parallel_build_trees( n_trees, verbose=0, class_weight=None, + n_samples_bootstrap=None ): # resample before to fit the tree X_resampled, y_resampled = sampler.fit_resample(X, y) @@ -59,7 +61,7 @@ def _local_parallel_build_trees( n_trees, verbose=verbose, class_weight=class_weight, - n_samples_bootstrap=X_resampled.shape[0], + n_samples_bootstrap=n_samples_bootstrap, ) return sampler, tree @@ -195,6 +197,16 @@ class BalancedRandomForestClassifier(RandomForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + max_samples : int or float, default=None + If bootstrap is True, the number of samples to draw from X + to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0, 1)`. + + .. versionadded:: 0.22 + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -281,6 +293,7 @@ def __init__( verbose=0, warm_start=False, class_weight=None, + max_samples=None, ): super().__init__( criterion=criterion, @@ -299,6 +312,7 @@ def __init__( max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease, + max_samples=max_samples, ) self.sampling_strategy = sampling_strategy @@ -414,6 +428,12 @@ def fit(self, X, y, sample_weight=None): else: sample_weight = expanded_class_weight + # Get bootstrap sample size + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples=X.shape[0], + max_samples=self.max_samples + ) + # Check parameters self._validate_estimator() @@ -479,6 +499,7 @@ def fit(self, X, y, sample_weight=None): len(trees), verbose=self.verbose, class_weight=self.class_weight, + n_samples_bootstrap=n_samples_bootstrap, ) for i, (s, t) in enumerate(zip(samplers, trees)) ) diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py index 0f386cb1a..81f35b0e5 100644 --- a/imblearn/ensemble/tests/test_forest.py +++ b/imblearn/ensemble/tests/test_forest.py @@ -134,3 +134,33 @@ def test_balanced_random_forest_grid_search(imbalanced_dataset): brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3 ) grid.fit(*imbalanced_dataset) + + +def test_little_tree_with_small_max_samples(): + rng = np.random.RandomState(1) + + X = rng.randn(10000, 2) + y = rng.randn(10000) > 0 + + # First fit with no restriction on max samples + est1 = BalancedRandomForestClassifier( + n_estimators=1, + random_state=rng, + max_samples=None, + ) + + # Second fit with max samples restricted to just 2 + est2 = BalancedRandomForestClassifier( + n_estimators=1, + random_state=rng, + max_samples=2, + ) + + est1.fit(X, y) + est2.fit(X, y) + + tree1 = est1.estimators_[0].tree_ + tree2 = est2.estimators_[0].tree_ + + msg = "Tree without `max_samples` restriction should have more nodes" + assert tree1.node_count > tree2.node_count, msg From 4d9cb8a1a6d9fcd04b8423d1e726fe5c70a95652 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 1 Nov 2019 00:23:29 +0100 Subject: [PATCH 2/3] add ccp_alpha --- imblearn/ensemble/_forest.py | 13 +++++++++++++ imblearn/ensemble/tests/test_forest.py | 12 ++++++++++++ 2 files changed, 25 insertions(+) diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index 7aae380bb..68769d12e 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -197,6 +197,16 @@ class BalancedRandomForestClassifier(RandomForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + + ccp_alpha : non-negative float, optional (default=0.0) + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + Added in `scikit-learn` in 0.22 + max_samples : int or float, default=None If bootstrap is True, the number of samples to draw from X to train each base estimator. @@ -206,6 +216,7 @@ class BalancedRandomForestClassifier(RandomForestClassifier): `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 + Added in `scikit-learn` in 0.22 Attributes ---------- @@ -293,6 +304,7 @@ def __init__( verbose=0, warm_start=False, class_weight=None, + ccp_alpha=0.0, max_samples=None, ): super().__init__( @@ -312,6 +324,7 @@ def __init__( max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease, + ccp_alpha=ccp_alpha, max_samples=max_samples, ) diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py index 81f35b0e5..0c8c615f8 100644 --- a/imblearn/ensemble/tests/test_forest.py +++ b/imblearn/ensemble/tests/test_forest.py @@ -164,3 +164,15 @@ def test_little_tree_with_small_max_samples(): msg = "Tree without `max_samples` restriction should have more nodes" assert tree1.node_count > tree2.node_count, msg + + +def test_balanced_random_forest_pruning(imbalanced_dataset): + brf = BalancedRandomForestClassifier() + brf.fit(*imbalanced_dataset) + n_nodes_no_pruning = brf.estimators_[0].tree_.node_count + + brf_pruned = BalancedRandomForestClassifier(ccp_alpha=0.015) + brf_pruned.fit(*imbalanced_dataset) + n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count + + assert n_nodes_no_pruning > n_nodes_pruning From d51c457de5a29456b73e6152e9bc99245c3c37cf Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 1 Nov 2019 00:26:42 +0100 Subject: [PATCH 3/3] DOC whats new --- doc/whats_new/v0.6.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v0.6.rst b/doc/whats_new/v0.6.rst index 869dff77e..8d846faaf 100644 --- a/doc/whats_new/v0.6.rst +++ b/doc/whats_new/v0.6.rst @@ -32,7 +32,11 @@ Maintenance :pr:`617` by :user:`Guillaume Lemaitre `. - Synchronize :mod:`imblearn.pipeline` with :mod:`sklearn.pipeline`. - :pr:`617` by :user:`Guillaume Lemaitre `. + :pr:`620` by :user:`Guillaume Lemaitre `. + +- Synchronize :class:`imblearn.ensemble.BalancedRandomForestClassifier` and add + parameters `max_samples` and `ccp_alpha`. + :pr:`621` by :user:`Guillaume Lemaitre `. Deprecation ...........