Skip to content

Commit

Permalink
[Unified TorchTrainer] Add PyTorch Lightning Trainer Utilities (#37989)
Browse files Browse the repository at this point in the history
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
  • Loading branch information
woshiyyya and matthewdeng committed Aug 10, 2023
1 parent 194ebf7 commit 0dd32ed
Show file tree
Hide file tree
Showing 17 changed files with 1,340 additions and 796 deletions.
2 changes: 2 additions & 0 deletions doc/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ parts:
- file: train/distributed-pytorch/checkpoints
- file: train/distributed-pytorch/experiment-tracking
- file: train/distributed-pytorch/fault-tolerance
- file: train/distributed-pytorch/migration-guides
title: Migration Guides
- file: train/distributed-pytorch/advanced
sections:
- file: train/distributed-pytorch/reproducibility
Expand Down
12 changes: 12 additions & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ PyTorch Lightning
~train.lightning.LightningConfigBuilder
~train.lightning.LightningCheckpoint
~train.lightning.LightningPredictor
~train.lightning.prepare_trainer
~train.lightning.RayLightningEnvironment
~train.lightning.RayDDPStrategy
~train.lightning.RayFSDPStrategy
~train.lightning.RayDeepSpeedStrategy
~train.lightning.RayTrainReportCallback

.. note::

We will deprecate `LightningTrainer`, `LightningConfigBuilder`,
`LightningCheckpoint`, and `LightningPredictor` in Ray 2.8. Please
refer to the :ref:`migration guides <migration-guide>` for more info.

Tensorflow/Keras
~~~~~~~~~~~~~~~~
Expand Down
144 changes: 139 additions & 5 deletions doc/source/train/distributed-pytorch/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,96 @@ appropriately in distributed training.
print(result.checkpoint.to_dict())
# {'epoch': 4, 'model_weights': OrderedDict([('bias', tensor([-0.1215])), ('weight', tensor([[0.3253, 0.1979, 0.4525, 0.2850]]))]), '_timestamp': 1656107095, '_preprocessor': None, '_current_checkpoint_id': 4}
By default, checkpoints will be persisted to local disk in the :ref:`log
directory <train-log-dir>` of each run.
.. tab-item:: PyTorch Lightning

Ray Train leverages PyTorch Lightning's Callback interface to report metrics
and checkpoints. We provide a simple callback implementation that reports
``on_train_epoch_end``.

Specifically, on each train epoch end, it

- collects all the logged metrics from ``trainer.callback_metrics``
- saves a checkpoint via ``trainer.save_checkpoint``
- reports to Ray Train via ``ray.train.report(metrics, checkpoint)``

.. code-block:: python
:emphasize-lines: 2,11,20,28,29,30,31,32
import pytorch_lightning as pl
from ray.train.lightning import RayTrainReportCallback
from ray.train.torch import TorchTrainer
from ray.train import CheckpointConfig, RunConfig
class MyLightningModule(LightningModule):
...
def on_validation_epoch_end(self):
...
mean_acc = calculate_accuracy()
self.log("mean_accuracy", mean_acc, sync_dist=True)
def train_func_per_worker():
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
# ...
callbacks = [RayTrainReportCallback()]
)
trainer.fit(model, datamodule=datamodule)
ray_trainer = TorchTrainer(
train_func_per_worker,
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="mean_accuracy",
checkpoint_score_order="max",
),
)
)
result = ray_trainer.fit()
You can always get the saved checkpoint path from ``result.checkpoint`` and
``result.best_checkpoints``.

For more advanced usage (e.g. reporting at different frequency, reporting
customized checkpoint files), you can implement your own customized callback.
Here is a simple example that reports a checkpoint every 3 epochs:

