Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified TorchTrainer] Add PyTorch Lightning Trainer Utilities #37989

Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3a7dc6e
init
woshiyyya Aug 1, 2023
c7b91a7
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 1, 2023
4ef88c9
add docstring
woshiyyya Aug 1, 2023
7b5da30
Apply suggestions from code review
woshiyyya Aug 2, 2023
95e2b9b
add deepspeed tests
woshiyyya Aug 2, 2023
4df1ec6
fix import error
woshiyyya Aug 2, 2023
3914a21
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 2, 2023
21a8ce3
make report callback public
woshiyyya Aug 2, 2023
abdf842
fix lint
woshiyyya Aug 3, 2023
d3b01f0
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 3, 2023
0ecd61a
use new ray.train api
woshiyyya Aug 3, 2023
0de4bc4
switch to new api in LightningTrainer
woshiyyya Aug 7, 2023
2b2589e
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 7, 2023
7c1b1e2
WIP: add lightning user guides
woshiyyya Aug 7, 2023
e8285ab
change default ckpt name
woshiyyya Aug 7, 2023
6c326a0
add migration guides and api
woshiyyya Aug 8, 2023
2f9f5a4
Apply suggestions from code review
woshiyyya Aug 8, 2023
4f1412f
Update doc/source/train/distributed-pytorch/migration-guides.rst
woshiyyya Aug 8, 2023
c9f3c03
address comments
woshiyyya Aug 8, 2023
8264549
change the resume training snippet
woshiyyya Aug 8, 2023
ba81c15
add mnist example
woshiyyya Aug 8, 2023
89c7388
fix doc lint
woshiyyya Aug 8, 2023
0d9bdaf
Merge branch 'master' into train/unified-api/add_lightning_utilities
woshiyyya Aug 8, 2023
c186f5b
fix
woshiyyya Aug 8, 2023
6925a05
fixing
woshiyyya Aug 8, 2023
079ddec
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 8, 2023
7c19ffe
fix ut
woshiyyya Aug 8, 2023
8a65e36
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 9, 2023
9d9ed17
update semgrep
woshiyyya Aug 9, 2023
87aa763
Apply suggestions from code review
woshiyyya Aug 9, 2023
74717af
address comments
woshiyyya Aug 9, 2023
a4213a2
Merge branch 'master' into train/unified-api/add_lightning_utilities
woshiyyya Aug 9, 2023
4e06c89
fix func name
woshiyyya Aug 9, 2023
1396883
fix ckpt path
woshiyyya Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ parts:
- file: train/distributed-pytorch/checkpoints
- file: train/distributed-pytorch/experiment-tracking
- file: train/distributed-pytorch/fault-tolerance
- file: train/distributed-pytorch/migration-guides
title: Migration Guides
- file: train/distributed-pytorch/advanced
sections:
- file: train/distributed-pytorch/reproducibility
Expand Down
12 changes: 12 additions & 0 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ PyTorch Lightning
~train.lightning.LightningConfigBuilder
~train.lightning.LightningCheckpoint
~train.lightning.LightningPredictor
~train.lightning.prepare_trainer
~train.lightning.RayLightningEnvironment
~train.lightning.RayDDPStrategy
~train.lightning.RayFSDPStrategy
~train.lightning.RayDeepSpeedStrategy
~train.lightning.RayTrainReportCallback

.. note::

We will deprecate `LightningTrainer`, `LightningConfigBuilder`,
`LightningCheckpoint`, and `LightningPredictor` in Ray 2.8. Please
refer to the migration guides for more info.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Tensorflow/Keras
~~~~~~~~~~~~~~~~
Expand Down
144 changes: 139 additions & 5 deletions doc/source/train/distributed-pytorch/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,96 @@ appropriately in distributed training.
print(result.checkpoint.to_dict())
# {'epoch': 4, 'model_weights': OrderedDict([('bias', tensor([-0.1215])), ('weight', tensor([[0.3253, 0.1979, 0.4525, 0.2850]]))]), '_timestamp': 1656107095, '_preprocessor': None, '_current_checkpoint_id': 4}

By default, checkpoints will be persisted to local disk in the :ref:`log
directory <train-log-dir>` of each run.
.. tab-item:: PyTorch Lightning

Ray Train leverages PyTorch Lightning's Callback interface to report metrics
and checkpoints. We provide a simple callback implementation that reports
``on_train_epoch_end``.

Specifically, on each train epoch end, it

- collects all the logged metrics from ``trainer.callback_metrics``
- saves a checkpoint via ``trainer.save_checkpoint``
- reports to Ray Train via ``ray.train.report(metrics, checkpoint)``

.. code-block:: python
:emphasize-lines: 2,11,20,28,29,30,31,32

