Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Commit

Permalink
Simplify fit (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 committed Aug 2, 2021
1 parent a700d39 commit 0502705
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 32 deletions.
42 changes: 10 additions & 32 deletions tune_sklearn/tune_basesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
Logger)

from tune_sklearn.utils import (EarlyStopping, get_early_stop_type,
check_is_pipeline, _check_multimetric_scoring)
check_is_pipeline, _check_multimetric_scoring,
ray_context)
from tune_sklearn._detect_booster import is_lightgbm_model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -654,37 +655,14 @@ def fit(self, X, y=None, groups=None, tune_params=None, **fit_params):
:obj:`TuneBaseSearchCV` child instance, after fitting.
"""
ray_init = ray.is_initialized()
try:
if not ray_init:
if self.n_jobs == 1:
ray.init(
local_mode=True,
configure_logging=False,
ignore_reinit_error=True,
include_dashboard=False)
else:
ray.init(
ignore_reinit_error=True,
configure_logging=False,
include_dashboard=False
# log_to_driver=self.verbose == 2
)
if self.verbose != 2:
logger.info("TIP: Hiding process output by default. "
"To show process output, set verbose=2.")

result = self._fit(X, y, groups, tune_params, **fit_params)

if not ray_init and ray.is_initialized():
ray.shutdown()

return result

except Exception:
if not ray_init and ray.is_initialized():
ray.shutdown()
raise
ray_kwargs = dict(
configure_logging=False,
ignore_reinit_error=True,
include_dashboard=False)
if self.n_jobs == 1:
ray_kwargs["local_mode"] = True
with ray_context(**ray_kwargs):
return self._fit(X, y, groups, tune_params, **fit_params)

def score(self, X, y=None):
"""Compute the score(s) of an estimator on a given test set.
Expand Down
23 changes: 23 additions & 0 deletions tune_sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
from enum import Enum, auto
from collections.abc import Sequence
import contextlib
import ray

try:
from ray.tune.stopper import MaximumIterationStopper
Expand Down Expand Up @@ -261,3 +263,24 @@ def _check_param_grid_tune_grid_search(param_grid):
if len(v) == 0:
raise ValueError("Parameter values for parameter ({0}) need "
"to be a non-empty sequence.".format(name))


class ray_context(contextlib.AbstractContextManager):
"""Context to initialize and shutdown Ray."""

def __init__(self, force_reinit: bool = False, **kwargs) -> None:
self.ray_init_kwargs = kwargs
self.force_reinit = force_reinit

def __enter__(self):
self.was_ray_initialized_ = ray.is_initialized()
if self.force_reinit or not self.was_ray_initialized_:
kwargs = self.ray_init_kwargs.copy()
if self.force_reinit:
kwargs["ignore_reinit_error"] = True
ray.init(**kwargs)
return self

def __exit__(self, exc_type, exc_value, traceback):
if not self.was_ray_initialized_ and ray.is_initialized():
ray.shutdown()

0 comments on commit 0502705

Please sign in to comment.