Skip to content

Commit

Permalink
[train] remove preprocessor logic for XGBoost/LightGBM Trainers (#38866)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Deng <matt@anyscale.com>
  • Loading branch information
matthewdeng authored Aug 25, 2023
1 parent 44e5edc commit b37fe98
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 41 deletions.
25 changes: 6 additions & 19 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import tempfile
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from ray import train, tune
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.checkpoint import Checkpoint as LegacyCheckpoint
from ray.train._checkpoint import Checkpoint
from ray.air.config import RunConfig, ScalingConfig
Expand All @@ -15,7 +14,6 @@
from ray.train.trainer import BaseTrainer, GenDataset
from ray.tune import Trainable
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable.util import TrainableUtil
from ray.util.annotations import DeveloperAPI
from ray._private.dict import flatten_dict

Expand Down Expand Up @@ -116,10 +114,8 @@ class GBDTTrainer(BaseTrainer):
Args:
datasets: Datasets to use for training and validation. Must include a
"train" key denoting the training dataset. If a ``preprocessor``
is provided and has not already been fit, it will be fit on the training
dataset. All datasets will be transformed by the ``preprocessor`` if
one is provided. All non-training datasets will be used as separate
"train" key denoting the training dataset.
All non-training datasets will be used as separate
validation sets, each reporting a separate metric.
label_column: Name of the label column. A column with this name
must be present in the training dataset.
Expand Down Expand Up @@ -165,7 +161,7 @@ def __init__(
num_boost_round: int = _DEFAULT_NUM_ITERATIONS,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
preprocessor: Optional["Preprocessor"] = None,
preprocessor: Optional["Preprocessor"] = None, # Deprecated
resume_from_checkpoint: Optional[LegacyCheckpoint] = None,
metadata: Optional[Dict[str, Any]] = None,
**train_kwargs,
Expand Down Expand Up @@ -223,7 +219,7 @@ def _get_dmatrices(
def _load_checkpoint(
self,
checkpoint: LegacyCheckpoint,
) -> Tuple[Any, Optional["Preprocessor"]]:
) -> Any:
raise NotImplementedError

def _train(self, **kwargs):
Expand Down Expand Up @@ -300,7 +296,7 @@ def training_loop(self) -> None:

init_model = None
if self.starting_checkpoint:
init_model, _ = self._load_checkpoint(self.starting_checkpoint)
init_model = self._load_checkpoint(self.starting_checkpoint)

config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])
Expand Down Expand Up @@ -361,15 +357,6 @@ def _generate_trainable_cls(self) -> Type["Trainable"]:
default_ray_params = self._default_ray_params

class GBDTTrainable(trainable_cls):
def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
checkpoint_path = super().save_checkpoint()
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

preprocessor = self._merged_config.get("preprocessor", None)
if parent_dir and preprocessor:
save_preprocessor_to_dir(preprocessor, parent_dir)
return checkpoint_path

@classmethod
def default_resource_request(cls, config):
# `config["scaling_config"] is a dataclass when passed via the
Expand Down
14 changes: 3 additions & 11 deletions python/ray/train/lightgbm/lightgbm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, Any, Optional, Tuple, Union, TYPE_CHECKING
from typing import Dict, Any, Union

try:
from packaging.version import Version
Expand All @@ -10,16 +10,12 @@
from ray.air.constants import MODEL_KEY
from ray.train.gbdt_trainer import GBDTTrainer
from ray.util.annotations import PublicAPI
from ray.train.lightgbm.lightgbm_checkpoint import LegacyLightGBMCheckpoint

import lightgbm
import lightgbm_ray
import xgboost_ray
from lightgbm_ray.tune import TuneReportCheckpointCallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="beta")
class LightGBMTrainer(GBDTTrainer):
Expand Down Expand Up @@ -111,12 +107,8 @@ def get_model(checkpoint: Checkpoint) -> lightgbm.Booster:
def _train(self, **kwargs):
return lightgbm_ray.train(**kwargs)

def _load_checkpoint(
self, checkpoint: Checkpoint
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
# TODO(matt): Replace this when preprocessor arg is removed.
checkpoint = LegacyLightGBMCheckpoint.from_checkpoint(checkpoint)
return checkpoint.get_model(), checkpoint.get_preprocessor()
def _load_checkpoint(self, checkpoint: Checkpoint) -> lightgbm.Booster:
return self.__class__.get_model(checkpoint)

def _save_model(self, model: lightgbm.LGBMModel, path: str):
model.booster_.save_model(path)
Expand Down
14 changes: 3 additions & 11 deletions python/ray/train/xgboost/xgboost_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict

try:
from packaging.version import Version
Expand All @@ -9,16 +9,12 @@
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.train.gbdt_trainer import GBDTTrainer
from ray.train.xgboost.xgboost_checkpoint import LegacyXGBoostCheckpoint
from ray.util.annotations import PublicAPI

import xgboost
import xgboost_ray
from xgboost_ray.tune import TuneReportCheckpointCallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="beta")
class XGBoostTrainer(GBDTTrainer):
Expand Down Expand Up @@ -104,12 +100,8 @@ def get_model(checkpoint: Checkpoint) -> xgboost.Booster:
def _train(self, **kwargs):
return xgboost_ray.train(**kwargs)

def _load_checkpoint(
self, checkpoint: Checkpoint
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
# TODO(matt): Replace this when preprocessor arg is removed.
checkpoint = LegacyXGBoostCheckpoint.from_checkpoint(checkpoint)
return checkpoint.get_model(), checkpoint.get_preprocessor()
def _load_checkpoint(self, checkpoint: Checkpoint) -> xgboost.Booster:
return self.__class__.get_model(checkpoint)

def _save_model(self, model: xgboost.Booster, path: str):
model.save_model(path)
Expand Down

0 comments on commit b37fe98

Please sign in to comment.