import pytorch_lightning as pl
from ray.train.lightning import RayTrainReportCallback
from ray.train.torch import TorchTrainer
from ray.train import CheckpointConfig, RunConfig

class MyLightningModule(LightningModule):
...
def on_validation_epoch_end(self):
...
mean_acc = calculate_accuracy()
self.log("mean_accuracy", mean_acc, sync_dist=True)

def train_func_per_worker():
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)

trainer = pl.Trainer(
...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's use comments here so that the python code at least would parse correctly (even if it isn't executed)

Suggested change
...
# ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this actually works in python!

callbacks = [RayTrainReportCallback()]
)
trainer.fit(model, datamodule=datamodule)

ray_trainer = TorchTrainer(
train_func_per_worker,
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="mean_accuracy",
checkpoint_score_order="max",
),
)
)
result = ray_trainer.fit()


You can always get the saved checkpoint path from ``result.checkpoint`` and
``result.best_checkpoints``.

For more advanced usage (e.g. reporting at different frequency, reporting
customized checkpoint files), you can implement your own customized callback.
Here is a simple example that reports a checkpoint every 3 epochs:

.. code-block:: python

import os
import ray
from ray.train import Checkpoint
from tempfile import TemporaryDirectory
from pytorch_lightning.callbacks import Callback

class CustomRayTrainReportCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
if trainer.current_epoch % 3 != 0:
return

with TemporaryDirectory() as tmpdir:
# Fetch metrics
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}

# Add customized metrics
metrics["epoch"] = trainer.current_epoch
metrics["custom_metric"] = 123

# Save model checkpoint file to tmpdir
ckpt_path = os.path.join(tmpdir, "ckpt.pt")
trainer.save_checkpoint(ckpt_path, weights_only=False)

# Report to train session
checkpoint = Checkpoint.from_directory(tmpdir)
ray.train.report(metrics=metrics, checkpoint=checkpoint)

By default, checkpoints will be persisted to the :ref:`log directory <train-log-dir>` of each run.


.. _train-dl-configure-checkpoints:
Expand All @@ -96,7 +184,7 @@ Configuring checkpoints
-----------------------

For more configurability of checkpointing behavior (specifically saving
checkpoints to disk), a :py:class:`~ray.air.config.CheckpointConfig` can be passed into
checkpoints to disk), a :py:class:`~ray.train.CheckpointConfig` can be passed into
``Trainer``.

.. literalinclude:: ../doc_code/key_concepts.py
Expand All @@ -107,12 +195,18 @@ checkpoints to disk), a :py:class:`~ray.air.config.CheckpointConfig` can be pass

.. seealso::

See the :class:`~ray.air.CheckpointConfig` API reference.
See the :class:`~ray.train.CheckpointConfig` API reference.

.. note::

If you want to save the top-k checkpoints with respect to a metric via
:py:class:`~ray.train.CheckpointConfig`,
please ensure that the metric is always reported together with the checkpoints.

**[Experimental] Distributed Checkpoints**: For model parallel workloads where the models do not fit in a single GPU worker,
it will be important to save and upload the model that is partitioned across different workers. You
can enable this by setting `_checkpoint_keep_all_ranks=True` to retain the model checkpoints across workers,
and `_checkpoint_upload_from_workers=True` to upload their checkpoints to cloud directly in :class:`~ray.air.CheckpointConfig`. This functionality works for any trainer that inherits from :class:`~ray.train.data_parallel_trainer.DataParallelTrainer`.
and `_checkpoint_upload_from_workers=True` to upload their checkpoints to cloud directly in :class:`~ray.train.CheckpointConfig`. This functionality works for any trainer that inherits from :class:`~ray.train.data_parallel_trainer.DataParallelTrainer`.



Expand Down Expand Up @@ -204,3 +298,43 @@ Checkpoints can be loaded into the training function in 2 steps:

print(result.checkpoint.to_dict())
# {'epoch': 3, 'model_weights': OrderedDict([('bias', tensor([0.0902])), ('weight', tensor([[-0.1549, -0.0861, 0.4353, -0.4116]]))]), '_timestamp': 1656108265, '_preprocessor': None, '_current_checkpoint_id': 2}

.. tab-item:: PyTorch Lightning

.. code-block:: python
:emphasize-lines: 11-17

from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import RayTrainReportCallback
from os.path import join

def train_func_per_worker():
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)
trainer = pl.Trainer(
...
callbacks=[RayTrainReportCallback()]
)

checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as ckpt_dir:
ckpt_path = join(ckpt_dir, "checkpoint.ckpt")
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
else:
trainer.fit(model, datamodule=datamodule)

# Build a Ray Train Checkpoint
# Suppose we have a Lightning checkpoint under ./ckpt_dir/checkpoint.ckpt
checkpoint = Checkpoint.from_directory("./ckpt_dir/checkpoint.ckpt")

