Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Allow user to pass own tune.run params in fit #212

Merged
merged 8 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion tests/test_randomizedsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tune_sklearn._detect_booster import (has_xgboost, has_catboost,
has_required_lightgbm_version)
from tune_sklearn.utils import EarlyStopping
from test_utils import SleepClassifier, PlateauClassifier
from test_utils import SleepClassifier, PlateauClassifier, MockClassifier


class RandomizedSearchTest(unittest.TestCase):
Expand Down Expand Up @@ -403,6 +403,21 @@ def test_warn_early_stop(self):
TuneSearchCV(
SGDClassifier(), {"epsilon": [0.1, 0.2]}, early_stopping=True)

def test_warn_user_params(self):
X, y = make_classification(
n_samples=50, n_features=50, n_informative=3, random_state=0)

clf = MockClassifier()

search = TuneSearchCV(
clf, {"foo_param": [2.0, 3.0, 4.0]}, cv=2, max_iters=2)

with self.assertWarnsRegex(
UserWarning,
"The following preset tune.run parameters will be overriden "
"by tune_params: fail_fast."):
search.fit(X, y, tune_params={"fail_fast": "raise"})

@unittest.skipIf(not has_xgboost(), "xgboost not installed")
def test_early_stop_xgboost_warn(self):
from xgboost.sklearn import XGBClassifier
Expand Down
32 changes: 27 additions & 5 deletions tune_sklearn/tune_basesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def __init__(self,
self.loggers = resolve_loggers(loggers)
assert isinstance(self.n_jobs, int)

def _fit(self, X, y=None, groups=None, **fit_params):
def _fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
"""Helper method to run fit procedure

Args:
Expand All @@ -497,6 +497,8 @@ def _fit(self, X, y=None, groups=None, **fit_params):
Group labels for the samples used while splitting the dataset
into train/test set. Only used in conjunction with a "Group"
`cv` instance (e.g., `GroupKFold`).
tune_params (:obj:`dict`, optional):
Parameters passed to ``tune.run`` used for parameter search.
**fit_params (:obj:`dict` of str): Parameters passed to
the ``fit`` method of the estimator.

Expand Down Expand Up @@ -564,7 +566,7 @@ def _fit(self, X, y=None, groups=None, **fit_params):
config["metric_name"] = self._metric_name

self._fill_config_hyperparam(config)
analysis = self._tune_run(config, resources_per_trial)
analysis = self._tune_run(config, resources_per_trial, tune_params)

self.cv_results_ = self._format_results(self.n_splits, analysis)

Expand Down Expand Up @@ -627,7 +629,7 @@ def _fit(self, X, y=None, groups=None, **fit_params):

return self

def fit(self, X, y=None, groups=None, **fit_params):
def fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
"""Run fit with all sets of parameters.

``tune.run`` is used to perform the fit procedure.
Expand All @@ -643,6 +645,8 @@ def fit(self, X, y=None, groups=None, **fit_params):
Group labels for the samples used while splitting the dataset
into train/test set. Only used in conjunction with a "Group"
`cv` instance (e.g., `GroupKFold`).
tune_params (:obj:`dict`, optional):
Parameters passed to ``tune.run`` used for parameter search.
**fit_params (:obj:`dict` of str): Parameters passed to
the ``fit`` method of the estimator.

Expand Down Expand Up @@ -670,7 +674,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
logger.info("TIP: Hiding process output by default. "
"To show process output, set verbose=2.")

result = self._fit(X, y, groups, **fit_params)
result = self._fit(X, y, groups, tune_params, **fit_params)

if not ray_init and ray.is_initialized():
ray.shutdown()
Expand Down Expand Up @@ -735,18 +739,36 @@ def _fill_config_hyperparam(self, config):
"""
raise NotImplementedError("Define in child class")

def _tune_run(self, config, resources_per_trial):
def _tune_run(self, config, resources_per_trial, tune_params=None):
"""Wrapper to call ``tune.run``. Implement this in a child class.

Args:
config (:obj:`dict`): dictionary to be passed in as the
configuration for `tune.run`.
resources_per_trial (:obj:`dict` of int): dictionary specifying the
number of cpu's and gpu's to use to train the model.
tune_params (dict): User defined parameters passed to
``tune.run``. Parameters inside `tune_params` override
preset parameters.

"""
raise NotImplementedError("Define in child class")

def _override_run_args_with_tune_params(self, run_args, tune_params):
"""Helper to override tune.run run arguments with user supplied dict"""

if tune_params:
user_overrides = {k for k in tune_params if k in run_args}
if user_overrides:
warnings.warn(
"The following preset tune.run parameters will "
f"be overriden by tune_params: {', '.join(user_overrides)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a great error message!

". This may cause unexpected issues! If you experience "
"issues, please try removing those parameters from "
"tune_params.")
run_args = {**run_args, **tune_params}
return run_args

def _clean_config_dict(self, config):
"""Helper to remove keys from the ``config`` dictionary returned from
``tune.run``.
Expand Down
8 changes: 7 additions & 1 deletion tune_sklearn/tune_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _list_grid_num_samples(self):
"""
return len(list(ParameterGrid(self.param_grid)))

def _tune_run(self, config, resources_per_trial):
def _tune_run(self, config, resources_per_trial, tune_params=None):
"""Wrapper to call ``tune.run``. Multiple estimators are generated when
early stopping is possible, whereas a single estimator is
generated when early stopping is not possible.
Expand All @@ -234,6 +234,9 @@ def _tune_run(self, config, resources_per_trial):
resources_per_trial (dict): Resources to use per trial within Ray.
Accepted keys are `cpu`, `gpu` and custom resources, and values
are integers specifying the number of each resource to use.
tune_params (dict): User defined parameters passed to
``tune.run``. Parameters inside `tune_params` override
preset parameters.

Returns:
analysis (`ExperimentAnalysis`): Object returned by
Expand Down Expand Up @@ -277,6 +280,9 @@ def _tune_run(self, config, resources_per_trial):
search_alg=ListSearcher(self.param_grid),
num_samples=self._list_grid_num_samples()))

run_args = self._override_run_args_with_tune_params(
run_args, tune_params)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="fail_fast='raise' "
Expand Down
8 changes: 7 additions & 1 deletion tune_sklearn/tune_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def _try_import_required_libraries(self, search_optimization):
raise ImportError("It appears that optuna is not installed. "
"Do: pip install optuna") from None

def _tune_run(self, config, resources_per_trial):
def _tune_run(self, config, resources_per_trial, tune_params=None):
"""Wrapper to call ``tune.run``. Multiple estimators are generated when
early stopping is possible, whereas a single estimator is
generated when early stopping is not possible.
Expand All @@ -626,6 +626,9 @@ def _tune_run(self, config, resources_per_trial):
resources_per_trial (dict): Resources to use per trial within Ray.
Accepted keys are `cpu`, `gpu` and custom resources, and values
are integers specifying the number of each resource to use.
tune_params (dict): User defined parameters passed to
``tune.run``. Parameters inside `tune_params` override
preset parameters.

Returns:
analysis (`ExperimentAnalysis`): Object returned by
Expand Down Expand Up @@ -772,6 +775,9 @@ def _tune_run(self, config, resources_per_trial):
search_algo, max_concurrent=self.n_jobs)
run_args["search_alg"] = search_algo

run_args = self._override_run_args_with_tune_params(
run_args, tune_params)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="fail_fast='raise' "
Expand Down