-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Simplify ray.train.xgboost/lightgbm
(1/n): Align frequency-based and checkpoint_at_end
checkpoint formats
#42111
Changes from 52 commits
4cb0a9f
d27c12b
7c4dc27
c92eaa1
9172933
f2c9758
634cacb
9b37c58
8751718
e0a4afa
e1f9798
95d4f87
8b6c0c1
f042ca2
ec4bfb6
1adcde6
2955b3a
e607a9c
3e3ba23
9332138
45b24cf
35d997a
efae857
8fa702a
3578b2e
a68bbc2
6a869b9
b8a4895
ccc94e1
e223a30
326451f
e00b73b
4f50f3d
0c8ade9
287049c
75f839e
d01090c
cef5518
8e0e9a7
8c2306f
2d5c0d2
c4fb911
8f00728
4ff91f3
9375ffa
6d36763
e21567a
b1a1f5c
4478309
c9a6574
db411c8
2a1fdd2
0be6f71
c9362fa
0388775
4e9f6b3
2548902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import tempfile | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
from typing import Callable, Dict, List, Optional, Union | ||
|
||
from lightgbm.basic import Booster | ||
from lightgbm.callback import CallbackEnv | ||
|
||
from ray import train | ||
from ray.train import Checkpoint | ||
from ray.tune.utils import flatten_dict | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
@PublicAPI(stability="beta") | ||
class RayTrainReportCallback: | ||
"""Creates a callback that reports metrics and checkpoints model. | ||
|
||
Args: | ||
metrics: Metrics to report. If this is a list, | ||
each item should be a metric key reported by LightGBM, | ||
and it will be reported to Ray Train/Tune under the same name. | ||
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>}, | ||
which can be used to rename LightGBM default metrics. | ||
filename: Customize the saved checkpoint file type by passing | ||
a filename. Defaults to "model.txt". | ||
frequency: How often to save checkpoints, in terms of iterations. | ||
Defaults to 0 (no checkpoints are saved during training). | ||
checkpoint_at_end: Whether or not to save a checkpoint at the end of training. | ||
results_postprocessing_fn: An optional Callable that takes in | ||
the metrics dict that will be reported (after it has been flattened) | ||
and returns a modified dict. | ||
|
||
Examples | ||
-------- | ||
|
||
Reporting checkpoints and metrics to Ray Tune when running many | ||
independent xgboost trials (without data parallelism within a trial). | ||
|
||
.. testcode:: | ||
:skipif: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we going to add them back later? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This used to be a code-block that didn't run 😅 I just wanted to show a mock |
||
|
||
import lightgbm | ||
|
||
from ray.train.lightgbm import RayTrainReportCallback | ||
|
||
config = { | ||
# ... | ||
"metric": ["binary_logloss", "binary_error"], | ||
} | ||
|
||
# Report only log loss to Tune after each validation epoch. | ||
bst = lightgbm.train( | ||
..., | ||
callbacks=[ | ||
RayTrainReportCallback( | ||
metrics={"loss": "eval-binary_logloss"}, frequency=1 | ||
) | ||
], | ||
) | ||
|
||
Loading a model from a checkpoint reported by this callback. | ||
|
||
.. testcode:: | ||
:skipif: True | ||
|
||
from ray.train.lightgbm import RayTrainReportCallback | ||
|
||
# Get a `Checkpoint` object that is saved by the callback during training. | ||
result = trainer.fit() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: For consistency with this, should we update the training example to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to add the |
||
booster = RayTrainReportCallback.get_model(result.checkpoint) | ||
|
||
""" | ||
|
||
CHECKPOINT_NAME = "model.txt" | ||
|
||
def __init__( | ||
self, | ||
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, | ||
filename: str = CHECKPOINT_NAME, | ||
frequency: int = 0, | ||
checkpoint_at_end: bool = True, | ||
results_postprocessing_fn: Optional[ | ||
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]] | ||
] = None, | ||
): | ||
if isinstance(metrics, str): | ||
metrics = [metrics] | ||
self._metrics = metrics | ||
self._filename = filename | ||
self._frequency = frequency | ||
self._checkpoint_at_end = checkpoint_at_end | ||
self._results_postprocessing_fn = results_postprocessing_fn | ||
|
||
@classmethod | ||
def get_model( | ||
cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME | ||
) -> Booster: | ||
"""Retrieve the model stored in a checkpoint reported by this callback. | ||
|
||
Args: | ||
checkpoint: The checkpoint object returned by a training run. | ||
The checkpoint should be saved by an instance of this callback. | ||
filename: The filename to load the model from, which should match | ||
the filename used when creating the callback. | ||
""" | ||
with checkpoint.as_directory() as checkpoint_path: | ||
return Booster(model_file=Path(checkpoint_path, filename).as_posix()) | ||
|
||
def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict: | ||
result_dict = flatten_dict(evals_log, delimiter="-") | ||
if not self._metrics: | ||
report_dict = result_dict | ||
else: | ||
report_dict = {} | ||
for key in self._metrics: | ||
if isinstance(self._metrics, dict): | ||
metric = self._metrics[key] | ||
else: | ||
metric = key | ||
report_dict[key] = result_dict[metric] | ||
if self._results_postprocessing_fn: | ||
report_dict = self._results_postprocessing_fn(report_dict) | ||
return report_dict | ||
|
||
def _get_eval_result(self, env: CallbackEnv) -> dict: | ||
eval_result = {} | ||
for entry in env.evaluation_result_list: | ||
data_name, eval_name, result = entry[0:3] | ||
if len(entry) > 4: | ||
stdv = entry[4] | ||
suffix = "-mean" | ||
else: | ||
stdv = None | ||
suffix = "" | ||
if data_name not in eval_result: | ||
eval_result[data_name] = {} | ||
eval_result[data_name][eval_name + suffix] = result | ||
if stdv is not None: | ||
eval_result[data_name][eval_name + "-stdv"] = stdv | ||
return eval_result | ||
|
||
@contextmanager | ||
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]: | ||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir: | ||
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix()) | ||
yield Checkpoint.from_directory(temp_checkpoint_dir) | ||
|
||
def __call__(self, env: CallbackEnv) -> None: | ||
eval_result = self._get_eval_result(env) | ||
report_dict = self._get_report_dict(eval_result) | ||
|
||
on_last_iter = env.iteration == env.end_iteration - 1 | ||
checkpointing_disabled = self._frequency == 0 | ||
# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=10, | ||
# you will checkpoint at iterations 1, 3, 5, ..., and 9 (checkpoint_at_end) | ||
# (counting from 0) | ||
should_checkpoint = ( | ||
not checkpointing_disabled and (env.iteration + 1) % self._frequency == 0 | ||
) or (on_last_iter and self._checkpoint_at_end) | ||
|
||
if should_checkpoint: | ||
with self._get_checkpoint(model=env.model) as checkpoint: | ||
train.report(report_dict, checkpoint=checkpoint) | ||
else: | ||
train.report(report_dict) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,12 +23,17 @@ def from_model( | |
booster: lightgbm.Booster, | ||
*, | ||
preprocessor: Optional["Preprocessor"] = None, | ||
path: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need these changes if we're centralizing on the Callbacks? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope I can get rid of it. If anybody does use this, specifying your own temp dir might be useful though if you want it to be cleaned up after. |
||
) -> "LightGBMCheckpoint": | ||
"""Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model. | ||
|
||
Args: | ||
booster: The LightGBM model to store in the checkpoint. | ||
preprocessor: A fitted preprocessor to be applied before inference. | ||
path: The path to the directory where the checkpoint file will be saved. | ||
This should start as an empty directory, since the *entire* | ||
directory will be treated as the checkpoint when reported. | ||
By default, a temporary directory will be created. | ||
|
||
Returns: | ||
An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``. | ||
|
@@ -44,10 +49,14 @@ def from_model( | |
>>> model = lightgbm.LGBMClassifier().fit(train_X, train_y) | ||
>>> checkpoint = LightGBMCheckpoint.from_model(model.booster_) | ||
""" | ||
tempdir = tempfile.mkdtemp() | ||
booster.save_model(Path(tempdir, cls.MODEL_FILENAME).as_posix()) | ||
checkpoint_path = Path(path or tempfile.mkdtemp()) | ||
|
||
checkpoint = cls.from_directory(tempdir) | ||
if not checkpoint_path.is_dir(): | ||
raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}") | ||
|
||
booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix()) | ||
|
||
checkpoint = cls.from_directory(checkpoint_path.as_posix()) | ||
if preprocessor: | ||
checkpoint.set_preprocessor(preprocessor) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TuneCallback
for lgbm was originally an empty class that wasn't referenced anywhere else so I just removed it.