Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Commit

Permalink
Fix cloning issues with early_stopping (#229)
Browse files Browse the repository at this point in the history
* Fix cloning issues with early_stopping

* Fix test

* Fix test

* Reenable BOHB tests
  • Loading branch information
Yard1 committed Dec 2, 2021
1 parent 1c69092 commit e8fef98
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 73 deletions.
88 changes: 69 additions & 19 deletions tests/test_randomizedsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,34 @@

class RandomizedSearchTest(unittest.TestCase):
def test_clone_estimator(self):
params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
params = dict(lr=tune.loguniform(0.1, 1))
random_search = TuneSearchCV(
SVC(),
SGDClassifier(),
param_distributions=params,
return_train_score=True,
n_jobs=2)
clone(random_search)

random_search = TuneSearchCV(
SGDClassifier(),
early_stopping=True,
param_distributions=params,
return_train_score=True,
n_jobs=2)
clone(random_search)

random_search = TuneSearchCV(
SGDClassifier(),
early_stopping="HyperBandScheduler",
param_distributions=params,
return_train_score=True,
n_jobs=2)
clone(random_search)

random_search = TuneSearchCV(
SGDClassifier(),
early_stopping=True,
search_optimization="bohb",
param_distributions=params,
return_train_score=True,
n_jobs=2)
Expand Down Expand Up @@ -406,12 +431,19 @@ def test_warn_reduce_maxiters(self):
local_dir="./test-result")

def test_warn_early_stop(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

with self.assertWarnsRegex(UserWarning, "max_iters = 1"):
TuneSearchCV(
LogisticRegression(), {"C": [1, 2]}, early_stopping=True)
LogisticRegression(), {
"C": [1, 2]
}, early_stopping=True).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters = 1"):
TuneSearchCV(
SGDClassifier(), {"epsilon": [0.1, 0.2]}, early_stopping=True)
SGDClassifier(), {
"epsilon": [0.1, 0.2]
}, early_stopping=True).fit(X, y)

def test_warn_user_params(self):
X, y = make_classification(
Expand All @@ -430,46 +462,67 @@ def test_warn_user_params(self):

@unittest.skipIf(not has_xgboost(), "xgboost not installed")
def test_early_stop_xgboost_warn(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

from xgboost.sklearn import XGBClassifier
with self.assertWarnsRegex(UserWarning, "github.com"):
TuneSearchCV(
XGBClassifier(), {"C": [1, 2]},
XGBClassifier(), {
"C": [1, 2]
},
early_stopping=True,
max_iters=10)
max_iters=10).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters"):
TuneSearchCV(
XGBClassifier(), {"C": [1, 2]},
XGBClassifier(), {
"C": [1, 2]
},
early_stopping=True,
max_iters=1)
max_iters=1).fit(X, y)

@unittest.skipIf(not has_required_lightgbm_version(),
"lightgbm not installed")
def test_early_stop_lightgbm_warn(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

from lightgbm import LGBMClassifier
with self.assertWarnsRegex(UserWarning, "lightgbm"):
TuneSearchCV(
LGBMClassifier(), {"learning_rate": [0.1, 0.5]},
LGBMClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=10)
max_iters=10).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters"):
TuneSearchCV(
LGBMClassifier(), {"learning_rate": [0.1, 0.5]},
LGBMClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=1)
max_iters=1).fit(X, y)

@unittest.skipIf(not has_catboost(), "catboost not installed")
def test_early_stop_catboost_warn(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

from catboost import CatBoostClassifier
with self.assertWarnsRegex(UserWarning, "Catboost"):
TuneSearchCV(
CatBoostClassifier(), {"learning_rate": [0.1, 0.5]},
CatBoostClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=10)
max_iters=10).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters"):
TuneSearchCV(
CatBoostClassifier(), {"learning_rate": [0.1, 0.5]},
CatBoostClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=1)
max_iters=1).fit(X, y)

def test_pipeline_early_stop(self):
digits = datasets.load_digits()
Expand Down Expand Up @@ -654,7 +707,6 @@ def testBayesian(self):
def testHyperopt(self):
self._test_method("hyperopt")

@unittest.skip("bohb test currently failing")
def testBohb(self):
self._test_method("bohb")

Expand Down Expand Up @@ -806,7 +858,6 @@ def testHyperoptPointsToEvaluate(self):
return
self._test_points_to_evaluate("hyperopt")

@unittest.skip("bohb currently failing not installed")
def testBOHBPointsToEvaluate(self):
self._test_points_to_evaluate("bohb")

Expand Down Expand Up @@ -867,7 +918,6 @@ def test_seed_bayesian(self):
self._test_seed_run("bayesian", seed=1234)
self._test_seed_run("bayesian", seed="1234")

