-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
gbdt_trainer.py
378 lines (322 loc) · 14.8 KB
/
gbdt_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import logging
import os
import tempfile
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from ray import train, tune
from ray._private.dict import flatten_dict
from ray.train import Checkpoint, RunConfig, ScalingConfig
from ray.train.constants import MODEL_KEY, TRAIN_DATASET_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.tune import Trainable
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
import xgboost_ray
from ray.data.preprocessor import Preprocessor
_WARN_REPARTITION_THRESHOLD = 10 * 1024**3
_DEFAULT_NUM_ITERATIONS = 10
logger = logging.getLogger(__name__)
def _convert_scaling_config_to_ray_params(
scaling_config: ScalingConfig,
ray_params_cls: Type["xgboost_ray.RayParams"],
default_ray_params: Optional[Dict[str, Any]] = None,
) -> "xgboost_ray.RayParams":
"""Scaling config parameters have precedence over default ray params.
Default ray params are defined in the trainers (xgboost/lightgbm),
but if the user requests something else, that should be respected.
"""
resources = (scaling_config.resources_per_worker or {}).copy()
cpus_per_actor = resources.pop("CPU", 0)
if not cpus_per_actor:
cpus_per_actor = default_ray_params.get("cpus_per_actor", 0)
gpus_per_actor = resources.pop("GPU", int(scaling_config.use_gpu))
if not gpus_per_actor:
gpus_per_actor = default_ray_params.get("gpus_per_actor", 0)
resources_per_actor = resources
if not resources_per_actor:
resources_per_actor = default_ray_params.get("resources_per_actor", None)
num_actors = scaling_config.num_workers
if not num_actors:
num_actors = default_ray_params.get("num_actors", 0)
ray_params_kwargs = default_ray_params.copy() or {}
ray_params_kwargs.update(
{
"cpus_per_actor": int(cpus_per_actor),
"gpus_per_actor": int(gpus_per_actor),
"resources_per_actor": resources_per_actor,
"num_actors": int(num_actors),
}
)
# This should be upstreamed to xgboost_ray,
# but also left here for backwards compatibility.
if not hasattr(ray_params_cls, "placement_options"):
@dataclass
class RayParamsFromScalingConfig(ray_params_cls):
# Passed as kwargs to PlacementGroupFactory
placement_options: Dict[str, Any] = None
def get_tune_resources(self) -> PlacementGroupFactory:
pgf = super().get_tune_resources()
placement_options = self.placement_options.copy()
extended_pgf = PlacementGroupFactory(
pgf.bundles,
**placement_options,
)
extended_pgf._head_bundle_is_empty = pgf._head_bundle_is_empty
return extended_pgf
ray_params_cls_extended = RayParamsFromScalingConfig
else:
ray_params_cls_extended = ray_params_cls
placement_options = {
"strategy": scaling_config.placement_strategy,
}
ray_params = ray_params_cls_extended(
placement_options=placement_options,
**ray_params_kwargs,
)
return ray_params
@DeveloperAPI
class GBDTTrainer(BaseTrainer):
"""Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.
Inherited by XGBoostTrainer and LightGBMTrainer.
Args:
datasets: Datasets to use for training and validation. Must include a
"train" key denoting the training dataset.
All non-training datasets will be used as separate
validation sets, each reporting a separate metric.
label_column: Name of the label column. A column with this name
must be present in the training dataset.
params: Framework specific training parameters.
dmatrix_params: Dict of ``dataset name:dict of kwargs`` passed to respective
:class:`xgboost_ray.RayDMatrix` initializations.
num_boost_round: Target number of boosting iterations (trees in the model).
scaling_config: Configuration for how to scale data parallel training.
run_config: Configuration for the execution of the training run.
resume_from_checkpoint: A checkpoint to resume training from.
metadata: Dict that should be made available in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
**train_kwargs: Additional kwargs passed to framework ``train()`` function.
"""
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
"num_workers",
"resources_per_worker",
"use_gpu",
"placement_strategy",
]
_handles_checkpoint_freq = True
_handles_checkpoint_at_end = True
_dmatrix_cls: type
_ray_params_cls: type
_tune_callback_checkpoint_cls: type
_default_ray_params: Dict[str, Any] = {
"checkpoint_frequency": 1,
"checkpoint_at_end": True,
}
_init_model_arg_name: str
_num_iterations_argument: str = "num_boost_round"
_default_num_iterations: int = _DEFAULT_NUM_ITERATIONS
def __init__(
self,
*,
datasets: Dict[str, GenDataset],
label_column: str,
params: Dict[str, Any],
dmatrix_params: Optional[Dict[str, Dict[str, Any]]] = None,
num_boost_round: int = _DEFAULT_NUM_ITERATIONS,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
preprocessor: Optional["Preprocessor"] = None, # Deprecated
resume_from_checkpoint: Optional[Checkpoint] = None,
metadata: Optional[Dict[str, Any]] = None,
**train_kwargs,
):
self.label_column = label_column
self.params = params
self.num_boost_round = num_boost_round
self.train_kwargs = train_kwargs
self.dmatrix_params = dmatrix_params or {}
super().__init__(
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
metadata=metadata,
)
# Datasets should always use distributed loading.
for dataset_name in self.datasets.keys():
dataset_params = self.dmatrix_params.get(dataset_name, {})
dataset_params["distributed"] = True
self.dmatrix_params[dataset_name] = dataset_params
def _validate_attributes(self):
super()._validate_attributes()
self._validate_config_and_datasets()
def _validate_config_and_datasets(self) -> None:
if TRAIN_DATASET_KEY not in self.datasets:
raise KeyError(
f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. "
f"Got {list(self.datasets.keys())}"
)
if self.dmatrix_params:
for key in self.dmatrix_params:
if key not in self.datasets:
raise KeyError(
f"`dmatrix_params` dict contains key '{key}' "
f"which is not present in `datasets`."
)
@classmethod
def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
# Todo: `trainer_resources` should be configurable. Currently it is silently
# ignored. We catch the error here rather than in
# `_scaling_config_allowed_keys` because the default of `None` is updated to
# `{}` from XGBoost-Ray.
if scaling_config.trainer_resources not in [None, {}]:
raise ValueError(
f"The `trainer_resources` attribute for {cls.__name__} "
f"is currently ignored and defaults to `{{}}`. Remove the "
f"`trainer_resources` key from your `ScalingConfig` to resolve."
)
return super(GBDTTrainer, cls)._validate_scaling_config(
scaling_config=scaling_config
)
def _get_dmatrices(
self, dmatrix_params: Dict[str, Any]
) -> Dict[str, "xgboost_ray.RayDMatrix"]:
return {
k: self._dmatrix_cls(
v, label=self.label_column, **dmatrix_params.get(k, {})
)
for k, v in self.datasets.items()
}
def _load_checkpoint(
self,
checkpoint: Checkpoint,
) -> Any:
raise NotImplementedError
def _train(self, **kwargs):
raise NotImplementedError
def _save_model(self, model: Any, path: str):
raise NotImplementedError
def _model_iteration(self, model: Any) -> int:
raise NotImplementedError
@property
def _ray_params(self) -> "xgboost_ray.RayParams":
scaling_config_dataclass = self._validate_scaling_config(self.scaling_config)
return _convert_scaling_config_to_ray_params(
scaling_config_dataclass, self._ray_params_cls, self._default_ray_params
)
def _repartition_datasets_to_match_num_actors(self):
# XGBoost/LightGBM-Ray requires each dataset to have at least as many
# blocks as there are workers.
# This is only applicable for xgboost-ray<0.1.16. The version check
# is done in subclasses to ensure that xgboost-ray doesn't need to be
# imported here.
for dataset_key, dataset in self.datasets.items():
if dataset.num_blocks() < self._ray_params.num_actors:
if dataset.size_bytes() > _WARN_REPARTITION_THRESHOLD:
warnings.warn(
f"Dataset '{dataset_key}' has {dataset.num_blocks()} blocks, "
f"which is less than the `num_workers` "
f"{self._ray_params.num_actors}. "
f"This dataset will be automatically repartitioned to "
f"{self._ray_params.num_actors} blocks. You can disable "
"this error message by partitioning the dataset "
"to have blocks >= number of workers via "
"`dataset.repartition(num_workers)`."
)
self.datasets[dataset_key] = dataset.repartition(
self._ray_params.num_actors
)
def _checkpoint_at_end(self, model, evals_result: dict) -> None:
# We need to call session.report to save checkpoints, so we report
# the last received metrics (possibly again).
result_dict = flatten_dict(evals_result, delimiter="-")
for k in list(result_dict):
result_dict[k] = result_dict[k][-1]
if getattr(self._tune_callback_checkpoint_cls, "_report_callbacks_cls", None):
# Deprecate: Remove in Ray 2.8
with tune.checkpoint_dir(step=self._model_iteration(model)) as cp_dir:
self._save_model(model, path=os.path.join(cp_dir, MODEL_KEY))
tune.report(**result_dict)
else:
with tempfile.TemporaryDirectory() as checkpoint_dir:
self._save_model(model, path=checkpoint_dir)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
train.report(result_dict, checkpoint=checkpoint)
def training_loop(self) -> None:
config = self.train_kwargs.copy()
config[self._num_iterations_argument] = self.num_boost_round
dmatrices = self._get_dmatrices(
dmatrix_params=self.dmatrix_params,
)
train_dmatrix = dmatrices[TRAIN_DATASET_KEY]
evals_result = {}
init_model = None
if self.starting_checkpoint:
init_model = self._load_checkpoint(self.starting_checkpoint)
config.setdefault("verbose_eval", False)
config.setdefault("callbacks", [])
if not any(
isinstance(cb, self._tune_callback_checkpoint_cls)
for cb in config["callbacks"]
):
# Only add our own callback if it hasn't been added before
checkpoint_frequency = (
self.run_config.checkpoint_config.checkpoint_frequency
)
callback = self._tune_callback_checkpoint_cls(
filename=MODEL_KEY, frequency=checkpoint_frequency
)
config["callbacks"] += [callback]
config[self._init_model_arg_name] = init_model
if init_model:
# If restoring, make sure that we only create num_boosting_round trees,
# and not init_model_trees + num_boosting_round trees
last_iteration = self._model_iteration(init_model)
num_iterations = config.get(
self._num_iterations_argument, self._default_num_iterations
)
new_iterations = num_iterations - last_iteration
config[self._num_iterations_argument] = new_iterations
logger.warning(
f"Model loaded from checkpoint will train for "
f"additional {new_iterations} iterations (trees) in order "
"to achieve the target number of iterations "
f"({self._num_iterations_argument}={num_iterations})."
)
model = self._train(
params=self.params,
dtrain=train_dmatrix,
evals_result=evals_result,
evals=[(dmatrix, k) for k, dmatrix in dmatrices.items()],
ray_params=self._ray_params,
**config,
)
checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end
if checkpoint_at_end is None:
checkpoint_at_end = True
if checkpoint_at_end:
self._checkpoint_at_end(model, evals_result)
def _generate_trainable_cls(self) -> Type["Trainable"]:
trainable_cls = super()._generate_trainable_cls()
trainer_cls = self.__class__
scaling_config = self.scaling_config
ray_params_cls = self._ray_params_cls
default_ray_params = self._default_ray_params
class GBDTTrainable(trainable_cls):
@classmethod
def default_resource_request(cls, config):
# `config["scaling_config"] is a dataclass when passed via the
# `scaling_config` argument in `Trainer` and is a dict when passed
# via the `scaling_config` key of `param_spec`.
updated_scaling_config = config.get("scaling_config", scaling_config)
if isinstance(updated_scaling_config, dict):
updated_scaling_config = ScalingConfig(**updated_scaling_config)
validated_scaling_config = trainer_cls._validate_scaling_config(
updated_scaling_config
)
return _convert_scaling_config_to_ray_params(
validated_scaling_config, ray_params_cls, default_ray_params
).get_tune_resources()
return GBDTTrainable