.. code-block:: python
import os
import ray
from ray.train import Checkpoint
from tempfile import TemporaryDirectory
from pytorch_lightning.callbacks import Callback
class CustomRayTrainReportCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch % 3 != 0:
return
with TemporaryDirectory() as tmpdir:
# Fetch metrics
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}
# Add customized metrics
metrics["epoch"] = trainer.current_epoch
metrics["custom_metric"] = 123
# Save model checkpoint file to tmpdir
ckpt_path = os.path.join(tmpdir, "ckpt.pt")
trainer.save_checkpoint(ckpt_path, weights_only=False)
# Report to train session
checkpoint = Checkpoint.from_directory(tmpdir)
ray.train.report(metrics=metrics, checkpoint=checkpoint)
By default, checkpoints will be persisted to the :ref:`log directory <train-log-dir>` of each run.


.. _train-dl-configure-checkpoints:
Expand All @@ -96,7 +184,7 @@ Configuring checkpoints
-----------------------

For more configurability of checkpointing behavior (specifically saving
checkpoints to disk), a :py:class:`~ray.air.config.CheckpointConfig` can be passed into
checkpoints to disk), a :py:class:`~ray.train.CheckpointConfig` can be passed into
``Trainer``.

.. literalinclude:: ../doc_code/key_concepts.py
Expand All @@ -107,12 +195,18 @@ checkpoints to disk), a :py:class:`~ray.air.config.CheckpointConfig` can be pass

.. seealso::

See the :class:`~ray.air.CheckpointConfig` API reference.
See the :class:`~ray.train.CheckpointConfig` API reference.

.. note::

If you want to save the top-k checkpoints with respect to a metric via
:py:class:`~ray.train.CheckpointConfig`,
please ensure that the metric is always reported together with the checkpoints.

**[Experimental] Distributed Checkpoints**: For model parallel workloads where the models do not fit in a single GPU worker,
it will be important to save and upload the model that is partitioned across different workers. You
can enable this by setting `_checkpoint_keep_all_ranks=True` to retain the model checkpoints across workers,
and `_checkpoint_upload_from_workers=True` to upload their checkpoints to cloud directly in :class:`~ray.air.CheckpointConfig`. This functionality works for any trainer that inherits from :class:`~ray.train.data_parallel_trainer.DataParallelTrainer`.
and `_checkpoint_upload_from_workers=True` to upload their checkpoints to cloud directly in :class:`~ray.train.CheckpointConfig`. This functionality works for any trainer that inherits from :class:`~ray.train.data_parallel_trainer.DataParallelTrainer`.



Expand Down Expand Up @@ -204,3 +298,43 @@ Checkpoints can be loaded into the training function in 2 steps:
print(result.checkpoint.to_dict())
# {'epoch': 3, 'model_weights': OrderedDict([('bias', tensor([0.0902])), ('weight', tensor([[-0.1549, -0.0861, 0.4353, -0.4116]]))]), '_timestamp': 1656108265, '_preprocessor': None, '_current_checkpoint_id': 2}
.. tab-item:: PyTorch Lightning

.. code-block:: python
:emphasize-lines: 11-17
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import RayTrainReportCallback
from os.path import join
def train_func_per_worker():
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
...
callbacks=[RayTrainReportCallback()]
)
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as ckpt_dir:
ckpt_path = join(ckpt_dir, "checkpoint.ckpt")
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
else:
trainer.fit(model, datamodule=datamodule)
# Build a Ray Train Checkpoint
# Suppose we have a Lightning checkpoint under ./ckpt_dir/checkpoint.ckpt
checkpoint = Checkpoint.from_directory("./ckpt_dir/checkpoint.ckpt")
# Resume training from checkpoint file
ray_trainer = TorchTrainer(
train_func_per_worker,
scaling_config=ScalingConfig(num_workers=2),
resume_from_checkpoint=checkpoint,
)
result = ray_trainer.fit()
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ training.
import torch
from torch.utils.data import DataLoader, DistributedSampler
+from ray import train
+import ray.train.torch
+ from ray import train
+ import ray.train.torch
def train_func():
Expand All @@ -91,6 +91,71 @@ training.
.. code-block:: python
global_batch_size = worker_batch_size * train.get_context().get_world_size()
.. tab-item:: PyTorch Lightning

Ray Train will set up your distributed process group on each worker. You only need to
make a few changes to your Lightning Trainer definition.

