Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Fix cloning issues with early_stopping #229

Merged
merged 4 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 69 additions & 19 deletions tests/test_randomizedsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,34 @@

class RandomizedSearchTest(unittest.TestCase):
def test_clone_estimator(self):
params = dict(C=expon(scale=10), gamma=expon(scale=0.1))
params = dict(lr=tune.loguniform(0.1, 1))
random_search = TuneSearchCV(
SVC(),
SGDClassifier(),
param_distributions=params,
return_train_score=True,
n_jobs=2)
clone(random_search)

random_search = TuneSearchCV(
SGDClassifier(),
early_stopping=True,
param_distributions=params,
return_train_score=True,
n_jobs=2)
clone(random_search)

random_search = TuneSearchCV(
SGDClassifier(),
early_stopping="HyperBandScheduler",
param_distributions=params,
return_train_score=True,
n_jobs=2)
clone(random_search)

random_search = TuneSearchCV(
SGDClassifier(),
early_stopping=True,
search_optimization="bohb",
param_distributions=params,
return_train_score=True,
n_jobs=2)
Expand Down Expand Up @@ -406,12 +431,19 @@ def test_warn_reduce_maxiters(self):
local_dir="./test-result")

def test_warn_early_stop(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

with self.assertWarnsRegex(UserWarning, "max_iters = 1"):
TuneSearchCV(
LogisticRegression(), {"C": [1, 2]}, early_stopping=True)
LogisticRegression(), {
"C": [1, 2]
}, early_stopping=True).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters = 1"):
TuneSearchCV(
SGDClassifier(), {"epsilon": [0.1, 0.2]}, early_stopping=True)
SGDClassifier(), {
"epsilon": [0.1, 0.2]
}, early_stopping=True).fit(X, y)

def test_warn_user_params(self):
X, y = make_classification(
Expand All @@ -430,46 +462,67 @@ def test_warn_user_params(self):

@unittest.skipIf(not has_xgboost(), "xgboost not installed")
def test_early_stop_xgboost_warn(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

from xgboost.sklearn import XGBClassifier
with self.assertWarnsRegex(UserWarning, "github.com"):
TuneSearchCV(
XGBClassifier(), {"C": [1, 2]},
XGBClassifier(), {
"C": [1, 2]
},
early_stopping=True,
max_iters=10)
max_iters=10).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters"):
TuneSearchCV(
XGBClassifier(), {"C": [1, 2]},
XGBClassifier(), {
"C": [1, 2]
},
early_stopping=True,
max_iters=1)
max_iters=1).fit(X, y)

@unittest.skipIf(not has_required_lightgbm_version(),
"lightgbm not installed")
def test_early_stop_lightgbm_warn(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

from lightgbm import LGBMClassifier
with self.assertWarnsRegex(UserWarning, "lightgbm"):
TuneSearchCV(
LGBMClassifier(), {"learning_rate": [0.1, 0.5]},
LGBMClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=10)
max_iters=10).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters"):
TuneSearchCV(
LGBMClassifier(), {"learning_rate": [0.1, 0.5]},
LGBMClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=1)
max_iters=1).fit(X, y)

@unittest.skipIf(not has_catboost(), "catboost not installed")
def test_early_stop_catboost_warn(self):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=0)

from catboost import CatBoostClassifier
with self.assertWarnsRegex(UserWarning, "Catboost"):
TuneSearchCV(
CatBoostClassifier(), {"learning_rate": [0.1, 0.5]},
CatBoostClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=10)
max_iters=10).fit(X, y)
with self.assertWarnsRegex(UserWarning, "max_iters"):
TuneSearchCV(
CatBoostClassifier(), {"learning_rate": [0.1, 0.5]},
CatBoostClassifier(), {
"learning_rate": [0.1, 0.5]
},
early_stopping=True,
max_iters=1)
max_iters=1).fit(X, y)

def test_pipeline_early_stop(self):
digits = datasets.load_digits()
Expand Down Expand Up @@ -654,7 +707,6 @@ def testBayesian(self):
def testHyperopt(self):
self._test_method("hyperopt")

@unittest.skip("bohb test currently failing")
def testBohb(self):
self._test_method("bohb")

Expand Down Expand Up @@ -806,7 +858,6 @@ def testHyperoptPointsToEvaluate(self):
return
self._test_points_to_evaluate("hyperopt")

@unittest.skip("bohb currently failing not installed")
def testBOHBPointsToEvaluate(self):
self._test_points_to_evaluate("bohb")

Expand Down Expand Up @@ -867,7 +918,6 @@ def test_seed_bayesian(self):
self._test_seed_run("bayesian", seed=1234)
self._test_seed_run("bayesian", seed="1234")