@unittest.skip("BOHB is currently failing")
def test_seed_bohb(self):
self._test_seed_run("bohb", seed=1234)
self._test_seed_run("bohb", seed="1234")
Expand Down
97 changes: 50 additions & 47 deletions tune_sklearn/tune_basesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,52 +386,6 @@ def __init__(self,

self._metric_name = "average_test_%s" % self._base_metric_name

if early_stopping:
if not self._can_early_stop() and is_lightgbm_model(
self.base_estimator):
warnings.warn("lightgbm>=3.0.0 required for early_stopping "
"functionality.")
assert self._can_early_stop()
if max_iters == 1:
warnings.warn(
"early_stopping is enabled but max_iters = 1. "
"To enable partial training, set max_iters > 1.",
category=UserWarning)
if self.early_stop_type == EarlyStopping.XGB:
warnings.warn(
"tune-sklearn implements incremental learning "
"for xgboost models following this: "
"https://github.com/dmlc/xgboost/issues/1686. "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.LGBM:
warnings.warn(
"tune-sklearn implements incremental learning "
"for lightgbm models following this: "
"https://lightgbm.readthedocs.io/en/latest/pythonapi/"
"lightgbm.LGBMModel.html#lightgbm.LGBMModel.fit "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.CATBOOST:
warnings.warn(
"tune-sklearn implements incremental learning "
"for Catboost models following this: "
"https://catboost.ai/docs/concepts/python-usages-"
"examples.html#training-continuation "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
if early_stopping is True:
# Override the early_stopping variable so
# that it is resolved appropriately in
# the next block
early_stopping = "AsyncHyperBandScheduler"
# Resolve the early stopping object
early_stopping = resolve_early_stopping(early_stopping, max_iters,
self._metric_name)

self.early_stopping = early_stopping
self.max_iters = max_iters

Expand Down Expand Up @@ -474,6 +428,55 @@ def _fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
"""
self._check_params()
classifier = is_classifier(self.estimator)

early_stopping = self.early_stopping
if early_stopping:
if not self._can_early_stop() and is_lightgbm_model(
self.base_estimator):
warnings.warn("lightgbm>=3.0.0 required for early_stopping "
"functionality.")
assert self._can_early_stop()
if self.max_iters == 1:
warnings.warn(
"early_stopping is enabled but max_iters = 1. "
"To enable partial training, set max_iters > 1.",
category=UserWarning)
if self.early_stop_type == EarlyStopping.XGB:
warnings.warn(
"tune-sklearn implements incremental learning "
"for xgboost models following this: "
"https://github.com/dmlc/xgboost/issues/1686. "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.LGBM:
warnings.warn(
"tune-sklearn implements incremental learning "
"for lightgbm models following this: "
"https://lightgbm.readthedocs.io/en/latest/pythonapi/"
"lightgbm.LGBMModel.html#lightgbm.LGBMModel.fit "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.CATBOOST:
warnings.warn(
"tune-sklearn implements incremental learning "
"for Catboost models following this: "
"https://catboost.ai/docs/concepts/python-usages-"
"examples.html#training-continuation "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
if self.early_stopping is True:
# Override the early_stopping variable so
# that it is resolved appropriately in
# the next block
early_stopping = "AsyncHyperBandScheduler"
# Resolve the early stopping object
early_stopping = resolve_early_stopping(
early_stopping, self.max_iters, self._metric_name)
self.early_stopping_ = early_stopping

cv = check_cv(cv=self.cv, y=y, classifier=classifier)
self.n_splits = cv.get_n_splits(X, y, groups)
if not hasattr(self, "_is_multi"):
Expand Down Expand Up @@ -519,7 +522,7 @@ def _fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
y_id = ray.put(y)

config = {}
config["early_stopping"] = bool(self.early_stopping)
config["early_stopping"] = bool(self.early_stopping_)
config["early_stop_type"] = self.early_stop_type
config["X_id"] = X_id
config["y_id"] = y_id
Expand Down
6 changes: 3 additions & 3 deletions tune_sklearn/tune_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):
"""
trainable = _Trainable
if self.pipeline_auto_early_stop and check_is_pipeline(
self.estimator) and self.early_stopping:
self.estimator) and self.early_stopping_:
trainable = _PipelineTrainable

if self.early_stopping is not None:
if self.early_stopping_ is not None:
config["estimator_ids"] = [
ray.put(self.estimator) for _ in range(self.n_splits)
]
Expand All @@ -261,7 +261,7 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):
stopper = CombinedStopper(stopper, self.stopper)

run_args = dict(
scheduler=self.early_stopping,
scheduler=self.early_stopping_,
reuse_actors=True,
verbose=self.verbose,
stop=stopper,
Expand Down
8 changes: 4 additions & 4 deletions tune_sklearn/tune_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,15 +656,15 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):

trainable = _Trainable
if self.pipeline_auto_early_stop and check_is_pipeline(
self.estimator) and self.early_stopping:
self.estimator) and self.early_stopping_:
trainable = _PipelineTrainable

max_iter = self.max_iters
if self.early_stopping is not None:
if self.early_stopping_ is not None:
config["estimator_ids"] = [
ray.put(self.estimator) for _ in range(self.n_splits)
]
if hasattr(self.early_stopping, "_max_t_attr"):
if hasattr(self.early_stopping_, "_max_t_attr"):
# we want to delegate stopping to schedulers which
# support it, but we want it to stop eventually, just in case
# the solution is to make the stop condition very big
Expand All @@ -676,7 +676,7 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):
if self.stopper:
stopper = CombinedStopper(stopper, self.stopper)
run_args = dict(
scheduler=self.early_stopping,
scheduler=self.early_stopping_,
reuse_actors=True,
verbose=self.verbose,
stop=stopper,
Expand Down

0 comments on commit e8fef98

Please sign in to comment.