.. code-block:: diff
import pytorch_lightning as pl
+ from ray.train.lightning import (
+ prepare_trainer,
+ RayDDPStrategy,
+ RayLightningEnvironment,
+ )
def train_func(config):
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
- devices=[0,1,2,3],
- strategy=DDPStrategy(),
- plugins=[LightningEnvironment()],
+ devices="auto",
+ strategy=RayDDPStrategy(),
+ plugins=[RayLightningEnvironment()]
)
+ trainer = prepare_trainer(trainer)
trainer.fit(model, datamodule=datamodule)
**Step 1: Configure Distributed Strategy**

Ray Train offers several subclassed distributed strategies for Lightning.
These strategies retain the same argument list as their base strategy classes.
Internally, they configure the root device and the distributed
sampler arguments.

- :class:`~ray.train.lightning.RayDDPStrategy`
- :class:`~ray.train.lightning.RayFSDPStrategy`
- :class:`~ray.train.lightning.RayDeepSpeedStrategy`

**Step 2: Configure Ray Cluster Environment Plugin**

Ray Train also provides :class:`~ray.train.lightning.RayLightningEnvironment`
as a specification for Ray Cluster. This utility class configures the worker's
local, global, and node rank and world size.

**Step 3: Configure Parallel Devices**

In addition, Ray TorchTrainer has already configured the correct
``CUDA_VISIBLE_DEVICES`` for you. One should always use all available
GPUs by setting ``devices="auto"``.

**Step 4: Prepare your Lightning Trainer**

Finally, pass your Lightning Trainer into
:meth:`~ray.train.lightning.prepare_trainer` to validate
your configurations.

**Step 5: Define a Ray TorchTrainer**



Creating a :class:`~ray.train.torch.TorchTrainer`
-------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,57 @@ Basics

Let's use a single Torch training workload as a running example. A very basic example of using Ray Data with TorchTrainer looks like this:

.. literalinclude:: ../doc_code/data_ingest_torch_new.py
:language: python
:start-after: __basic__
:end-before: __basic_end__
.. tab-set::

.. tab-item:: PyTorch

.. literalinclude:: ../doc_code/data_ingest_torch_new.py
:language: python
:start-after: __basic__
:end-before: __basic_end__

In this basic example, the `train_ds` object is created in your Ray script before the Trainer is even instantiated. The `train_ds` object is passed to the Trainer via the `datasets` argument, and is accessible to the `train_loop_per_worker` function via the :meth:`train.get_dataset_shard <ray.train.get_dataset_shard>` method.

.. tab-item:: PyTorch Lightning

.. code-block:: python
:emphasize-lines: 9,10,13,14,25,26
from ray import train
train_data = ray.data.read_csv("./train.csv")
val_data = ray.data.read_csv("./validation.csv")
def train_func_per_worker():
# Access Ray datsets in your train_func via ``get_dataset_shard``.
# The "train" dataset gets sharded across workers by default
train_ds = train.get_dataset_shard("train")
val_ds = train.get_dataset_shard("validation")
# Create Ray dataset iterables via ``iter_torch_batches``.
train_dataloader = train_ds.iter_torch_batches(batch_size=16)
val_dataloader = val_ds.iter_torch_batches(batch_size=16)
...
trainer = pl.Trainer(
# ...
)
# Feed the Ray dataset iterables to ``pl.Trainer.fit``.
trainer.fit(
model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)
trainer = TorchTrainer(
train_func,
datasets={"train": train_data, "validation": val_data},
scaling_config=ScalingConfig(num_workers=4),
)
trainer.fit()
In this basic example, the `train_ds` object is created in your Ray script before the Trainer is even instantiated. The `train_ds` object is passed to the Trainer via the `datasets` argument, and is accessible to the `train_loop_per_worker` function via the :meth:`train.get_dataset_shard <ray.train.get_dataset_shard>` method.
Splitting data across workers
-----------------------------
Expand Down
Loading

0 comments on commit 0dd32ed

Please sign in to comment.