Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified TorchTrainer] Add PyTorch Lightning Trainer Utilities #37989

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3a7dc6e
init
woshiyyya Aug 1, 2023
c7b91a7
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 1, 2023
4ef88c9
add docstring
woshiyyya Aug 1, 2023
7b5da30
Apply suggestions from code review
woshiyyya Aug 2, 2023
95e2b9b
add deepspeed tests
woshiyyya Aug 2, 2023
4df1ec6
fix import error
woshiyyya Aug 2, 2023
3914a21
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 2, 2023
21a8ce3
make report callback public
woshiyyya Aug 2, 2023
abdf842
fix lint
woshiyyya Aug 3, 2023
d3b01f0
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 3, 2023
0ecd61a
use new ray.train api
woshiyyya Aug 3, 2023
0de4bc4
switch to new api in LightningTrainer
woshiyyya Aug 7, 2023
2b2589e
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 7, 2023
7c1b1e2
WIP: add lightning user guides
woshiyyya Aug 7, 2023
e8285ab
change default ckpt name
woshiyyya Aug 7, 2023
6c326a0
add migration guides and api
woshiyyya Aug 8, 2023
2f9f5a4
Apply suggestions from code review
woshiyyya Aug 8, 2023
4f1412f
Update doc/source/train/distributed-pytorch/migration-guides.rst
woshiyyya Aug 8, 2023
c9f3c03
address comments
woshiyyya Aug 8, 2023
8264549
change the resume training snippet
woshiyyya Aug 8, 2023
ba81c15
add mnist example
woshiyyya Aug 8, 2023
89c7388
fix doc lint
woshiyyya Aug 8, 2023
0d9bdaf
Merge branch 'master' into train/unified-api/add_lightning_utilities
woshiyyya Aug 8, 2023
c186f5b
fix
woshiyyya Aug 8, 2023
6925a05
fixing
woshiyyya Aug 8, 2023
079ddec
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 8, 2023
7c19ffe
fix ut
woshiyyya Aug 8, 2023
8a65e36
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 9, 2023
9d9ed17
update semgrep
woshiyyya Aug 9, 2023
87aa763
Apply suggestions from code review
woshiyyya Aug 9, 2023
74717af
address comments
woshiyyya Aug 9, 2023
a4213a2
Merge branch 'master' into train/unified-api/add_lightning_utilities
woshiyyya Aug 9, 2023
4e06c89
fix func name
woshiyyya Aug 9, 2023
1396883
fix ckpt path
woshiyyya Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_torch_lightning_train",
size = "large",
srcs = ["tests/test_torch_lightning_train.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu", "ptl_v2"],
deps = [":train_lib"]
)

