-
Notifications
You must be signed in to change notification settings - Fork 51
Conversation
Still need to add some tests and documentation |
def check_warm_start_ensemble(estimator): | ||
from sklearn.ensemble import BaseEnsemble | ||
is_ensemble_subclass = issubclass(type(estimator), BaseEnsemble) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you update the above? check_warm_start
can probably return True if is ensemble now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose to separate it because warm start on non-ensemble estimators means we set max_iter=1
, but warm start on ensemble estimators means we set n_estimators=1
. I thought it'd be easier to handle these different cases with two separate checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you then rename the above to can_warm_start_iter
?
tune_sklearn/_trainable.py
Outdated
updated_n_estimators = self.estimator_list[i].get_params( | ||
)["n_estimators"] + 1 | ||
self.estimator_list[i].set_params( | ||
**{"n_estimators": updated_n_estimators}) | ||
self.estimator_list[i].fit(X_train, y_train) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename estimator = self.estimator_list[i]
@@ -92,6 +97,10 @@ def _setup(self, config): | |||
self.estimator_config["warm_start"] = True | |||
self.estimator_config["max_iter"] = 1 | |||
|
|||
if not self._can_partial_fit() and self._can_warm_start_ensemble(): | |||
self.estimator_config["warm_start"] = True | |||
self.estimator_config["n_estimators"] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment about implementation?
@@ -405,6 +405,10 @@ def _fit(self, X, y=None, groups=None, **fit_params): | |||
mode="max", | |||
scope="last") | |||
self.best_params = self._clean_config_dict(best_config) | |||
if not check_partial_fit( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I think we should actually log something here. I think this behavior is somewhat of a "surprise", so being upfront that this methodology is being leveraged would be nice to mention.
Wouldn't it be a good idea to explicitly raise an exception if the user tries to tune |
@Yard1 Yes, that's a good idea. Let me make a quick issue. |
Add functionality to warm start ensembles estimators in sklearn. This relies on incrementing
n_estimators
one at a time to get the effect of fitting one estimator at a time in the ensemble. This unfortunately means the user can't cross validate then_estimators
parameter if they choose to early stop.