Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6b2ca6e
Draft of the metadata routing for bagging
adam2392 Feb 15, 2024
6fc50fc
Adding metadata routing
adam2392 Feb 15, 2024
a14913a
Raise for params
adam2392 Feb 15, 2024
351f2b2
Fix lint
adam2392 Feb 15, 2024
52e99a9
Address final test
adam2392 Feb 16, 2024
9558c18
Fix lint' -s
adam2392 Feb 16, 2024
fc77399
Fix iforest and print statements
adam2392 Feb 16, 2024
e7a392c
Fix lint
adam2392 Feb 16, 2024
f1da30b
Merge branch 'main' into baggingmeta
adam2392 Feb 16, 2024
9d23eb2
Fix bagging and isolation forest
adam2392 Feb 16, 2024
57c0881
Merge branch 'baggingmeta' of https://github.com/adam2392/scikit-lear…
adam2392 Feb 16, 2024
4ad22f3
Merge branch 'main' into baggingmeta
adam2392 Feb 17, 2024
f0a93e5
Fix test
adam2392 Feb 17, 2024
294c8f1
Fix lint
adam2392 Feb 17, 2024
5c4f968
Merge branch 'baggingmeta' of https://github.com/adam2392/scikit-lear…
adam2392 Feb 17, 2024
a675a65
Apply suggestions from code review
adam2392 Feb 19, 2024
134a60f
Address adrin comments
adam2392 Feb 19, 2024
9dbac9b
Merge branch 'main' into baggingmeta
adam2392 Feb 19, 2024
2ffd7b7
Address adrin's comment
adam2392 Feb 19, 2024
e2dbc63
clean up test to make pass
adam2392 Feb 19, 2024
0062566
Fix lint
adam2392 Feb 20, 2024
7dd4fb5
Cleanup
adam2392 Feb 20, 2024
c6bd89b
Merge branch 'main' into baggingmeta
adam2392 Feb 20, 2024
926e5ee
Remove diff
adam2392 Feb 20, 2024
312be71
Remove diff
adam2392 Feb 20, 2024
5dc9a3d
Update v1.5.rst
adam2392 Feb 21, 2024
3a3d444
Update v1.5.rst
adam2392 Feb 21, 2024
01a64e4
Update _bagging.py
adam2392 Feb 21, 2024
debe632
Merge branch 'main' into baggingmeta
adam2392 Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ Meta-estimators and functions supporting metadata routing:

- :class:`sklearn.calibration.CalibratedClassifierCV`
- :class:`sklearn.compose.ColumnTransformer`
- :class:`sklearn.ensemble.BaggingClassifier`
- :class:`sklearn.ensemble.BaggingRegressor`
- :class:`sklearn.feature_selection.SelectFromModel`
- :class:`sklearn.impute.IterativeImputer`
- :class:`sklearn.linear_model.ElasticNetCV`
Expand Down Expand Up @@ -306,8 +308,6 @@ Meta-estimators and tools not supporting metadata routing yet:
- :class:`sklearn.covariance.GraphicalLassoCV`
- :class:`sklearn.ensemble.AdaBoostClassifier`
- :class:`sklearn.ensemble.AdaBoostRegressor`
- :class:`sklearn.ensemble.BaggingClassifier`
- :class:`sklearn.ensemble.BaggingRegressor`
- :class:`sklearn.ensemble.StackingClassifier`
- :class:`sklearn.ensemble.StackingRegressor`
- :class:`sklearn.ensemble.VotingClassifier`
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ more details.
- |Feature| :class:`impute.IterativeImputer` now supports metadata routing in
its `fit` method. :pr:`28187` by :user:`Stefanie Senger <StefanieSenger>`.

- |Feature| :class:`ensemble.BaggingClassifier` and :class:`ensemble.BaggingRegressor`
now support metadata routing. The fit methods now
accept ``**fit_params`` which are passed to the underlying estimators
via their `fit` methods.
:pr:`28432` by :user:`Adam Li <adam2392>` and :user:`Benjamin Bossan <BenjaminBossan>`.

Changelog
---------
Expand Down
180 changes: 139 additions & 41 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,34 @@
from ..base import ClassifierMixin, RegressorMixin, _fit_context
from ..metrics import accuracy_score, r2_score
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
from ..utils import check_random_state, column_or_1d, indices_to_mask
from ..utils import (
Bunch,
_safe_indexing,
check_random_state,
column_or_1d,
indices_to_mask,
)
from ..utils._param_validation import HasMethods, Interval, RealNotInt
from ..utils._tags import _safe_tags
from ..utils.metadata_routing import (
_raise_for_unsupported_routing,
_RoutingNotSupportedMixin,
MetadataRouter,
MethodMapping,
_raise_for_params,
_routing_enabled,
get_routing_for_object,
process_routing,
)
from ..utils.metaestimators import available_if
from ..utils.multiclass import check_classification_targets
from ..utils.parallel import Parallel, delayed
from ..utils.random import sample_without_replacement
from ..utils.validation import _check_sample_weight, check_is_fitted, has_fit_parameter
from ..utils.validation import (
_check_method_params,
_check_sample_weight,
_deprecate_positional_args,
check_is_fitted,
has_fit_parameter,
)
from ._base import BaseEnsemble, _partition_estimators

