Skip to content

Latest commit

 

History

History
253 lines (163 loc) · 9.82 KB

trainers.rst

File metadata and controls

253 lines (163 loc) · 9.82 KB

Using Trainers

image

Ray AIR Trainers provide a way to scale out training with popular machine learning frameworks. As part of Ray Train, Trainers enable users to run distributed multi-node training with fault tolerance. Fully integrated with the Ray ecosystem, Trainers leverage Ray Data <air-ingest> to enable scalable preprocessing and performant distributed data ingestion. Also, Trainers can be composed with Tuners <ray.tune.Tuner> for distributed hyperparameter tuning.

After executing training, Trainers output the trained model in the form of a Checkpoint <ray.air.checkpoint.Checkpoint>, which can be used for batch or online prediction inference.

There are three broad categories of Trainers that AIR offers:

  • Deep Learning Trainers <air-trainers-dl> (Pytorch, Tensorflow, Horovod)
  • Tree-based Trainers <air-trainers-tree> (XGboost, LightGBM)
  • Other ML frameworks <air-trainers-other> (Hugging Face, Scikit-Learn, RLlib)

Trainer Basics

All trainers inherit from the BaseTrainer <ray.train.base_trainer.BaseTrainer> interface. To construct a Trainer, you can provide:

  • A scaling_config <ray.air.config.ScalingConfig>, which specifies how many parallel training workers and what type of resources (CPUs/GPUs) to use per worker during training.
  • A run_config <ray.air.config.RunConfig>, which configures a variety of runtime parameters such as fault tolerance, logging, and callbacks.
  • A collection of datasets <air-ingest> and a preprocessor <air-preprocessors> for the provided datasets, which configures preprocessing and the datasets to ingest from.
  • resume_from_checkpoint, which is a checkpoint path to resume from, should your training run be interrupted.

After instantiating a Trainer, you can invoke it by calling Trainer.fit() <ray.air.trainer.BaseTrainer.fit>.

doc_code/xgboost_trainer.py

Deep Learning Trainers

Ray Train offers 3 main deep learning trainers: TorchTrainer <ray.train.torch.TorchTrainer>, TensorflowTrainer <ray.train.tensorflow.TensorflowTrainer>, and HorovodTrainer <ray.train.horovod.HorovodTrainer>.

These three trainers all take a train_loop_per_worker parameter, which is a function that defines the main training logic that runs on each training worker.

Under the hood, Ray AIR will use the provided scaling_config to instantiate the correct number of workers.

Upon instantiation, each worker will be able to reference a global Session <air-session-ref> object, which provides functionality for reporting metrics, saving checkpoints, and more.

You can provide multiple datasets to a trainer via the datasets parameter. If datasets includes a training dataset (denoted by the "train" key), then it will be split into multiple dataset shards, with each worker training on a single shard. All other datasets will not be split. You can access the data shard within a worker via ~ray.air.session.get_dataset_shard(), and use ~ray.data.Dataset.to_tf or iter_torch_batches to generate batches of Tensorflow or Pytorch tensors. You can read more about data ingest <air-ingest> here.

Read more about Ray Train's Deep Learning Trainers <train-dl-guide>.

Code examples

Torch

doc_code/torch_trainer.py

Tensorflow

doc_code/tf_starter.py

Horovod

doc_code/hvd_trainer.py

How to report metrics and checkpoints?

During model training, you may want to save training metrics and checkpoints for downstream processing (e.g., serving the model).

Use the Session <air-session-ref> API to gather metrics and save checkpoints. Checkpoints are synced to driver or the cloud storage based on user's configurations, as specified in Trainer(run_config=...).

Code example

doc_code/report_metrics_and_save_checkpoints.py

Tree-based Trainers

Ray Train offers 2 main tree-based trainers: XGBoostTrainer <ray.train.xgboost.XGBoostTrainer> and LightGBMTrainer <ray.train.lightgbm.LightGBMTrainer>.

See here for a more detailed user-guide <train-gbdt-guide>.

