Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/whats_new/v0.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ Maintenance
:pr:`617` by :user:`Guillaume Lemaitre <glemaitre>`.

- Synchronize :mod:`imblearn.pipeline` with :mod:`sklearn.pipeline`.
:pr:`617` by :user:`Guillaume Lemaitre <glemaitre>`.
:pr:`620` by :user:`Guillaume Lemaitre <glemaitre>`.

- Synchronize :class:`imblearn.ensemble.BalancedRandomForestClassifier` and add
parameters `max_samples` and `ccp_alpha`.
:pr:`621` by :user:`Guillaume Lemaitre <glemaitre>`.

Deprecation
...........
Expand Down
36 changes: 35 additions & 1 deletion imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.ensemble._forest import _get_n_samples_bootstrap
from sklearn.ensemble._forest import _parallel_build_trees
from sklearn.exceptions import DataConversionWarning
from sklearn.tree import DecisionTreeClassifier
Expand Down Expand Up @@ -44,6 +45,7 @@ def _local_parallel_build_trees(
n_trees,
verbose=0,
class_weight=None,
n_samples_bootstrap=None
):
# resample before to fit the tree
X_resampled, y_resampled = sampler.fit_resample(X, y)
Expand All @@ -59,7 +61,7 @@ def _local_parallel_build_trees(
n_trees,
verbose=verbose,
class_weight=class_weight,
n_samples_bootstrap=X_resampled.shape[0],
n_samples_bootstrap=n_samples_bootstrap,
)
return sampler, tree

Expand Down Expand Up @@ -195,6 +197,27 @@ class BalancedRandomForestClassifier(RandomForestClassifier):
Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.


ccp_alpha : non-negative float, optional (default=0.0)
Complexity parameter used for Minimal Cost-Complexity Pruning. The
subtree with the largest cost complexity that is smaller than
``ccp_alpha`` will be chosen. By default, no pruning is performed. See
:ref:`minimal_cost_complexity_pruning` for details.

.. versionadded:: 0.22
Added in `scikit-learn` in 0.22

max_samples : int or float, default=None
If bootstrap is True, the number of samples to draw from X
to train each base estimator.
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
`max_samples` should be in the interval `(0, 1)`.

.. versionadded:: 0.22
Added in `scikit-learn` in 0.22

Attributes
----------
estimators_ : list of DecisionTreeClassifier
Expand Down Expand Up @@ -281,6 +304,8 @@ def __init__(
verbose=0,
warm_start=False,
class_weight=None,
ccp_alpha=0.0,
max_samples=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -299,6 +324,8 @@ def __init__(
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
ccp_alpha=ccp_alpha,
max_samples=max_samples,
)

self.sampling_strategy = sampling_strategy
Expand Down Expand Up @@ -414,6 +441,12 @@ def fit(self, X, y, sample_weight=None):
else:
sample_weight = expanded_class_weight

# Get bootstrap sample size
n_samples_bootstrap = _get_n_samples_bootstrap(
n_samples=X.shape[0],
max_samples=self.max_samples
)

# Check parameters
self._validate_estimator()

Expand Down Expand Up @@ -479,6 +512,7 @@ def fit(self, X, y, sample_weight=None):
len(trees),
verbose=self.verbose,
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
)
for i, (s, t) in enumerate(zip(samplers, trees))
)
Expand Down
42 changes: 42 additions & 0 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,45 @@ def test_balanced_random_forest_grid_search(imbalanced_dataset):
brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3
)
grid.fit(*imbalanced_dataset)


def test_little_tree_with_small_max_samples():
rng = np.random.RandomState(1)

X = rng.randn(10000, 2)
y = rng.randn(10000) > 0

# First fit with no restriction on max samples
est1 = BalancedRandomForestClassifier(
n_estimators=1,
random_state=rng,
max_samples=None,
)

# Second fit with max samples restricted to just 2
est2 = BalancedRandomForestClassifier(
n_estimators=1,
random_state=rng,
max_samples=2,
)

est1.fit(X, y)
est2.fit(X, y)

tree1 = est1.estimators_[0].tree_
tree2 = est2.estimators_[0].tree_

msg = "Tree without `max_samples` restriction should have more nodes"
assert tree1.node_count > tree2.node_count, msg


def test_balanced_random_forest_pruning(imbalanced_dataset):
brf = BalancedRandomForestClassifier()
brf.fit(*imbalanced_dataset)
n_nodes_no_pruning = brf.estimators_[0].tree_.node_count

brf_pruned = BalancedRandomForestClassifier(ccp_alpha=0.015)
brf_pruned.fit(*imbalanced_dataset)
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count

assert n_nodes_no_pruning > n_nodes_pruning