__all__ = ["BaggingClassifier", "BaggingRegressor"]
Expand Down Expand Up @@ -77,11 +93,11 @@ def _parallel_build_estimators(
ensemble,
X,
y,
sample_weight,
seeds,
total_n_estimators,
verbose,
check_input,
fit_params,
):
"""Private function used to build a batch of estimators within a job."""
# Retrieve settings
Expand All @@ -90,17 +106,26 @@ def _parallel_build_estimators(
max_samples = ensemble._max_samples
bootstrap = ensemble.bootstrap
bootstrap_features = ensemble.bootstrap_features
support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight")
has_check_input = has_fit_parameter(ensemble.estimator_, "check_input")
requires_feature_indexing = bootstrap_features or max_features != n_features

if not support_sample_weight and sample_weight is not None:
raise ValueError("The base estimator doesn't support sample weight")

# Build estimators
estimators = []
estimators_features = []

request_or_router = get_routing_for_object(ensemble.estimator_)

# TODO: (slep6) remove if condition for unrouted sample_weight when metadata
# routing can't be disabled.
support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight")
if not _routing_enabled() and (
not support_sample_weight and fit_params.get("sample_weight") is not None
):
raise ValueError(
"The base estimator doesn't support sample weight, but sample_weight is "
"passed to the fit method."
)

for i in range(n_estimators):
if verbose > 1:
print(
Expand All @@ -127,12 +152,25 @@ def _parallel_build_estimators(
max_samples,
)

# Draw samples, using sample weights, and then fit
if support_sample_weight:
if sample_weight is None:
curr_sample_weight = np.ones((n_samples,))
else:
curr_sample_weight = sample_weight.copy()
fit_params_ = fit_params.copy()

# TODO(SLEP6): remove if condition for unrouted sample_weight when metadata
# routing can't be disabled.
# 1. If routing is enabled, we will check if the routing supports sample
# weight and use it if it does.
# 2. If routing is not enabled, we will check if the base
# estimator supports sample_weight and use it if it does.

# Note: Row sampling can be achieved either through setting sample_weight or
# by indexing. The former is more efficient. Therefore, use this method
# if possible, otherwise use indexing.
if (
_routing_enabled() and request_or_router.consumes("fit", ("sample_weight",))
) or (not _routing_enabled() and support_sample_weight):
# Draw sub samples, using sample weights, and then fit
curr_sample_weight = _check_sample_weight(
fit_params_.pop("sample_weight", None), X
).copy()

if bootstrap:
sample_counts = np.bincount(indices, minlength=n_samples)
Expand All @@ -141,11 +179,17 @@ def _parallel_build_estimators(
not_indices_mask = ~indices_to_mask(indices, n_samples)
curr_sample_weight[not_indices_mask] = 0

fit_params_["sample_weight"] = curr_sample_weight
X_ = X[:, features] if requires_feature_indexing else X
estimator_fit(X_, y, sample_weight=curr_sample_weight)
estimator_fit(X_, y, **fit_params_)
else:
X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
estimator_fit(X_, y[indices])
# cannot use sample_weight, so use indexing
y_ = _safe_indexing(y, indices)
X_ = _safe_indexing(X, indices)
fit_params_ = _check_method_params(X, params=fit_params_, indices=indices)
if requires_feature_indexing:
X_ = X_[:, features]
estimator_fit(X_, y_, **fit_params_)

estimators.append(estimator)
estimators_features.append(features)
Expand Down Expand Up @@ -294,11 +338,15 @@ def __init__(
self.random_state = random_state
self.verbose = verbose

# TODO(1.7): remove `sample_weight` from the signature after deprecation
# cycle; pop it from `fit_params` before the `_raise_for_params` check and
# reinsert later, for backwards compatibility
@_deprecate_positional_args(version="1.7")
@_fit_context(
# BaseBagging.estimator is not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y, sample_weight=None):
def fit(self, X, y, *, sample_weight=None, **fit_params):
"""Build a Bagging ensemble of estimators from the training set (X, y).

Parameters
Expand All @@ -316,12 +364,24 @@ def fit(self, X, y, sample_weight=None):
Note that this is supported only if the base estimator supports
sample weighting.

**fit_params : dict
Parameters to pass to the underlying estimators.

.. versionadded:: 1.5

Only available if `enable_metadata_routing=True`,
which can be set by using
``sklearn.set_config(enable_metadata_routing=True)``.
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Returns
-------
self : object
Fitted estimator.
"""
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
_raise_for_params(fit_params, self, "fit")

# Convert data (X is required to be 2d and indexable)
X, y = self._validate_data(
X,
Expand All @@ -331,7 +391,12 @@ def fit(self, X, y, sample_weight=None):
force_all_finite=False,
multi_output=True,
)
return self._fit(X, y, self.max_samples, sample_weight=sample_weight)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)
fit_params["sample_weight"] = sample_weight

return self._fit(X, y, max_samples=self.max_samples, **fit_params)

def _parallel_args(self):
return {}
Expand All @@ -342,8 +407,8 @@ def _fit(
y,
max_samples=None,
max_depth=None,
sample_weight=None,
check_input=True,
**fit_params,
):
"""Build a Bagging ensemble of estimators from the training
set (X, y).
Expand All @@ -365,14 +430,15 @@ def _fit(
Override value used when constructing base estimator. Only
supported if the base estimator has a max_depth parameter.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted.
Note that this is supported only if the base estimator supports
sample weighting.

check_input : bool, default=True
Override value used when fitting base estimator. Only supported
if the base estimator has a check_input parameter for fit function.
If the meta-estimator already checks the input, set this value to
False to prevent redundant input validation.

**fit_params : dict, default=None
Parameters to pass to the :term:`fit` method of the underlying
estimator.

Returns
-------
Expand All @@ -381,16 +447,23 @@ def _fit(
"""
random_state = check_random_state(self.random_state)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)

# Remap output
n_samples = X.shape[0]
self._n_samples = n_samples
y = self._validate_y(y)

# Check parameters
self._validate_estimator()
self._validate_estimator(self._get_estimator())

if _routing_enabled():
routed_params = process_routing(self, "fit", **fit_params)
else:
routed_params = Bunch()
routed_params.estimator = Bunch(fit=fit_params)
if "sample_weight" in fit_params:
routed_params.estimator.fit["sample_weight"] = fit_params[
"sample_weight"
]

if max_depth is not None:
self.estimator_.max_depth = max_depth
Expand Down Expand Up @@ -474,11 +547,11 @@ def _fit(
self,
X,
y,
sample_weight,
seeds[starts[i] : starts[i + 1]],
total_n_estimators,
verbose=self.verbose,
check_input=check_input,
fit_params=routed_params.estimator.fit,
)
for i in range(n_jobs)
)
Expand Down Expand Up @@ -537,8 +610,33 @@ def estimators_samples_(self):
"""
return [sample_indices for _, sample_indices in self._get_estimators_indices()]

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

.. versionadded:: 1.5

class BaggingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, BaseBagging):
Returns
-------
routing : MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)
router.add(
estimator=self._get_estimator(),
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
return router

@abstractmethod
def _get_estimator(self):
"""Resolve which estimator to return."""


class BaggingClassifier(ClassifierMixin, BaseBagging):
"""A Bagging classifier.

A Bagging classifier is an ensemble meta-estimator that fits base
Expand Down Expand Up @@ -735,9 +833,9 @@ def __init__(
verbose=verbose,
)

def _validate_estimator(self):
"""Check the estimator and set the estimator_ attribute."""
super()._validate_estimator(default=DecisionTreeClassifier())
def _get_estimator(self):
"""Resolve which estimator to return (default is DecisionTreeClassifier)"""
return self.estimator or DecisionTreeClassifier()

def _set_oob_score(self, X, y):
n_samples = y.shape[0]
Expand Down Expand Up @@ -970,7 +1068,7 @@ def _more_tags(self):
return {"allow_nan": _safe_tags(estimator, "allow_nan")}


class BaggingRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseBagging):
class BaggingRegressor(RegressorMixin, BaseBagging):
"""A Bagging regressor.

A Bagging regressor is an ensemble meta-estimator that fits base
Expand Down Expand Up @@ -1202,10 +1300,6 @@ def predict(self, X):

return y_hat

def _validate_estimator(self):
"""Check the estimator and set the estimator_ attribute."""
super()._validate_estimator(default=DecisionTreeRegressor())

def _set_oob_score(self, X, y):
n_samples = y.shape[0]

Expand Down Expand Up @@ -1240,3 +1334,7 @@ def _more_tags(self):
else:
estimator = self.estimator
return {"allow_nan": _safe_tags(estimator, "allow_nan")}

def _get_estimator(self):
"""Resolve which estimator to return (default is DecisionTreeClassifier)"""
return self.estimator or DecisionTreeRegressor()
12 changes: 9 additions & 3 deletions sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ def __init__(
warm_start=False,
):
super().__init__(
estimator=ExtraTreeRegressor(
max_features=1, splitter="random", random_state=random_state
),
estimator=None,
# here above max_features has no links with self.max_features
bootstrap=bootstrap,
bootstrap_features=False,
Expand All @@ -249,6 +247,14 @@ def __init__(

self.contamination = contamination

def _get_estimator(self):
return ExtraTreeRegressor(
# here max_features has no links with self.max_features
max_features=1,
splitter="random",
random_state=self.random_state,
)

def _set_oob_score(self, X, y):
raise NotImplementedError("OOB score not supported by iforest")

Expand Down
Loading