Skip to content

Commit

Permalink
MAINT compatibility with sklearn 1.4 (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Oct 23, 2023
1 parent 42a7909 commit 0a659af
Show file tree
Hide file tree
Showing 7 changed files with 2,546 additions and 109 deletions.
8 changes: 8 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ Bug fixes
the number of samples in the minority class.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

Compatibility
.............

- :class:`~imblearn.ensemble.BalancedRandomForestClassifier` now support missing values
and monotonic constraints if scikit-learn >= 1.4 is installed.
- :class:`~imblearn.pipeline.Pipeline` support metadata routing if scikit-learn >= 1.4
is installed.

Deprecations
............

Expand Down
1 change: 1 addition & 0 deletions imblearn/ensemble/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ def check(self):
list,
None,
],
"monotonic_cst": ["array-like", None],
}
163 changes: 115 additions & 48 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _local_parallel_build_trees(
class_weight=None,
n_samples_bootstrap=None,
forest=None,
missing_values_in_feature_mask=None,
):
# resample before to fit the tree
X_resampled, y_resampled = sampler.fit_resample(X, y)
Expand All @@ -68,33 +69,34 @@ def _local_parallel_build_trees(
if _get_n_samples_bootstrap is not None:
n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0])

if sklearn_version >= parse_version("1.1"):
tree = _parallel_build_trees(
tree,
bootstrap,
X_resampled,
y_resampled,
sample_weight,
tree_idx,
n_trees,
verbose=verbose,
class_weight=class_weight,
n_samples_bootstrap=n_samples_bootstrap,
)
params_parallel_build_trees = {
"tree": tree,
"X": X_resampled,
"y": y_resampled,
"sample_weight": sample_weight,
"tree_idx": tree_idx,
"n_trees": n_trees,
"verbose": verbose,
"class_weight": class_weight,
"n_samples_bootstrap": n_samples_bootstrap,
}

if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
# support for missing values
params_parallel_build_trees[
"missing_values_in_feature_mask"
] = missing_values_in_feature_mask

# TODO: remove when the minimum supported version of scikit-learn will be 1.1
# change of signature in scikit-learn 1.1
if parse_version(sklearn_version.base_version) >= parse_version("1.1"):
params_parallel_build_trees["bootstrap"] = bootstrap
else:
# TODO: remove when the minimum version of scikit-learn supported is 1.1
tree = _parallel_build_trees(
tree,
forest,
X_resampled,
y_resampled,
sample_weight,
tree_idx,
n_trees,
verbose=verbose,
class_weight=class_weight,
n_samples_bootstrap=n_samples_bootstrap,
)
params_parallel_build_trees["forest"] = forest

tree = _parallel_build_trees(**params_parallel_build_trees)

return sampler, tree