# Resume training from checkpoint file
ray_trainer = TorchTrainer(
train_func_per_worker,
scaling_config=ScalingConfig(num_workers=2),
resume_from_checkpoint=checkpoint,
)
result = ray_trainer.fit()

Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ training.

import torch
from torch.utils.data import DataLoader, DistributedSampler
+from ray import train
+import ray.train.torch
+ from ray import train
+ import ray.train.torch


def train_func():
Expand All @@ -91,6 +91,71 @@ training.
.. code-block:: python

global_batch_size = worker_batch_size * train.get_context().get_world_size()

.. tab-item:: PyTorch Lightning

Ray Train will set up your distributed process group on each worker. You only need to
provide several Ray Train utilities in your Lightning Trainer definition.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: diff

import pytorch_lightning as pl
+ from ray.train.lightning import (
+ prepare_trainer,
+ RayDDPStrategy,
+ RayLightningEnvironment,
+ )

def train_func_per_worker():
...
model = MyLightningModule(...)
datamodule = MyLightningDataModule(...)

trainer = pl.Trainer(
- devices=[0,1,2,3],
- strategy=DDPStrategy(),
- plugins=[LightningEnvironment()],
+ devices="auto",
+ strategy=RayDDPStrategy(),
+ plugins=[RayLightningEnvironment()]
)
+ trainer = prepare_trainer(trainer)

trainer.fit(model, datamodule=datamodule)

We provide the following utilities:
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Step 1: Configure Distributed Strategy
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Ray Train offers several subclassed distributed strategy for Lightning.
These strategy retains the same argument list as the base strategy class.
Internally, it correctly configures the root device and the distributed
sampler arguments.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

- :class:`~ray.train.lightning.RayDDPStrategy`
- :class:`~ray.train.lightning.RayFSDPStrategy`
- :class:`~ray.train.lightning.RayDeepSpeedStrategy`

Step 2: Configure Ray Cluster Environment Plugin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Ray Train also provides :class:`~ray.train.lightning.RayLightningEnvironment`
as a specification for Ray Cluster. This utility class configures the worker's
local, global, and node rank and world size.

In addition, Ray TorchTrainer has already configured the correct
``CUDA_VISIBLE_DEVICES`` for you. One should always use all available
GPUs by setting ``devices="auto"``.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Step 3: Prepare your Lightning Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Finally, pass your Lightning trainer into
:meth:`~ray.train.lightning.prepare_trainer` to validate
your configurations.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved



Creating a :class:`~ray.train.torch.TorchTrainer`
-------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,55 @@ Basics

Let's use a single Torch training workload as a running example. A very basic example of using Ray Data with TorchTrainer looks like this:

.. literalinclude:: ../doc_code/data_ingest_torch_new.py
:language: python
:start-after: __basic__
:end-before: __basic_end__
.. tab-set::

.. tab-item:: PyTorch

.. literalinclude:: ../doc_code/data_ingest_torch_new.py
:language: python
:start-after: __basic__
:end-before: __basic_end__

In this basic example, the `train_ds` object is created in your Ray script before the Trainer is even instantiated. The `train_ds` object is passed to the Trainer via the `datasets` argument, and is accessible to the `train_loop_per_worker` function via the :meth:`train.get_dataset_shard <ray.train.get_dataset_shard>` method.

.. tab-item:: PyTorch Lightning

.. code-block:: python
:emphasize-lines: 9,10,13,14,23,24

from ray import train

train_data = ray.data.read_csv("./train.csv")
val_data = ray.data.read_csv("./validation.csv")

def train_func_per_worker():
# Access Ray datsets in your train_func via ``get_dataset_shard``.
# The "train" dataset gets sharded across workers by default
train_ds = train.get_dataset_shard("train")
val_ds = train.get_dataset_shard("validation")

# Create Ray dataset iterables via ``iter_torch_batches``.
train_dataloader = train_ds.iter_torch_batches(batch_size=16)
val_dataloader = val_ds.iter_torch_batches(batch_size=16)

...

trainer = pl.Trainer(...)

# Feed the Ray dataset iterables to ``pl.Trainer.fit``.
trainer.fit(
model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)

trainer = TorchTrainer(
train_func,
datasets={"train": train_data, "validation": val_data},
scaling_config=ScalingConfig(num_workers=4),
)
trainer.fit()

In this basic example, the `train_ds` object is created in your Ray script before the Trainer is even instantiated. The `train_ds` object is passed to the Trainer via the `datasets` argument, and is accessible to the `train_loop_per_worker` function via the :meth:`train.get_dataset_shard <ray.train.get_dataset_shard>` method.

Splitting data across workers
-----------------------------
Expand Down
Loading
Loading