Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Simplify ray.train.xgboost/lightgbm (3/n): Re-implement LightGBMTrainer as a lightweight DataParallelTrainer #43244

Merged
merged 24 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
643c059
DataConfig split instead of streaming_split
justinvyu Feb 15, 2024
04b0620
Merge branch 'master' of https://github.com/ray-project/ray into simp…
justinvyu Feb 15, 2024
5d5f8ab
initial commit of lgbm v2 files
justinvyu Feb 15, 2024
36fb4e0
Merge branch 'master' of https://github.com/ray-project/ray into simp…
justinvyu Feb 16, 2024
b9978df
pipe network_params through global (per process) session
justinvyu Feb 16, 2024
9f13e0b
fix lint
justinvyu Feb 16, 2024
7edf6b6
assign ports correctly
justinvyu Feb 17, 2024
88aaad3
add working lightgbm trainer v2 example
justinvyu Feb 17, 2024
82024ce
fix docstring line lengths
justinvyu Feb 17, 2024
a4f3404
remove unneeded import
justinvyu Feb 17, 2024
ad33441
make a copy of the global var dict attr
justinvyu Feb 17, 2024
583093e
add skeleton for legacy api on new impl
justinvyu Feb 17, 2024
e9fe651
fix tests
justinvyu Feb 17, 2024
bdd1baf
fix lint
justinvyu Feb 17, 2024
99f25b6
revert workspace testing changes
justinvyu Feb 17, 2024
d99ad07
Merge branch 'master' of https://github.com/ray-project/ray into simp…
justinvyu Feb 21, 2024
5fff2ce
detach NETWORK_PARAMS_KEY from lgbm config
justinvyu Feb 21, 2024
b160507
fix some error messages + remove unused methods
justinvyu Feb 21, 2024
e0c42c7
add network params in legacy lgbm trainer train fn
justinvyu Feb 21, 2024
ca1c403
revert data config changes
justinvyu Feb 21, 2024
c774020
add missing dataset_config param
justinvyu Feb 21, 2024
72dd0ae
Merge branch 'master' of https://github.com/ray-project/ray into simp…
justinvyu Feb 22, 2024
f1eb837
remove lgbm config from ray.train.lightgbm for now
justinvyu Feb 22, 2024
b6842ec
fix isort
justinvyu Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def __init__(
self.local_ip = self.get_current_ip()

self.accelerator = None
self._state = {}

def get_state(self, key: str) -> Any:
return self._state.get(key)

def set_state(self, key: str, value: Any):
self._state[key] = value

def get_current_ip(self):
self.local_ip = ray.util.get_node_ip_address()
Expand Down Expand Up @@ -210,6 +217,7 @@ def reset(
self.loaded_checkpoint = loaded_checkpoint

# Reset state
self._state = {}
self.ignore_report = False
self.training_started = False
self._first_report = True
Expand Down
62 changes: 62 additions & 0 deletions python/ray/train/lightgbm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging
from dataclasses import dataclass

import ray
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig

logger = logging.getLogger(__name__)


NETWORK_PARAMS_KEY = "LIGHTGBM_NETWORK_PARAMS"


@dataclass
class LightGBMConfig(BackendConfig):
"""Configuration for LightGBM distributed data-parallel training setup.

See the LightGBM docs for more information on the "network parameters"
that Ray Train sets up for you:
https://lightgbm.readthedocs.io/en/latest/Parameters.html#network-parameters
"""

@property
def backend_cls(self):
return _LightGBMBackend


class _LightGBMBackend(Backend):
def on_training_start(
self, worker_group: WorkerGroup, backend_config: LightGBMConfig
):
node_ips_and_ports = worker_group.execute(get_address_and_port)
ports = [port for _, port in node_ips_and_ports]
machines = ",".join(
[f"{node_ip}:{port}" for node_ip, port in node_ips_and_ports]
)
num_machines = len(worker_group)

def set_network_params(
num_machines: int, local_listen_port: int, machines: str
):
from ray.train._internal.session import get_session

session = get_session()
session.set_state(
NETWORK_PARAMS_KEY,
dict(
num_machines=num_machines,
local_listen_port=local_listen_port,
machines=machines,
),
)

ray.get(
[
worker_group.execute_single_async(
rank, set_network_params, num_machines, ports[rank], machines
)
for rank in range(len(worker_group))
]
)
219 changes: 217 additions & 2 deletions python/ray/train/lightgbm/lightgbm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,230 @@
from typing import Any, Dict, Union
import logging
from functools import partial
from typing import Any, Dict, Optional, Union

import lightgbm

import ray
from ray.train import Checkpoint
from ray.train.constants import _DEPRECATED_VALUE, TRAIN_DATASET_KEY
from ray.train.gbdt_trainer import GBDTTrainer
from ray.train.lightgbm import RayTrainReportCallback
from ray.train.lightgbm.v2 import LightGBMTrainer as SimpleLightGBMTrainer
from ray.train.trainer import GenDataset
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


def _lightgbm_train_fn_per_worker(
config: dict,
label_column: str,
num_boost_round: int,
dataset_keys: set,
lightgbm_train_kwargs: dict,
):
checkpoint = ray.train.get_checkpoint()
starting_model = None
remaining_iters = num_boost_round
if checkpoint:
starting_model = RayTrainReportCallback.get_model(checkpoint)
starting_iter = starting_model.current_iteration()
remaining_iters = num_boost_round - starting_iter
logger.info(
f"Model loaded from checkpoint will train for "
f"additional {remaining_iters} iterations (trees) in order "
"to achieve the target number of iterations "
f"({num_boost_round=})."
)

train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
train_df = train_ds_iter.materialize().to_pandas()

eval_ds_iters = {
k: ray.train.get_dataset_shard(k)
for k in dataset_keys
if k != TRAIN_DATASET_KEY
}
eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}

train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
train_set = lightgbm.Dataset(train_X, label=train_y)

# NOTE: Include the training dataset in the evaluation datasets.
# This allows `train-*` metrics to be calculated and reported.
valid_sets = [train_set]
valid_names = [TRAIN_DATASET_KEY]

for eval_name, eval_df in eval_dfs.items():
eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
valid_sets.append(lightgbm.Dataset(eval_X, label=eval_y))
valid_names.append(eval_name)

# Add network params of the worker group to enable distributed training.
config.update(ray.train.lightgbm.v2.get_network_params())

lightgbm.train(
params=config,
train_set=train_set,
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
num_boost_round=remaining_iters,
valid_sets=valid_sets,
valid_names=valid_names,
init_model=starting_model,
**lightgbm_train_kwargs,
)


@PublicAPI(stability="beta")
class LightGBMTrainer(SimpleLightGBMTrainer):
"""A Trainer for data parallel LightGBM training.

This Trainer runs the LightGBM training loop in a distributed manner
using multiple Ray Actors.

If you would like to take advantage of LightGBM's built-in handling
for features with the categorical data type, consider applying the
:class:`Categorizer` preprocessor to set the dtypes in the dataset.

.. note::
``LightGBMTrainer`` does not modify or otherwise alter the working
of the LightGBM distributed training algorithm.
Ray only provides orchestration, data ingest and fault tolerance.
For more information on LightGBM distributed training, refer to
`LightGBM documentation <https://lightgbm.readthedocs.io/>`__.

Example:
.. testcode::

import ray

from ray.train.lightgbm import LightGBMTrainer
from ray.train import ScalingConfig

train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)]
)
trainer = LightGBMTrainer(
label_column="y",
params={"objective": "regression"},
scaling_config=ScalingConfig(num_workers=3),
datasets={"train": train_dataset},
)
result = trainer.fit()

.. testoutput::
:hide:

...

Args:
datasets: The Ray Datasets to use for training and validation. Must include a
"train" key denoting the training dataset. All non-training datasets will
be used as separate validation sets, each reporting a separate metric.
label_column: Name of the label column. A column with this name
must be present in the training dataset.
params: LightGBM training parameters passed to ``lightgbm.train()``.
Refer to `LightGBM documentation <https://lightgbm.readthedocs.io>`_
for a list of possible parameters.
num_boost_round: Target number of boosting iterations (trees in the model).
Note that unlike in ``lightgbm.train``, this is the target number
of trees, meaning that if you set ``num_boost_round=10`` and pass a model
that has already been trained for 5 iterations, it will be trained for 5
iterations more, instead of 10 more.
scaling_config: Configuration for how to scale data parallel training.
run_config: Configuration for the execution of the training run.
resume_from_checkpoint: A checkpoint to resume training from.
metadata: Dict that should be made available in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
**train_kwargs: Additional kwargs passed to ``lightgbm.train()`` function.
"""

_handles_checkpoint_freq = True
_handles_checkpoint_at_end = True

def __init__(
self,
*,
datasets: Dict[str, GenDataset],
label_column: str,
params: Dict[str, Any],
num_boost_round: int = 10,
scaling_config: Optional[ray.train.ScalingConfig] = None,
run_config: Optional[ray.train.RunConfig] = None,
dataset_config: Optional[ray.train.DataConfig] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
metadata: Optional[Dict[str, Any]] = None,
dmatrix_params: Optional[Dict[str, Dict[str, Any]]] = _DEPRECATED_VALUE,
**train_kwargs,
):
# TODO(justinvyu): [Deprecated] Remove in 2.11
if dmatrix_params != _DEPRECATED_VALUE:
raise DeprecationWarning(
"`dmatrix_params` is deprecated, since XGBoostTrainer no longer "
"depends on the `xgboost_ray.RayDMatrix` utility. "
"You can remove this argument and use `dataset_config` instead "
"to customize Ray Dataset ingestion."
)

# Initialize a default Ray Train metrics/checkpoint reporting callback if needed
callbacks = train_kwargs.get("callbacks", [])
user_supplied_callback = any(
isinstance(callback, RayTrainReportCallback) for callback in callbacks
)
callback_kwargs = {}
if run_config:
checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end

callback_kwargs["frequency"] = checkpoint_frequency
# Default `checkpoint_at_end=True` unless the user explicitly sets it.
callback_kwargs["checkpoint_at_end"] = (
checkpoint_at_end if checkpoint_at_end is not None else True
)

if not user_supplied_callback:
callbacks.append(RayTrainReportCallback(**callback_kwargs))
train_kwargs["callbacks"] = callbacks

train_fn_per_worker = partial(
_lightgbm_train_fn_per_worker,
label_column=label_column,
num_boost_round=num_boost_round,
dataset_keys=set(datasets),
lightgbm_train_kwargs=train_kwargs,
)

super(LightGBMTrainer, self).__init__(
train_loop_per_worker=train_fn_per_worker,
train_loop_config=params,
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
dataset_config=dataset_config,
resume_from_checkpoint=resume_from_checkpoint,
metadata=metadata,
)

@classmethod
def get_model(
cls,
checkpoint: Checkpoint,
) -> lightgbm.Booster:
"""Retrieve the LightGBM model stored in this checkpoint."""
return RayTrainReportCallback.get_model(checkpoint)

def _validate_attributes(self):
super()._validate_attributes()

if TRAIN_DATASET_KEY not in self.datasets:
raise KeyError(
f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. "
f"Got {list(self.datasets.keys())}"
)


# TODO(justinvyu): [simplify_xgb] Remove once lightgbm_ray dep is gone.
@PublicAPI(stability="beta")
class LightGBMTrainer(GBDTTrainer):
class LegacyLightGBMTrainer(GBDTTrainer):
"""A Trainer for data parallel LightGBM training.

This Trainer runs the LightGBM training loop in a distributed manner
Expand Down