XGBoost Trainer

Ray AIR also provides an easy to use XGBoostTrainer <ray.train.xgboost.XGBoostTrainer> for training XGBoost models at scale.

To use this trainer, you will need to first run: pip install -U xgboost-ray.

doc_code/xgboost_trainer.py

LightGBMTrainer

Similarly, Ray AIR comes with a LightGBMTrainer <ray.train.lightgbm.LightGBMTrainer> for training LightGBM models at scale.

To use this trainer, you will need to first run pip install -U lightgbm-ray.

doc_code/lightgbm_trainer.py

Other Trainers

Hugging Face

TransformersTrainer

TransformersTrainer <ray.train.huggingface.TransformersTrainer> further extends TorchTrainer <ray.train.torch.TorchTrainer>, built for interoperability with the HuggingFace Transformers library.

Users are required to provide a trainer_init_per_worker function which returns a transformers.Trainer object. The trainer_init_per_worker function will have access to preprocessed train and evaluation datasets.

Upon calling TransformersTrainer.fit(), multiple workers (ray actors) will be spawned, and each worker will create its own copy of a transformers.Trainer.

Each worker will then invoke transformers.Trainer.train(), which will perform distributed training via Pytorch DDP.

Code example

doc_code/hf_trainer.py

AccelerateTrainer

If you prefer a more fine-grained Hugging Face API than what Transformers provides, you can use AccelerateTrainer <ray.train.huggingface.AccelerateTrainer> to run training functions making use of Hugging Face Accelerate. Similarly to TransformersTrainer <ray.train.huggingface.TransformersTrainer>, AccelerateTrainer <ray.train.huggingface.AccelerateTrainer> is also an extension of TorchTrainer <ray.train.torch.TorchTrainer>.

AccelerateTrainer <ray.train.huggingface.AccelerateTrainer> allows you to pass an Accelerate configuration file generated with accelerate config to be applied on all training workers. This ensures that the worker environments are set up correctly for Accelerate, allowing you to take advantage of Accelerate APIs and integrations such as DeepSpeed and FSDP just as you would if you were running Accelerate without Ray.

Note

AccelerateTrainer will override some settings set with accelerate config, mainly related to the topology and networking. See the AccelerateTrainer <ray.train.huggingface.AccelerateTrainer> API reference for more details.

Aside from Accelerate support, the usage is identical to TorchTrainer <ray.train.torch.TorchTrainer>, meaning you define your own training function and use the Session <air-session-ref> API to report metrics, save checkpoints etc.

Code example

doc_code/accelerate_trainer.py

Scikit-Learn Trainer

Note

This trainer is not distributed.

The Scikit-Learn Trainer is a thin wrapper to launch scikit-learn training within Ray AIR. Even though this trainer is not distributed, you can still benefit from its integration with Ray Tune for distributed hyperparameter tuning and scalable batch/online prediction.

doc_code/sklearn_trainer.py

RLlib Trainer

RLTrainer provides an interface to RL Trainables. This enables you to use the same abstractions as in the other trainers to define the scaling behavior, and to use Ray Data for offline training.

Please note that some scaling behavior still has to be defined separately. The scaling_config <ray.air.config.ScalingConfig> will set the number of training workers ("Rollout workers"). To set the number of e.g. evaluation workers, you will have to specify this in the config parameter of the RLTrainer:

doc_code/rl_trainer.py

How to interpret training results?

Calling Trainer.fit() returns a Result <ray.air.Result>, providing you access to metrics, checkpoints, and errors. You can interact with a Result object as follows:

result = trainer.fit()

# returns the last saved checkpoint
result.checkpoint

# returns the N best saved checkpoints, as configured in ``RunConfig.CheckpointConfig``
result.best_checkpoints

# returns the final metrics as reported
result.metrics

# returns the Exception if training failed.
result.error

# Returns a pandas dataframe of all reported results
result.metrics_dataframe

See the Result docstring <ray.air.result.Result> for more details.