There are four main concepts in the Ray Train library.
Trainers
execute distributed training.Configuration
objects are used to configure training.Checkpoints
are returned as the result of training.Predictors
can be used for inference and batch prediction.
Trainers are responsible for executing (distributed) training runs. The output of a Trainer run is a Result <train-key-concepts-results>
that contains metrics from the training run and the latest saved Checkpoint <air-checkpoint-ref>
. Trainers can also be configured with Datasets <air-ingest>
and Preprocessors <air-preprocessors>
for scalable data ingest and preprocessing.
There are three categories of built-in Trainers:
Deep Learning Trainers
Ray Train supports the following deep learning trainers:
TorchTrainer <ray.train.torch.TorchTrainer>
TensorflowTrainer <ray.train.tensorflow.TensorflowTrainer>
HorovodTrainer <ray.train.horovod.HorovodTrainer>
For these trainers, you usually define your own training function that loads the model and executes single-worker training steps. Refer to the following guides for more details:
Deep learning user guide <train-dl-guide>
Quick overview of deep-learning trainers in the Ray AIR documentation <air-trainers-dl>
Tree-Based Trainers
Tree-based trainers utilize gradient-based decision trees for training. The most popular libraries for this are XGBoost and LightGBM.
XGBoostTrainer <ray.train.xgboost.XGBoostTrainer>
LightGBMTrainer <ray.train.lightgbm.LightGBMTrainer>
For these trainers, you just pass a dataset and parameters. The training loop is configured automatically.
XGBoost/LightGBM user guide <train-gbdt-guide>
Quick overview of tree-based trainers in the Ray AIR documentation <air-trainers-tree>
Other Trainers
Some trainers don't fit into the other two categories, such as:
HuggingFaceTrainer <ray.train.huggingface.HuggingFaceTrainer>
for NLPRLTrainer <ray.train.rl.RLTrainer>
for reinforcement learningSklearnTrainer <ray.train.sklearn.sklearn_trainer.SklearnTrainer>
for (non-distributed) training of sklearn models.Other trainers in the Ray AIR documentation <air-trainers-other>
Trainers are configured with configuration objects. There are two main configuration classes, the ScalingConfig <ray.air.config.ScalingConfig>
and the RunConfig <ray.air.config.RunConfig>
. The latter contains subconfigurations, such as the FailureConfig <ray.air.config.FailureConfig>
, SyncConfig <ray.tune.syncer.SyncConfig>
and CheckpointConfig <ray.air.config.CheckpointConfig>
.
Check out the Configurations User Guide <train-config>
for an in-depth guide on using these configurations.
Calling Trainer.fit()
returns a Result <ray.air.result.Result>
object, which includes information about the run such as the reported metrics and the saved checkpoints.
Checkpoints have the following purposes:
- They can be passed to a Trainer to resume training from the given model state.
- They can be used to create a Predictor / BatchPredictor for scalable batch prediction.
- They can be deployed with Ray Serve.
Predictors are the counterpart to Trainers. A Trainer trains a model on a dataset, and a predictor uses the resulting model and performs inference on it.
Each Trainer has a respective Predictor implementation that is compatible with its generated checkpoints.
Example: XGBoostPredictor <ray.train.xgboost.XGBoostPredictor>
/train/doc_code/xgboost_train_predict.py
A predictor can be passed into a BatchPredictor <ray.train.batch_predictor.BatchPredictor>
is used to scale up prediction over a Ray cluster. It takes a Ray Dataset as input.
Example: Batch prediction with XGBoostPredictor <ray.train.xgboost.XGBoostPredictor>
/train/doc_code/xgboost_train_predict.py
See the Predictors user guide <air-predictors>
for more information and examples.