@unittest.skip("BOHB is currently failing")
def test_seed_bohb(self):
self._test_seed_run("bohb", seed=1234)
self._test_seed_run("bohb", seed="1234")
Expand Down
97 changes: 50 additions & 47 deletions tune_sklearn/tune_basesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,52 +386,6 @@ def __init__(self,

self._metric_name = "average_test_%s" % self._base_metric_name

if early_stopping:
if not self._can_early_stop() and is_lightgbm_model(
self.base_estimator):
warnings.warn("lightgbm>=3.0.0 required for early_stopping "
"functionality.")
assert self._can_early_stop()
if max_iters == 1:
warnings.warn(
"early_stopping is enabled but max_iters = 1. "
"To enable partial training, set max_iters > 1.",
category=UserWarning)
if self.early_stop_type == EarlyStopping.XGB:
warnings.warn(
"tune-sklearn implements incremental learning "
"for xgboost models following this: "
"https://github.com/dmlc/xgboost/issues/1686. "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.LGBM:
warnings.warn(
"tune-sklearn implements incremental learning "
"for lightgbm models following this: "
"https://lightgbm.readthedocs.io/en/latest/pythonapi/"
"lightgbm.LGBMModel.html#lightgbm.LGBMModel.fit "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.CATBOOST:
warnings.warn(
"tune-sklearn implements incremental learning "
"for Catboost models following this: "
"https://catboost.ai/docs/concepts/python-usages-"
"examples.html#training-continuation "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
if early_stopping is True:
# Override the early_stopping variable so
# that it is resolved appropriately in
# the next block
early_stopping = "AsyncHyperBandScheduler"
# Resolve the early stopping object
early_stopping = resolve_early_stopping(early_stopping, max_iters,
self._metric_name)

self.early_stopping = early_stopping
self.max_iters = max_iters

Expand Down Expand Up @@ -474,6 +428,55 @@ def _fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
"""
self._check_params()
classifier = is_classifier(self.estimator)

early_stopping = self.early_stopping
if early_stopping:
if not self._can_early_stop() and is_lightgbm_model(
self.base_estimator):
warnings.warn("lightgbm>=3.0.0 required for early_stopping "
"functionality.")
assert self._can_early_stop()
if self.max_iters == 1:
warnings.warn(
"early_stopping is enabled but max_iters = 1. "
"To enable partial training, set max_iters > 1.",
category=UserWarning)
if self.early_stop_type == EarlyStopping.XGB:
warnings.warn(
"tune-sklearn implements incremental learning "
"for xgboost models following this: "
"https://github.com/dmlc/xgboost/issues/1686. "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.LGBM:
warnings.warn(
"tune-sklearn implements incremental learning "
"for lightgbm models following this: "
"https://lightgbm.readthedocs.io/en/latest/pythonapi/"
"lightgbm.LGBMModel.html#lightgbm.LGBMModel.fit "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
elif self.early_stop_type == EarlyStopping.CATBOOST:
warnings.warn(
"tune-sklearn implements incremental learning "
"for Catboost models following this: "
"https://catboost.ai/docs/concepts/python-usages-"
"examples.html#training-continuation "
"This may negatively impact performance. To "
"disable, set `early_stopping=False`.",
category=UserWarning)
if self.early_stopping is True:
# Override the early_stopping variable so
# that it is resolved appropriately in
# the next block
early_stopping = "AsyncHyperBandScheduler"
# Resolve the early stopping object
early_stopping = resolve_early_stopping(
early_stopping, self.max_iters, self._metric_name)
self.early_stopping_ = early_stopping

cv = check_cv(cv=self.cv, y=y, classifier=classifier)
self.n_splits = cv.get_n_splits(X, y, groups)
if not hasattr(self, "_is_multi"):
Expand Down Expand Up @@ -519,7 +522,7 @@ def _fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
y_id = ray.put(y)

config = {}
config["early_stopping"] = bool(self.early_stopping)
config["early_stopping"] = bool(self.early_stopping_)
config["early_stop_type"] = self.early_stop_type
config["X_id"] = X_id
config["y_id"] = y_id
Expand Down
6 changes: 3 additions & 3 deletions tune_sklearn/tune_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):
"""
trainable = _Trainable
if self.pipeline_auto_early_stop and check_is_pipeline(
self.estimator) and self.early_stopping:
self.estimator) and self.early_stopping_:
trainable = _PipelineTrainable

if self.early_stopping is not None:
if self.early_stopping_ is not None:
config["estimator_ids"] = [
ray.put(self.estimator) for _ in range(self.n_splits)
]
Expand All @@ -261,7 +261,7 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):
stopper = CombinedStopper(stopper, self.stopper)

run_args = dict(
scheduler=self.early_stopping,
scheduler=self.early_stopping_,
reuse_actors=True,
verbose=self.verbose,
stop=stopper,
Expand Down
8 changes: 4 additions & 4 deletions tune_sklearn/tune_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,15 +656,15 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):

trainable = _Trainable
if self.pipeline_auto_early_stop and check_is_pipeline(
self.estimator) and self.early_stopping:
self.estimator) and self.early_stopping_:
trainable = _PipelineTrainable

max_iter = self.max_iters
if self.early_stopping is not None:
if self.early_stopping_ is not None:
config["estimator_ids"] = [
ray.put(self.estimator) for _ in range(self.n_splits)
]
if hasattr(self.early_stopping, "_max_t_attr"):
if hasattr(self.early_stopping_, "_max_t_attr"):
# we want to delegate stopping to schedulers which
# support it, but we want it to stop eventually, just in case
# the solution is to make the stop condition very big
Expand All @@ -676,7 +676,7 @@ def _tune_run(self, config, resources_per_trial, tune_params=None):
if self.stopper:
stopper = CombinedStopper(stopper, self.stopper)
run_args = dict(
scheduler=self.early_stopping,
scheduler=self.early_stopping_,
reuse_actors=True,
verbose=self.verbose,
stop=stopper,
Expand Down