The ~ray.train.torch.TorchTrainer
can help you easily launch your DeepSpeed training across a distributed Ray cluster.
You only need to run your existing training code with a TorchTrainer. You can expect the final code to look like this:
import deepspeed
from deepspeed.accelerator import get_accelerator
def train_func(config):
# Instantiate your model and dataset
model = ...
train_dataset = ...
eval_dataset = ...
deepspeed_config = {...} # Your Deepspeed config
# Prepare everything for distributed training
model, optimizer, train_dataloader, lr_scheduler = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
training_data=tokenized_datasets["train"],
collate_fn=collate_fn,
config=deepspeed_config,
)
# Define the GPU device for the current worker
device = get_accelerator().device_name(model.local_rank)
# Start training
...
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(...),
...
)
trainer.fit()
Below is a simple example of ZeRO-3 training with DeepSpeed only.
Example with Ray Data
Show Code
/../../python/ray/train/examples/deepspeed/deepspeed_torch_trainer.py
Example with PyTorch DataLoader
Show Code
/../../python/ray/train/examples/deepspeed/deepspeed_torch_trainer_no_raydata.py
Tip
To run DeepSpeed with pure PyTorch, you don't need to provide any additional Ray Train utilities like ~ray.train.torch.prepare_model
or ~ray.train.torch.prepare_data_loader
in your training funciton. Instead, keep using deepspeed.initialize() as usual to prepare everything for distributed training.
Many deep learning frameworks have integrated with DeepSpeed, including Lightning, Transformers, Accelerate, and more. You can run all these combinations in Ray Train.
Check the below examples for more details:
Framework | Example |
---|---|
Accelelate (User Guide <train-hf-accelerate> ) |
Fine-tune Llama-2 series models with Deepspeed, Accelerate, and Ray Train. |
Transformers (User Guide <train-pytorch-transformers> ) |
Fine-tune GPT-J-6b with DeepSpeed and Hugging Face Transformers <gptj_deepspeed_finetune> |
Lightning (User Guide <train-pytorch-lightning> ) |
Fine-tune vicuna-13b with DeepSpeed and PyTorch Lightning <vicuna_lightning_deepspeed_finetuning> |