Skip to content

Commit

Permalink
[tune] Add xgboost_ray integration (#12572)
Browse files Browse the repository at this point in the history
  • Loading branch information
krfricke committed Dec 4, 2020
1 parent 219c445 commit 1c0d10f
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions python/ray/tune/integration/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(self,
metrics = [metrics]
self._metrics = metrics

def __call__(self, env):
def _get_report_dict(self, env):
# Only one worker should report to Tune
result_dict = dict(env.evaluation_result_list)
if not self._metrics:
report_dict = result_dict
Expand All @@ -66,6 +67,10 @@ def __call__(self, env):
else:
metric = key
report_dict[key] = result_dict[metric]
return report_dict

def __call__(self, env):
report_dict = self._get_report_dict(env)
tune.report(**report_dict)


Expand All @@ -81,15 +86,24 @@ class _TuneCheckpointCallback(TuneCallback):
Args:
filename (str): Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
frequency (int): How often to save checkpoints. Per default, a
checkpoint is saved every five iterations.
"""

def __init__(self, filename: str = "checkpoint"):
def __init__(self, filename: str = "checkpoint", frequency: int = 5):
self._filename = filename
self._frequency = frequency

def __call__(self, env):
@staticmethod
def _create_checkpoint(env, filename: str, frequency: int):
if env.iteration % frequency > 0:
return
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
env.model.save_model(os.path.join(checkpoint_dir, self._filename))
env.model.save_model(os.path.join(checkpoint_dir, filename))

def __call__(self, env):
self._create_checkpoint(env, self._filename, self._frequency)


class TuneReportCheckpointCallback(TuneCallback):
Expand All @@ -108,6 +122,8 @@ class TuneReportCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint". If this is None,
all metrics will be reported to Tune under their default names as
obtained from XGBoost.
frequency (int): How often to save checkpoints. Per default, a
checkpoint is saved every five iterations.
Example:
Expand All @@ -132,12 +148,15 @@ class TuneReportCheckpointCallback(TuneCallback):
{"loss": "eval-logloss"}, "xgboost.mdl)])
"""
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callbacks_cls = TuneReportCallback

def __init__(self,
metrics: Union[None, str, List[str], Dict[str, str]] = None,
filename: str = "checkpoint"):
self._checkpoint = _TuneCheckpointCallback(filename)
self._report = TuneReportCallback(metrics)
filename: str = "checkpoint",
frequency: int = 5):
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
self._report = self._report_callbacks_cls(metrics)

def __call__(self, env):
self._checkpoint(env)
Expand Down

0 comments on commit 1c0d10f

Please sign in to comment.