Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Simplify ray.train.xgboost/lightgbm (1/n): Align frequency-based and checkpoint_at_end checkpoint formats #42111

Merged
merged 57 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4cb0a9f
allow xgboost checkpoint to take in custom path
justinvyu Dec 26, 2023
d27c12b
switch ckpting logic to use xgboost checkpoint cls
justinvyu Dec 26, 2023
7c4dc27
hard deprecate tune report ckpt
justinvyu Dec 26, 2023
c92eaa1
remove unused methods/attrs
justinvyu Dec 26, 2023
9172933
do the same for lgbm
justinvyu Dec 27, 2023
f2c9758
add configurability of temp ckpt path for lgbm checkpoint
justinvyu Dec 27, 2023
634cacb
fix error msg wording
justinvyu Dec 27, 2023
9b37c58
use lgbm ckpt for ckpt saving in lgbm callback
justinvyu Dec 27, 2023
8751718
fix lint
justinvyu Dec 27, 2023
e0a4afa
fix docstrings
justinvyu Dec 27, 2023
e1f9798
add support for checkpoint at end for lightgbm
justinvyu Dec 27, 2023
95d4f87
standaredize the method of loading xgb/lgbm ckpts as well
justinvyu Dec 27, 2023
8b6c0c1
remove commented code + add todo
justinvyu Dec 27, 2023
f042ca2
remove useless methods
justinvyu Dec 28, 2023
ec4bfb6
fix lint
justinvyu Dec 28, 2023
1adcde6
add back _TuneCheckpointCallback dummy classes for tests to pass
justinvyu Dec 28, 2023
2955b3a
skip tests for now until xgb_ray is merged in
justinvyu Dec 28, 2023
e607a9c
make assertion error more lenient (until xgb_ray is merged in)
justinvyu Dec 28, 2023
3e3ba23
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 1, 2024
9332138
remove different Version imports
justinvyu Feb 1, 2024
45b24cf
fix test with xgboost_ray master
justinvyu Feb 1, 2024
35d997a
test with xgboost_ray master (preparing for other pr)
justinvyu Feb 1, 2024
efae857
add xgboost missing check to minimal test (preparation)
justinvyu Feb 1, 2024
8fa702a
isort
justinvyu Feb 1, 2024
3578b2e
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 4, 2024
a68bbc2
fix xgboost_example + xgb_dynamic_resources
justinvyu Feb 4, 2024
6a869b9
use master lightgbm_ray
justinvyu Feb 4, 2024
b8a4895
fix lint
justinvyu Feb 4, 2024
ccc94e1
remove filename usage
justinvyu Feb 5, 2024
e223a30
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 5, 2024
326451f
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 10, 2024
e00b73b
remove xgboost checkpoint again
justinvyu Feb 12, 2024
4f50f3d
add get_model utility on callback
justinvyu Feb 12, 2024
0c8ade9
fix tune-xgboost
justinvyu Feb 12, 2024
287049c
move callback impl to ray.train.xgboost.RayTrainReportCallback + leav…
justinvyu Feb 13, 2024
75f839e
update XGBoostTrainer.get_model
justinvyu Feb 13, 2024
d01090c
remove xgboost ckpt from xgboost examples
justinvyu Feb 13, 2024
cef5518
update lightgbm integration callback to not use framework ckpt
justinvyu Feb 13, 2024
8e0e9a7
move callback impl to ray.train.lightgbm.RayTrainReportCallback + lea…
justinvyu Feb 13, 2024
8c2306f
remove unneeded legacy callback class
justinvyu Feb 13, 2024
2d5c0d2
remove lightgbm ckpt from LightGBMTrainer.get_model
justinvyu Feb 13, 2024
c4fb911
fix/add todos to test cases
justinvyu Feb 13, 2024
8f00728
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 13, 2024
4ff91f3
remove deprecated apis from tune api listing
justinvyu Feb 13, 2024
9375ffa
fix docstring examples
justinvyu Feb 13, 2024
6d36763
fix ci failures
justinvyu Feb 13, 2024
e21567a
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 13, 2024
b1a1f5c
fix tune-xgboost (asha will terminate the trial before after_training…
justinvyu Feb 13, 2024
4478309
skip testcode
justinvyu Feb 13, 2024
c9a6574
add callbacks to train api ref
justinvyu Feb 13, 2024
db411c8
fix doc build errors due to automatic alias detection
justinvyu Feb 13, 2024
2a1fdd2
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 13, 2024
0be6f71
add test get_model
justinvyu Feb 13, 2024
c9362fa
json to ubj format as default
justinvyu Feb 13, 2024
0388775
remove unnecessary __module__ patches
justinvyu Feb 13, 2024
4e9f6b3
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
justinvyu Feb 13, 2024
2548902
fix lint
justinvyu Feb 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ XGBoost
:toctree: doc/

~train.xgboost.XGBoostTrainer
~train.xgboost.RayTrainReportCallback


LightGBM
Expand All @@ -106,6 +107,7 @@ LightGBM
:toctree: doc/

~train.lightgbm.LightGBMTrainer
~train.lightgbm.RayTrainReportCallback


.. _ray-train-configs-api:
Expand Down
5 changes: 2 additions & 3 deletions doc/source/tune/api/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ PyTorch Lightning (tune.integration.pytorch_lightning)
:nosignatures:
:toctree: doc/

~tune.integration.pytorch_lightning.TuneReportCallback
~tune.integration.pytorch_lightning.TuneReportCheckpointCallback

.. _tune-integration-xgboost:
Expand All @@ -24,9 +23,9 @@ XGBoost (tune.integration.xgboost)

.. autosummary::
:nosignatures:
:template: autosummary/class_without_autosummary.rst
:toctree: doc/

~tune.integration.xgboost.TuneReportCallback
~tune.integration.xgboost.TuneReportCheckpointCallback


Expand All @@ -37,7 +36,7 @@ LightGBM (tune.integration.lightgbm)

.. autosummary::
:nosignatures:
:template: autosummary/class_without_autosummary.rst
:toctree: doc/

~tune.integration.lightgbm.TuneReportCallback
~tune.integration.lightgbm.TuneReportCheckpointCallback
676 changes: 53 additions & 623 deletions doc/source/tune/examples/tune-xgboost.ipynb

Large diffs are not rendered by default.

65 changes: 27 additions & 38 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import logging
import os
import tempfile
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from ray import train, tune
from ray._private.dict import flatten_dict
from ray.train import Checkpoint, RunConfig, ScalingConfig
from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.tune import Trainable
from ray.tune.execution.placement_groups import PlacementGroupFactory
Expand Down Expand Up @@ -224,17 +220,28 @@ def _get_dmatrices(
for k, v in self.datasets.items()
}

@classmethod
def get_model(cls, checkpoint: Checkpoint, checkpoint_cls: Type[Any]) -> Any:
raise NotImplementedError

def _load_checkpoint(
self,
checkpoint: Checkpoint,
) -> Any:
raise NotImplementedError
# TODO(justinvyu): [code_removal] Remove in 2.11.
raise DeprecationWarning(
"The internal method `_load_checkpoint` deprecated and will be removed. "
f"See `{self.__class__.__name__}.get_model` instead."
)

def _train(self, **kwargs):
raise NotImplementedError

def _save_model(self, model: Any, path: str):
raise NotImplementedError
# TODO(justinvyu): [code_removal] Remove in 2.11.
raise DeprecationWarning(
"The internal method `_save_model` is deprecated and will be removed."
)

def _model_iteration(self, model: Any) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -269,24 +276,6 @@ def _repartition_datasets_to_match_num_actors(self):
self._ray_params.num_actors
)

