Skip to content

Commit

Permalink
MAINT remove deprecated base_estimator and base_estimator_ (#27826)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Nov 24, 2023
1 parent 3287570 commit c08afde
Show file tree
Hide file tree
Showing 14 changed files with 11 additions and 481 deletions.
2 changes: 1 addition & 1 deletion examples/ensemble/plot_adaboost_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# ---------------------------------
# We train the :class:`~sklearn.ensemble.AdaBoostClassifier`. The estimator
# utilizes boosting to improve the classification accuracy. Boosting is a method
# designed to train weak learners (i.e. `base_estimator`) that learn from their
# designed to train weak learners (i.e. `estimator`) that learn from their
# predecessor's mistakes.
#
# Here, we define the weak learner as a
Expand Down
21 changes: 0 additions & 21 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,27 +239,6 @@ def set_params(self, **params):
valid_params[key] = value

for key, sub_params in nested_params.items():
# TODO(1.4): remove specific handling of "base_estimator".
# The "base_estimator" key is special. It was deprecated and
# renamed to "estimator" for several estimators. This means we
# need to translate it here and set sub-parameters on "estimator",
# but only if the user did not explicitly set a value for
# "base_estimator".
if (
key == "base_estimator"
and valid_params[key] == "deprecated"
and self.__module__.startswith("sklearn.")
):
warnings.warn(
(
f"Parameter 'base_estimator' of {self.__class__.__name__} is"
" deprecated in favor of 'estimator'. See"
f" {self.__class__.__name__}'s docstring for more details."
),
FutureWarning,
stacklevel=2,
)
key = "estimator"
valid_params[key].set_params(**sub_params)

return self
Expand Down
38 changes: 3 additions & 35 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
)
from .utils._param_validation import (
HasMethods,
Hidden,
Interval,
StrOptions,
validate_params,
Expand Down Expand Up @@ -161,13 +160,6 @@ class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator)
.. versionadded:: 0.24
base_estimator : estimator instance
This parameter is deprecated. Use `estimator` instead.
.. deprecated:: 1.2
The parameter `base_estimator` is deprecated in 1.2 and will be
removed in 1.4. Use `estimator` instead.
Attributes
----------
classes_ : ndarray of shape (n_classes,)
Expand Down Expand Up @@ -266,12 +258,6 @@ class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator)
"cv": ["cv_object", StrOptions({"prefit"})],
"n_jobs": [Integral, None],
"ensemble": ["boolean"],
"base_estimator": [
HasMethods(["fit", "predict_proba"]),
HasMethods(["fit", "decision_function"]),
None,
Hidden(StrOptions({"deprecated"})),
],
}