py_test(
name = "test_minimal",
size = "small",
Expand Down
20 changes: 16 additions & 4 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,26 @@

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_predictor import LightningPredictor
from ray.train.lightning.lightning_trainer import (
LightningTrainer,
LightningConfigBuilder,
from ray.train.lightning.lightning_utils import (
get_devices,
prepare_trainer,
RayDDPStrategy,
RayFSDPStrategy,
RayDeepSpeedStrategy,
RayLightningEnvironment,
RayModelCheckpoint,
)

__all__ = [
"LightningTrainer",
"LightningConfigBuilder",
"LightningCheckpoint",
"LightningPredictor",
]
"get_devices",
"prepare_trainer",
"RayDDPStrategy",
"RayFSDPStrategy",
"RayDeepSpeedStrategy",
"RayLightningEnvironment",
"RayModelCheckpoint",
]
6 changes: 3 additions & 3 deletions python/ray/train/lightning/lightning_trainer.py
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from ray.train.torch import TorchTrainer
from ray.train.torch.config import TorchConfig
from ray.util import PublicAPI
from ray.train.lightning._lightning_utils import (
from ray.train.lightning.lightning_utils import (
RayDDPStrategy,
RayFSDPStrategy,
RayDeepSpeedStrategy,
RayEnvironment,
RayLightningEnvironment,
RayDataModule,
RayModelCheckpoint,
get_worker_root_device,
Expand Down Expand Up @@ -580,7 +580,7 @@ def _lightning_train_loop_per_worker(config):
for plugin in trainer_config.get("plugins", [])
if not isinstance(plugin, ClusterEnvironment)
]
trainer_config["plugins"].append(RayEnvironment())
trainer_config["plugins"].append(RayLightningEnvironment())

# Setup ddp strategy for ray orchestration
if "strategy" in trainer_config:
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this with the original _lightning_utils name.

Copy link
Member Author

@woshiyyya woshiyyya Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because there are still private methods/classes in that file? Should we remove the underscore after we fully deprecated LightningTrainer?

Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from ray.air.constants import MODEL_KEY
from ray.data.dataset import DataIterator
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.util import PublicAPI

import logging
import shutil
import torch
import tempfile
from packaging.version import Version
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List, Union
from torch.utils.data import IterableDataset, DataLoader

import pytorch_lightning as pl
Expand Down Expand Up @@ -49,6 +50,36 @@ def get_worker_root_device():
return devices


@PublicAPI(stability="alpha")
def get_devices() -> Optional[Union[List[int], str]]:
"""Returns the parallel devices for Lightning Trainer on each Ray Train worker.

This method returns the list of CUDA device indexes of a GPU worker, and returns
"auto" if called inside a CPU worker. Note that you can only call this method within
the training function of the :class:`TorchTrainer <ray.train.torch.TorchTrainer>` class.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Example:
.. testcode::

import ray
import pytorch_lightning as pl
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

def train_loop_per_worker():
devices = ray.train.lightning.get_devices()
# devices == [deivce_id] in GPU workers
# devices == "auto" in CPU workers

trainer = pl.Trainer(
...,
devices=devices
)

"""
device = get_worker_root_device()
return [device.index] if device.index else "auto"
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved


@PublicAPI(stability="alpha")
class RayDDPStrategy(DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration."""

Expand All @@ -64,6 +95,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
)


@PublicAPI(stability="alpha")
class RayFSDPStrategy(FSDPStrategy):
"""Subclass of FSDPStrategy to ensure compatibility with Ray orchestration."""

Expand Down Expand Up @@ -98,6 +130,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:
return super().lightning_module_state_dict()


@PublicAPI(stability="alpha")
class RayDeepSpeedStrategy(DeepSpeedStrategy):
"""Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration."""

Expand All @@ -123,7 +156,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
)


class RayEnvironment(LightningEnvironment):
@PublicAPI(stability="alpha")
class RayLightningEnvironment(LightningEnvironment):
"""Setup Lightning DDP training environment for Ray cluster."""

def world_size(self) -> int:
Expand All @@ -150,6 +184,35 @@ def teardown(self):
pass


@PublicAPI(stability="alpha")
def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:
"""Prepare the PyTorch Lightning Trainer for distributed execution."""

# Check strategy class
valid_strategy_class = [RayDDPStrategy, RayFSDPStrategy, RayDeepSpeedStrategy]

if not any(isinstance(trainer.strategy, cls) for cls in valid_strategy_class):
raise RuntimeError(
f"Invalid strategy class: {type(trainer.strategy)}. To use "
"PyTorch Lightning with Ray, the strategy object should be one of "
"[RayDDPStrategy, RayFSDPStrategy, RayDeepspeedStrategy] class "
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"or its subclass."
)

# Check cluster environment
cluster_environment = getattr(trainer.strategy, "cluster_environment", None)
if cluster_environment and not isinstance(
cluster_environment, RayLightningEnvironment
):
raise RuntimeError(
"Invalid cluster environment plugin. The expected class is"
"`ray.train.lightning.RayLightningEnvironment` "
f"but got {type(cluster_environment)}!"
)

return trainer


class RayIterableDataset(IterableDataset):
def __init__(self, dataset: "DataIterator", config: Dict[str, Any]) -> None:
super().__init__()
Expand Down
28 changes: 27 additions & 1 deletion python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from pytorch_lightning.callbacks import Callback
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from ray.air import session

import ray
from ray.air import session, Checkpoint


class LinearModule(pl.LightningModule):
Expand Down Expand Up @@ -167,3 +173,23 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
x = batch
logits = self.forward(x)
return torch.argmax(logits, dim=-1)


class RayTrainReportCallback(Callback):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
def on_train_epoch_end(self, trainer, pl_module) -> None:
with TemporaryDirectory() as tmpdir:
# Fetch metrics
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}

# (Optional) Add customized metrics
metrics["epoch"] = trainer.current_epoch
metrics["steps"] = trainer.global_step

# Save checkpoint to local
ckpt_path = os.path.join(tmpdir, f"ckpt_epoch_{trainer.current_epoch}.pth")
trainer.save_checkpoint(ckpt_path, weights_only=False)

# Report to train session
checkpoint = Checkpoint.from_directory(tmpdir)
ray.train.report(metrics=metrics, checkpoint=checkpoint)
154 changes: 154 additions & 0 deletions python/ray/train/tests/test_torch_lightning_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import pytest
import numpy as np

import ray
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
get_devices,
RayDDPStrategy,
RayFSDPStrategy,
RayLightningEnvironment,
)

from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.tests.lightning_test_utils import (
LinearModule,
DummyDataModule,
RayTrainReportCallback,
)
import pytorch_lightning as pl


@pytest.fixture
def ray_start_6_cpus_2_gpus():
address_info = ray.init(num_cpus=6, num_gpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.mark.parametrize("strategy_name", ["ddp", "fsdp"])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu"])
@pytest.mark.parametrize("datasource", ["dataloader", "datamodule"])
def test_trainer_with_native_dataloader(
ray_start_6_cpus_2_gpus, strategy_name, accelerator, datasource
):
"""Test basic ddp and fsdp training with dataloader and datamodule. """

if accelerator == "cpu" and strategy_name == "fsdp":
return

num_workers = 2
num_epochs = 4
batch_size = 8
dataset_size = 256

strategy_map = {"ddp": RayDDPStrategy(), "fsdp": RayFSDPStrategy()}

def train_loop():
model = LinearModule(input_dim=32, output_dim=4, strategy=strategy_name)

strategy = strategy_map[strategy_name]

trainer = pl.Trainer(
max_epochs=num_epochs,
devices=get_devices(),
accelerator=accelerator,
strategy=strategy,
plugins=[RayLightningEnvironment()],
callbacks=[RayTrainReportCallback()],
)

datamodule = DummyDataModule(batch_size, dataset_size)

if datasource == "dataloader":
trainer.fit(
model,
train_dataloaders=datamodule.train_dataloader(),
val_dataloaders=datamodule.val_dataloader(),
)
if datasource == "datamodule":
trainer.fit(model, datamodule=datamodule)

trainer = TorchTrainer(
train_loop_per_worker=train_loop,
scaling_config=ScalingConfig(num_workers=2, use_gpu=(accelerator == "gpu")),
)

results = trainer.fit()
assert results.metrics["epoch"] == num_epochs - 1
assert (
results.metrics["steps"] == num_epochs * dataset_size / num_workers / batch_size
)
assert "loss" in results.metrics
assert "val_loss" in results.metrics


@pytest.mark.parametrize("strategy_name", ["ddp", "fsdp"])
@pytest.mark.parametrize("accelerator", ["cpu", "gpu"])
def test_trainer_with_ray_data(ray_start_6_cpus_2_gpus, strategy_name, accelerator):
"""Test Data integration with ddp and fsdp. """

if accelerator == "cpu" and strategy_name == "fsdp":
return

num_epochs = 4
batch_size = 8
num_workers = 2
dataset_size = 256

strategy_map = {"ddp": RayDDPStrategy(), "fsdp": RayFSDPStrategy()}

dataset = np.random.rand(dataset_size, 32).astype(np.float32)
train_dataset = ray.data.from_numpy(dataset)
val_dataset = ray.data.from_numpy(dataset)

def train_loop():
model = LinearModule(input_dim=32, output_dim=4, strategy=strategy_name)

strategy = strategy_map[strategy_name]

trainer = pl.Trainer(
max_epochs=num_epochs,
devices=get_devices(),
accelerator=accelerator,
strategy=strategy,
plugins=[RayLightningEnvironment()],
callbacks=[RayTrainReportCallback()],
)

train_data_iterable = session.get_dataset_shard("train").iter_torch_batches(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
batch_size=batch_size
)
val_data_iterable = session.get_dataset_shard("val").iter_torch_batches(
batch_size=batch_size
)

trainer.fit(
model,
train_dataloaders=train_data_iterable,
val_dataloaders=val_data_iterable,
)

trainer = TorchTrainer(
train_loop_per_worker=train_loop,
scaling_config=ScalingConfig(num_workers=2, use_gpu=(accelerator == "gpu")),
datasets={"train": train_dataset, "val": val_dataset},
)

results = trainer.fit()
assert results.metrics["epoch"] == num_epochs - 1
assert (
results.metrics["step"] == num_epochs * dataset_size / num_workers / batch_size
)
assert "loss" in results.metrics
assert "val_loss" in results.metrics


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
Loading