def _checkpoint_at_end(self, model, evals_result: dict) -> None:
# We need to call session.report to save checkpoints, so we report
# the last received metrics (possibly again).
result_dict = flatten_dict(evals_result, delimiter="-")
for k in list(result_dict):
result_dict[k] = result_dict[k][-1]

if getattr(self._tune_callback_checkpoint_cls, "_report_callbacks_cls", None):
# Deprecate: Remove in Ray 2.8
with tune.checkpoint_dir(step=self._model_iteration(model)) as cp_dir:
self._save_model(model, path=os.path.join(cp_dir, MODEL_KEY))
tune.report(**result_dict)
else:
with tempfile.TemporaryDirectory() as checkpoint_dir:
self._save_model(model, path=checkpoint_dir)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
train.report(result_dict, checkpoint=checkpoint)

def training_loop(self) -> None:
config = self.train_kwargs.copy()
config[self._num_iterations_argument] = self.num_boost_round
Expand All @@ -299,21 +288,28 @@ def training_loop(self) -> None:

init_model = None
if self.starting_checkpoint:
init_model = self._load_checkpoint(self.starting_checkpoint)
init_model = self.__class__.get_model(self.starting_checkpoint)

config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])

if not any(
has_user_supplied_callback = any(
isinstance(cb, self._tune_callback_checkpoint_cls)
for cb in config["callbacks"]
):
# Only add our own callback if it hasn't been added before
)
if not has_user_supplied_callback:
# Only add our own default callback if the user hasn't supplied one.
checkpoint_frequency = (
self.run_config.checkpoint_config.checkpoint_frequency
)

checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end
if checkpoint_at_end is None:
# Defaults to True
checkpoint_at_end = True

callback = self._tune_callback_checkpoint_cls(
filename=MODEL_KEY, frequency=checkpoint_frequency
frequency=checkpoint_frequency, checkpoint_at_end=checkpoint_at_end
)

config["callbacks"] += [callback]
Expand All @@ -336,7 +332,7 @@ def training_loop(self) -> None:
f"({self._num_iterations_argument}={num_iterations})."
)

model = self._train(
self._train(
params=self.params,
dtrain=train_dmatrix,
evals_result=evals_result,
Expand All @@ -345,13 +341,6 @@ def training_loop(self) -> None:
**config,
)

checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end
if checkpoint_at_end is None:
checkpoint_at_end = True

if checkpoint_at_end:
self._checkpoint_at_end(model, evals_result)

def _generate_trainable_cls(self) -> Type["Trainable"]:
trainable_cls = super()._generate_trainable_cls()
trainer_cls = self.__class__
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/lightgbm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ray.train.lightgbm._lightgbm_utils import RayTrainReportCallback
from ray.train.lightgbm.lightgbm_checkpoint import LightGBMCheckpoint
from ray.train.lightgbm.lightgbm_predictor import LightGBMPredictor
from ray.train.lightgbm.lightgbm_trainer import LightGBMTrainer

__all__ = [
"RayTrainReportCallback",
"LightGBMCheckpoint",
"LightGBMPredictor",
"LightGBMTrainer",
Expand Down
166 changes: 166 additions & 0 deletions python/ray/train/lightgbm/_lightgbm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union

from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv

from ray import train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI


@PublicAPI(stability="beta")
class RayTrainReportCallback:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TuneCallback for lgbm was originally an empty class that wasn't referenced anywhere else so I just removed it.

"""Creates a callback that reports metrics and checkpoints model.

Args:
metrics: Metrics to report. If this is a list,
each item should be a metric key reported by LightGBM,
and it will be reported to Ray Train/Tune under the same name.
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
which can be used to rename LightGBM default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.txt".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict.

Examples
--------

Reporting checkpoints and metrics to Ray Tune when running many
independent xgboost trials (without data parallelism within a trial).

.. testcode::
:skipif: True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to add them back later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used to be a code-block that didn't run 😅 I just wanted to show a mock xgboost.train call with the callback inside, without needing to specify the dataset and everything.


import lightgbm

from ray.train.lightgbm import RayTrainReportCallback

config = {
# ...
"metric": ["binary_logloss", "binary_error"],
}

# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
...,
callbacks=[
RayTrainReportCallback(
metrics={"loss": "eval-binary_logloss"}, frequency=1
)
],
)

Loading a model from a checkpoint reported by this callback.

.. testcode::
:skipif: True

from ray.train.lightgbm import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For consistency with this, should we update the training example to use the LightGBMTrainer? Same for xgboost.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to add the *Trainer examples once I add in a v2 xgboost/lightgbm trainer, since then it'll actually show the callback usage in the training func. Right now the user doesn't need to create the callback themselves.

booster = RayTrainReportCallback.get_model(result.checkpoint)

"""

CHECKPOINT_NAME = "model.txt"

def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = CHECKPOINT_NAME,
frequency: int = 0,
checkpoint_at_end: bool = True,
results_postprocessing_fn: Optional[
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
if isinstance(metrics, str):
metrics = [metrics]
self._metrics = metrics
self._filename = filename
self._frequency = frequency
self._checkpoint_at_end = checkpoint_at_end
self._results_postprocessing_fn = results_postprocessing_fn

@classmethod
def get_model(
cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
) -> Booster:
"""Retrieve the model stored in a checkpoint reported by this callback.

Args:
checkpoint: The checkpoint object returned by a training run.
The checkpoint should be saved by an instance of this callback.
filename: The filename to load the model from, which should match
the filename used when creating the callback.
"""
with checkpoint.as_directory() as checkpoint_path:
return Booster(model_file=Path(checkpoint_path, filename).as_posix())

def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict:
result_dict = flatten_dict(evals_log, delimiter="-")
if not self._metrics:
report_dict = result_dict
else:
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = result_dict[metric]
if self._results_postprocessing_fn:
report_dict = self._results_postprocessing_fn(report_dict)
return report_dict

def _get_eval_result(self, env: CallbackEnv) -> dict:
eval_result = {}
for entry in env.evaluation_result_list:
data_name, eval_name, result = entry[0:3]
if len(entry) > 4:
stdv = entry[4]
suffix = "-mean"
else:
stdv = None
suffix = ""
if data_name not in eval_result:
eval_result[data_name] = {}
eval_result[data_name][eval_name + suffix] = result
if stdv is not None:
eval_result[data_name][eval_name + "-stdv"] = stdv
return eval_result

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)

def __call__(self, env: CallbackEnv) -> None:
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)

on_last_iter = env.iteration == env.end_iteration - 1
checkpointing_disabled = self._frequency == 0
# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=10,
# you will checkpoint at iterations 1, 3, 5, ..., and 9 (checkpoint_at_end)
# (counting from 0)
should_checkpoint = (
not checkpointing_disabled and (env.iteration + 1) % self._frequency == 0
) or (on_last_iter and self._checkpoint_at_end)

if should_checkpoint:
with self._get_checkpoint(model=env.model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
else:
train.report(report_dict)
15 changes: 12 additions & 3 deletions python/ray/train/lightgbm/lightgbm_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ def from_model(
booster: lightgbm.Booster,
*,
preprocessor: Optional["Preprocessor"] = None,
path: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need these changes if we're centralizing on the Callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope I can get rid of it. If anybody does use this, specifying your own temp dir might be useful though if you want it to be cleaned up after.

) -> "LightGBMCheckpoint":
"""Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model.

Args:
booster: The LightGBM model to store in the checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.
path: The path to the directory where the checkpoint file will be saved.
This should start as an empty directory, since the *entire*
directory will be treated as the checkpoint when reported.
By default, a temporary directory will be created.

Returns:
An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``.
Expand All @@ -44,10 +49,14 @@ def from_model(
>>> model = lightgbm.LGBMClassifier().fit(train_X, train_y)
>>> checkpoint = LightGBMCheckpoint.from_model(model.booster_)
"""
tempdir = tempfile.mkdtemp()
booster.save_model(Path(tempdir, cls.MODEL_FILENAME).as_posix())
checkpoint_path = Path(path or tempfile.mkdtemp())

checkpoint = cls.from_directory(tempdir)
if not checkpoint_path.is_dir():
raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")

booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())

checkpoint = cls.from_directory(checkpoint_path.as_posix())
if preprocessor:
checkpoint.set_preprocessor(preprocessor)

Expand Down