Expand Down Expand Up @@ -305,6 +307,25 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
.. versionadded:: 0.6
Added in `scikit-learn` in 0.22
monotonic_cst : array-like of int of shape (n_features), default=None
Indicates the monotonicity constraint to enforce on each feature.
- 1: monotonic increase
- 0: no constraint
- -1: monotonic decrease
If monotonic_cst is None, no constraints are applied.
Monotonicity constraints are not supported for:
- multiclass classifications (i.e. when `n_classes > 2`),
- multioutput classifications (i.e. when `n_outputs_ > 1`),
- classifications trained on data with missing values.
The constraints hold over the probability of the positive class.
.. versionadded:: 0.12
Only supported when scikit-learn >= 1.4 is installed. Otherwise, a
`ValueError` is raised.
Attributes
----------
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` instance
Expand Down Expand Up @@ -415,7 +436,7 @@ class labels (multi-output problem).
"""

# make a deepcopy to not modify the original dictionary
if sklearn_version >= parse_version("1.3"):
if sklearn_version >= parse_version("1.4"):
_parameter_constraints = deepcopy(RandomForestClassifier._parameter_constraints)
else:
_parameter_constraints = deepcopy(
Expand Down Expand Up @@ -459,27 +480,42 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
max_samples=None,
monotonic_cst=None,
):
super().__init__(
criterion=criterion,
max_depth=max_depth,
n_estimators=n_estimators,
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
ccp_alpha=ccp_alpha,
max_samples=max_samples,
)
params_random_forest = {
"criterion": criterion,
"max_depth": max_depth,
"n_estimators": n_estimators,
"bootstrap": bootstrap,
"oob_score": oob_score,
"n_jobs": n_jobs,
"random_state": random_state,
"verbose": verbose,
"warm_start": warm_start,
"class_weight": class_weight,
"min_samples_split": min_samples_split,
"min_samples_leaf": min_samples_leaf,
"min_weight_fraction_leaf": min_weight_fraction_leaf,
"max_features": max_features,
"max_leaf_nodes": max_leaf_nodes,
"min_impurity_decrease": min_impurity_decrease,
"ccp_alpha": ccp_alpha,
"max_samples": max_samples,
}
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
# use scikit-learn support for monotonic constraints
params_random_forest["monotonic_cst"] = monotonic_cst
else:
if monotonic_cst is not None:
raise ValueError(
"Monotonic constraints are not supported for scikit-learn "
"version < 1.4."
)
# create an attribute for compatibility with other scikit-learn tools such
# as HTML representation.
self.monotonic_cst = monotonic_cst
super().__init__(**params_random_forest)

self.sampling_strategy = sampling_strategy
self.replacement = replacement
Expand Down Expand Up @@ -591,11 +627,41 @@ def fit(self, X, y, sample_weight=None):
# Validate or convert input data
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")

# TODO: remove when the minimum supported version of scipy will be 1.4
# Support for missing values
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
force_all_finite = False
else:
force_all_finite = True

X, y = self._validate_data(
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
X,
y,
multi_output=True,
accept_sparse="csc",
dtype=DTYPE,
force_all_finite=force_all_finite,
)

# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
# _compute_missing_values_in_feature_mask checks if X has missing values and
# will raise an error if the underlying tree base estimator can't handle
# missing values. Only the criterion is required to determine if the tree
# supports missing values.
estimator = type(self.estimator)(criterion=self.criterion)
missing_values_in_feature_mask = (
estimator._compute_missing_values_in_feature_mask(
X, estimator_name=self.__class__.__name__
)
)
else:
missing_values_in_feature_mask = None

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

self._n_features = X.shape[1]

if issparse(X):
Expand Down Expand Up @@ -713,6 +779,7 @@ def fit(self, X, y, sample_weight=None):
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
forest=self,
missing_values_in_feature_mask=missing_values_in_feature_mask,
)
for i, (s, t) in enumerate(zip(samplers, trees))
)
Expand Down
97 changes: 97 additions & 0 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,100 @@ def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
)
with pytest.warns(FutureWarning, match="The default of `bootstrap`"):
estimator.fit(*imbalanced_dataset)


@pytest.mark.skipif(
parse_version(sklearn_version.base_version) < parse_version("1.4"),
reason="scikit-learn should be >= 1.4",
)
def test_missing_values_is_resilient():
"""Check that forest can deal with missing values and has decent performance."""

rng = np.random.RandomState(0)
n_samples, n_features = 1000, 10
X, y = make_classification(
n_samples=n_samples, n_features=n_features, random_state=rng
)

# Create dataset with missing values
X_missing = X.copy()
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
assert np.isnan(X_missing).any()

X_missing_train, X_missing_test, y_train, y_test = train_test_split(
X_missing, y, random_state=0
)

# Train forest with missing values
forest_with_missing = BalancedRandomForestClassifier(
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=rng,
n_estimators=50,
)
forest_with_missing.fit(X_missing_train, y_train)
score_with_missing = forest_with_missing.score(X_missing_test, y_test)

# Train forest without missing values
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
forest = BalancedRandomForestClassifier(
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=rng,
n_estimators=50,
)
forest.fit(X_train, y_train)
score_without_missing = forest.score(X_test, y_test)

# Score is still 80 percent of the forest's score that had no missing values
assert score_with_missing >= 0.80 * score_without_missing


@pytest.mark.skipif(
parse_version(sklearn_version.base_version) < parse_version("1.4"),
reason="scikit-learn should be >= 1.4",
)
def test_missing_value_is_predictive():
"""Check that the forest learns when missing values are only present for
a predictive feature."""
rng = np.random.RandomState(0)
n_samples = 300

X_non_predictive = rng.standard_normal(size=(n_samples, 10))
y = rng.randint(0, high=2, size=n_samples)

# Create a predictive feature using `y` and with some noise
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
y_mask = y.astype(bool)
y_mask[X_random_mask] = ~y_mask[X_random_mask]

predictive_feature = rng.standard_normal(size=n_samples)
predictive_feature[y_mask] = np.nan
assert np.isnan(predictive_feature).any()

X_predictive = X_non_predictive.copy()
X_predictive[:, 5] = predictive_feature

(
X_predictive_train,
X_predictive_test,
X_non_predictive_train,
X_non_predictive_test,
y_train,
y_test,
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
forest_predictive = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
).fit(X_predictive_train, y_train)
forest_non_predictive = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
).fit(X_non_predictive_train, y_train)

predictive_test_score = forest_predictive.score(X_predictive_test, y_test)

assert predictive_test_score >= 0.75
assert predictive_test_score >= forest_non_predictive.score(
X_non_predictive_test, y_test
)
Loading

0 comments on commit 0a659af

Please sign in to comment.