Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] remove preprocessor logic for XGBoost/LightGBM Trainers #38866

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import tempfile
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from ray import train, tune
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.checkpoint import Checkpoint as LegacyCheckpoint
from ray.train._checkpoint import Checkpoint
from ray.air.config import RunConfig, ScalingConfig
Expand All @@ -15,7 +14,6 @@
from ray.train.trainer import BaseTrainer, GenDataset
from ray.tune import Trainable
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.trainable.util import TrainableUtil
from ray.util.annotations import DeveloperAPI
from ray._private.dict import flatten_dict

Expand Down Expand Up @@ -116,10 +114,8 @@ class GBDTTrainer(BaseTrainer):

Args:
datasets: Datasets to use for training and validation. Must include a
"train" key denoting the training dataset. If a ``preprocessor``
is provided and has not already been fit, it will be fit on the training
dataset. All datasets will be transformed by the ``preprocessor`` if
one is provided. All non-training datasets will be used as separate
"train" key denoting the training dataset.
All non-training datasets will be used as separate
validation sets, each reporting a separate metric.
label_column: Name of the label column. A column with this name
must be present in the training dataset.
Expand Down Expand Up @@ -165,7 +161,7 @@ def __init__(
num_boost_round: int = _DEFAULT_NUM_ITERATIONS,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
preprocessor: Optional["Preprocessor"] = None,
preprocessor: Optional["Preprocessor"] = None, # Deprecated
resume_from_checkpoint: Optional[LegacyCheckpoint] = None,
metadata: Optional[Dict[str, Any]] = None,
**train_kwargs,
Expand Down Expand Up @@ -223,7 +219,7 @@ def _get_dmatrices(
def _load_checkpoint(
self,
checkpoint: LegacyCheckpoint,
) -> Tuple[Any, Optional["Preprocessor"]]:
) -> Any:
raise NotImplementedError

def _train(self, **kwargs):
Expand Down Expand Up @@ -300,7 +296,7 @@ def training_loop(self) -> None:

init_model = None
if self.starting_checkpoint:
init_model, _ = self._load_checkpoint(self.starting_checkpoint)
init_model = self._load_checkpoint(self.starting_checkpoint)

config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])
Expand Down Expand Up @@ -361,15 +357,6 @@ def _generate_trainable_cls(self) -> Type["Trainable"]:
default_ray_params = self._default_ray_params

class GBDTTrainable(trainable_cls):
def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
checkpoint_path = super().save_checkpoint()
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

preprocessor = self._merged_config.get("preprocessor", None)
if parent_dir and preprocessor:
save_preprocessor_to_dir(preprocessor, parent_dir)
return checkpoint_path

@classmethod
def default_resource_request(cls, config):
# `config["scaling_config"] is a dataclass when passed via the
Expand Down
14 changes: 3 additions & 11 deletions python/ray/train/lightgbm/lightgbm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, Any, Optional, Tuple, Union, TYPE_CHECKING
from typing import Dict, Any, Union

try:
from packaging.version import Version
Expand All @@ -10,16 +10,12 @@
from ray.air.constants import MODEL_KEY
from ray.train.gbdt_trainer import GBDTTrainer
from ray.util.annotations import PublicAPI
from ray.train.lightgbm.lightgbm_checkpoint import LegacyLightGBMCheckpoint

import lightgbm
import lightgbm_ray
import xgboost_ray
from lightgbm_ray.tune import TuneReportCheckpointCallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="beta")
class LightGBMTrainer(GBDTTrainer):
Expand Down Expand Up @@ -111,12 +107,8 @@ def get_model(checkpoint: Checkpoint) -> lightgbm.Booster:
def _train(self, **kwargs):
return lightgbm_ray.train(**kwargs)

def _load_checkpoint(
self, checkpoint: Checkpoint
) -> Tuple[lightgbm.Booster, Optional["Preprocessor"]]:
# TODO(matt): Replace this when preprocessor arg is removed.
checkpoint = LegacyLightGBMCheckpoint.from_checkpoint(checkpoint)
return checkpoint.get_model(), checkpoint.get_preprocessor()
def _load_checkpoint(self, checkpoint: Checkpoint) -> lightgbm.Booster:
return self.__class__.get_model(checkpoint)

def _save_model(self, model: lightgbm.LGBMModel, path: str):
model.booster_.save_model(path)
Expand Down
14 changes: 3 additions & 11 deletions python/ray/train/xgboost/xgboost_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict

try:
from packaging.version import Version
Expand All @@ -9,16 +9,12 @@
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.train.gbdt_trainer import GBDTTrainer
from ray.train.xgboost.xgboost_checkpoint import LegacyXGBoostCheckpoint
from ray.util.annotations import PublicAPI

import xgboost
import xgboost_ray
from xgboost_ray.tune import TuneReportCheckpointCallback

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor


@PublicAPI(stability="beta")
class XGBoostTrainer(GBDTTrainer):
Expand Down Expand Up @@ -104,12 +100,8 @@ def get_model(checkpoint: Checkpoint) -> xgboost.Booster:
def _train(self, **kwargs):
return xgboost_ray.train(**kwargs)

def _load_checkpoint(
self, checkpoint: Checkpoint
) -> Tuple[xgboost.Booster, Optional["Preprocessor"]]:
# TODO(matt): Replace this when preprocessor arg is removed.
checkpoint = LegacyXGBoostCheckpoint.from_checkpoint(checkpoint)
return checkpoint.get_model(), checkpoint.get_preprocessor()
def _load_checkpoint(self, checkpoint: Checkpoint) -> xgboost.Booster:
return self.__class__.get_model(checkpoint)

def _save_model(self, model: xgboost.Booster, path: str):
model.save_model(path)
Expand Down
Loading