def __init__(
Expand All @@ -282,41 +268,23 @@ def __init__(
cv=None,
n_jobs=None,
ensemble=True,
base_estimator="deprecated",
):
self.estimator = estimator
self.method = method
self.cv = cv
self.n_jobs = n_jobs
self.ensemble = ensemble
self.base_estimator = base_estimator

def _get_estimator(self):
"""Resolve which estimator to return (default is LinearSVC)"""
# TODO(1.4): Remove when base_estimator is removed
if self.base_estimator != "deprecated":
if self.estimator is not None:
raise ValueError(
"Both `base_estimator` and `estimator` are set. Only set "
"`estimator` since `base_estimator` is deprecated."
)
warnings.warn(
(
"`base_estimator` was renamed to `estimator` in version 1.2 and "
"will be removed in 1.4."
),
FutureWarning,
)
estimator = self.base_estimator
else:
estimator = self.estimator

if estimator is None:
if self.estimator is None:
# we want all classifiers that don't expose a random_state
# to be deterministic (and we don't want to expose this one).
estimator = LinearSVC(random_state=0, dual="auto")
if _routing_enabled():
estimator.set_fit_request(sample_weight=True)
else:
estimator = self.estimator

return estimator

Expand Down
45 changes: 2 additions & 43 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..metrics import accuracy_score, r2_score
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
from ..utils import check_random_state, column_or_1d, indices_to_mask
from ..utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions
from ..utils._param_validation import HasMethods, Interval, RealNotInt
from ..utils._tags import _safe_tags
from ..utils.metadata_routing import (
_raise_for_unsupported_routing,
Expand Down Expand Up @@ -231,10 +231,8 @@ def _estimator_has(attr):
def check(self):
if hasattr(self, "estimators_"):
return hasattr(self.estimators_[0], attr)
elif self.estimator is not None:
else: # self.estimator is not None
return hasattr(self.estimator, attr)
else: # TODO(1.4): Remove when the base_estimator deprecation cycle ends
return hasattr(self.base_estimator, attr)

return check

Expand Down Expand Up @@ -264,11 +262,6 @@ class BaseBagging(BaseEnsemble, metaclass=ABCMeta):
"n_jobs": [None, Integral],
"random_state": ["random_state"],
"verbose": ["verbose"],
"base_estimator": [
HasMethods(["fit", "predict"]),
StrOptions({"deprecated"}),
None,
],
}

@abstractmethod
Expand All @@ -286,12 +279,10 @@ def __init__(
n_jobs=None,
random_state=None,
verbose=0,
base_estimator="deprecated",
):
super().__init__(
estimator=estimator,
n_estimators=n_estimators,
base_estimator=base_estimator,
)
self.max_samples = max_samples
self.max_features = max_features
Expand Down Expand Up @@ -635,13 +626,6 @@ class BaggingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, BaseBagging)
verbose : int, default=0
Controls the verbosity when fitting and predicting.
base_estimator : object, default="deprecated"
Use `estimator` instead.
.. deprecated:: 1.2
`base_estimator` is deprecated and will be removed in 1.4.
Use `estimator` instead.
Attributes
----------
estimator_ : estimator
Expand All @@ -650,13 +634,6 @@ class BaggingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, BaseBagging)
.. versionadded:: 1.2
`base_estimator_` was renamed to `estimator_`.
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
.. deprecated:: 1.2
`base_estimator_` is deprecated and will be removed in 1.4.
Use `estimator_` instead.
n_features_in_ : int
Number of features seen during :term:`fit`.
Expand Down Expand Up @@ -743,7 +720,6 @@ def __init__(
n_jobs=None,
random_state=None,
verbose=0,
base_estimator="deprecated",
):
super().__init__(
estimator=estimator,
Expand All @@ -757,7 +733,6 @@ def __init__(
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
base_estimator=base_estimator,
)

def _validate_estimator(self):
Expand Down Expand Up @@ -1080,13 +1055,6 @@ class BaggingRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseBagging):
verbose : int, default=0
Controls the verbosity when fitting and predicting.
base_estimator : object, default="deprecated"
Use `estimator` instead.
.. deprecated:: 1.2
`base_estimator` is deprecated and will be removed in 1.4.
Use `estimator` instead.
Attributes
----------
estimator_ : estimator
Expand All @@ -1095,13 +1063,6 @@ class BaggingRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseBagging):
.. versionadded:: 1.2
`base_estimator_` was renamed to `estimator_`.
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
.. deprecated:: 1.2
`base_estimator_` is deprecated and will be removed in 1.4.
Use `estimator_` instead.
n_features_in_ : int
Number of features seen during :term:`fit`.
Expand Down Expand Up @@ -1182,7 +1143,6 @@ def __init__(
n_jobs=None,
random_state=None,
verbose=0,
base_estimator="deprecated",
):
super().__init__(
estimator=estimator,
Expand All @@ -1196,7 +1156,6 @@ def __init__(
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
base_estimator=base_estimator,
)

def predict(self, X):
Expand Down
51 changes: 2 additions & 49 deletions sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# Authors: Gilles Louppe
# License: BSD 3 clause

import warnings
from abc import ABCMeta, abstractmethod
from typing import List

import numpy as np
from joblib import effective_n_jobs

from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor
from ..utils import Bunch, _print_elapsed_time, check_random_state, deprecated
from ..utils import Bunch, _print_elapsed_time, check_random_state
from ..utils.metaestimators import _BaseComposition


Expand Down Expand Up @@ -93,25 +92,11 @@ class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
The list of attributes to use as parameters when instantiating a
new base estimator. If none are given, default parameters are used.
base_estimator : object, default="deprecated"
Use `estimator` instead.
.. deprecated:: 1.2
`base_estimator` is deprecated and will be removed in 1.4.
Use `estimator` instead.
Attributes
----------
estimator_ : estimator
The base estimator from which the ensemble is grown.
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
.. deprecated:: 1.2
`base_estimator_` is deprecated and will be removed in 1.4.
Use `estimator_` instead.
estimators_ : list of estimators
The collection of fitted base estimators.
"""
Expand All @@ -126,15 +111,13 @@ def __init__(
*,
n_estimators=10,
estimator_params=tuple(),
base_estimator="deprecated",
):
# Set parameters
self.estimator = estimator
self.n_estimators = n_estimators
self.estimator_params = estimator_params
self.base_estimator = base_estimator

# Don't instantiate estimators now! Parameters of base_estimator might
# Don't instantiate estimators now! Parameters of estimator might
# still change. Eg., when grid-searching with the nested object syntax.
# self.estimators_ needs to be filled by the derived classes in fit.

Expand All @@ -143,41 +126,11 @@ def _validate_estimator(self, default=None):
Sets the `estimator_` attributes.
"""
if self.estimator is not None and (
self.base_estimator not in [None, "deprecated"]
):
raise ValueError(
"Both `estimator` and `base_estimator` were set. Only set `estimator`."
)

if self.estimator is not None:
self.estimator_ = self.estimator
elif self.base_estimator != "deprecated":
warnings.warn(
(
"`base_estimator` was renamed to `estimator` in version 1.2 and "
"will be removed in 1.4."
),
FutureWarning,
)
if self.base_estimator is not None:
self.estimator_ = self.base_estimator
else:
self.estimator_ = default
else:
self.estimator_ = default

# TODO(1.4): remove
# mypy error: Decorated property not supported
@deprecated( # type: ignore
"Attribute `base_estimator_` was deprecated in version 1.2 and will be removed "
"in 1.4. Use `estimator_` instead."
)
@property
def base_estimator_(self):
"""Estimator used to grow the ensemble."""
return self.estimator_

def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `estimator_` attribute.
Expand Down
Loading

0 comments on commit c08afde

Please sign in to comment.