Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Simplify TuneBaseSearchCV.fit #215

Merged
merged 1 commit into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 10 additions & 32 deletions tune_sklearn/tune_basesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
Logger)

from tune_sklearn.utils import (EarlyStopping, get_early_stop_type,
check_is_pipeline, _check_multimetric_scoring)
check_is_pipeline, _check_multimetric_scoring,
ray_context)
from tune_sklearn._detect_booster import is_lightgbm_model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -654,37 +655,14 @@ def fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
:obj:`TuneBaseSearchCV` child instance, after fitting.

"""
ray_init = ray.is_initialized()
try:
if not ray_init:
if self.n_jobs == 1:
ray.init(
local_mode=True,
configure_logging=False,
ignore_reinit_error=True,
include_dashboard=False)
else:
ray.init(
ignore_reinit_error=True,
configure_logging=False,
include_dashboard=False
# log_to_driver=self.verbose == 2
)
if self.verbose != 2:
logger.info("TIP: Hiding process output by default. "
"To show process output, set verbose=2.")

result = self._fit(X, y, groups, tune_params, **fit_params)

if not ray_init and ray.is_initialized():
ray.shutdown()

return result

except Exception:
if not ray_init and ray.is_initialized():
ray.shutdown()
raise
ray_kwargs = dict(
configure_logging=False,
ignore_reinit_error=True,
include_dashboard=False)
if self.n_jobs == 1:
ray_kwargs["local_mode"] = True
with ray_context(**ray_kwargs):
return self._fit(X, y, groups, tune_params, **fit_params)

def score(self, X, y=None):
"""Compute the score(s) of an estimator on a given test set.
Expand Down
23 changes: 23 additions & 0 deletions tune_sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
from enum import Enum, auto
from collections.abc import Sequence
import contextlib
import ray

try:
from ray.tune.stopper import MaximumIterationStopper
Expand Down Expand Up @@ -261,3 +263,24 @@ def _check_param_grid_tune_grid_search(param_grid):
if len(v) == 0:
raise ValueError("Parameter values for parameter ({0}) need "
"to be a non-empty sequence.".format(name))


class ray_context(contextlib.AbstractContextManager):
"""Context to initialize and shutdown Ray."""

def __init__(self, force_reinit: bool = False, **kwargs) -> None:
self.ray_init_kwargs = kwargs
self.force_reinit = force_reinit

def __enter__(self):
self.was_ray_initialized_ = ray.is_initialized()
if self.force_reinit or not self.was_ray_initialized_:
kwargs = self.ray_init_kwargs.copy()
if self.force_reinit:
kwargs["ignore_reinit_error"] = True
ray.init(**kwargs)
return self

def __exit__(self, exc_type, exc_value, traceback):
if not self.was_ray_initialized_ and ray.is_initialized():
ray.shutdown()