Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified TorchTrainer] Add PyTorch Lightning Trainer Utilities #37989

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3a7dc6e
init
woshiyyya Aug 1, 2023
c7b91a7
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 1, 2023
4ef88c9
add docstring
woshiyyya Aug 1, 2023
7b5da30
Apply suggestions from code review
woshiyyya Aug 2, 2023
95e2b9b
add deepspeed tests
woshiyyya Aug 2, 2023
4df1ec6
fix import error
woshiyyya Aug 2, 2023
3914a21
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 2, 2023
21a8ce3
make report callback public
woshiyyya Aug 2, 2023
abdf842
fix lint
woshiyyya Aug 3, 2023
d3b01f0
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 3, 2023
0ecd61a
use new ray.train api
woshiyyya Aug 3, 2023
0de4bc4
switch to new api in LightningTrainer
woshiyyya Aug 7, 2023
2b2589e
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 7, 2023
7c1b1e2
WIP: add lightning user guides
woshiyyya Aug 7, 2023
e8285ab
change default ckpt name
woshiyyya Aug 7, 2023
6c326a0
add migration guides and api
woshiyyya Aug 8, 2023
2f9f5a4
Apply suggestions from code review
woshiyyya Aug 8, 2023
4f1412f
Update doc/source/train/distributed-pytorch/migration-guides.rst
woshiyyya Aug 8, 2023
c9f3c03
address comments
woshiyyya Aug 8, 2023
8264549
change the resume training snippet
woshiyyya Aug 8, 2023
ba81c15
add mnist example
woshiyyya Aug 8, 2023
89c7388
fix doc lint
woshiyyya Aug 8, 2023
0d9bdaf
Merge branch 'master' into train/unified-api/add_lightning_utilities
woshiyyya Aug 8, 2023
c186f5b
fix
woshiyyya Aug 8, 2023
6925a05
fixing
woshiyyya Aug 8, 2023
079ddec
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 8, 2023
7c19ffe
fix ut
woshiyyya Aug 8, 2023
8a65e36
Merge remote-tracking branch 'upstream/master' into train/unified-api…
woshiyyya Aug 9, 2023
9d9ed17
update semgrep
woshiyyya Aug 9, 2023
87aa763
Apply suggestions from code review
woshiyyya Aug 9, 2023
74717af
address comments
woshiyyya Aug 9, 2023
a4213a2
Merge branch 'master' into train/unified-api/add_lightning_utilities
woshiyyya Aug 9, 2023
4e06c89
fix func name
woshiyyya Aug 9, 2023
1396883
fix ckpt path
woshiyyya Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_torch_lightning_train",
size = "large",
srcs = ["tests/test_torch_lightning_train.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu", "ptl_v2"],
deps = [":train_lib"]
)

py_test(
name = "test_minimal",
size = "small",
Expand Down
20 changes: 17 additions & 3 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,30 @@
)
# isort: on

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_predictor import LightningPredictor
from ray.train.lightning.lightning_trainer import (
LightningTrainer,
LightningConfigBuilder,
LightningTrainer,
)
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_predictor import LightningPredictor
from ray.train.lightning.lightning_utils import (
prepare_trainer,
RayDDPStrategy,
RayFSDPStrategy,
RayDeepSpeedStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
)

__all__ = [
"LightningTrainer",
"LightningConfigBuilder",
"LightningCheckpoint",
"LightningPredictor",
"prepare_trainer",
"RayDDPStrategy",
"RayFSDPStrategy",
"RayDeepSpeedStrategy",
"RayLightningEnvironment",
"RayTrainReportCallback",
]
6 changes: 3 additions & 3 deletions python/ray/train/lightning/lightning_trainer.py
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from ray.train.torch import TorchTrainer
from ray.train.torch.config import TorchConfig
from ray.util import PublicAPI
from ray.train.lightning._lightning_utils import (
from ray.train.lightning.lightning_utils import (
RayDDPStrategy,
RayFSDPStrategy,
RayDeepSpeedStrategy,
RayEnvironment,
RayLightningEnvironment,
RayDataModule,
RayModelCheckpoint,
get_worker_root_device,
Expand Down Expand Up @@ -586,7 +586,7 @@ def _lightning_train_loop_per_worker(config):
for plugin in trainer_config.get("plugins", [])
if not isinstance(plugin, ClusterEnvironment)
]
trainer_config["plugins"].append(RayEnvironment())
trainer_config["plugins"].append(RayLightningEnvironment())

# Setup ddp strategy for ray orchestration
if "strategy" in trainer_config:
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this with the original _lightning_utils name.

Copy link
Member Author

@woshiyyya woshiyyya Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because there are still private methods/classes in that file? Should we remove the underscore after we fully deprecated LightningTrainer?

Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from ray.air.constants import MODEL_KEY
from ray.data.dataset import DataIterator
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.util import PublicAPI

import logging
import shutil
import torch
import tempfile
from tempfile import TemporaryDirectory
from ray.train import Checkpoint
from packaging.version import Version
from typing import Any, Dict, Optional
from torch.utils.data import IterableDataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy

Expand Down Expand Up @@ -49,6 +52,7 @@ def get_worker_root_device():
return devices


@PublicAPI(stability="alpha")
class RayDDPStrategy(DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration."""

Expand All @@ -64,6 +68,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
)


@PublicAPI(stability="alpha")
class RayFSDPStrategy(FSDPStrategy):
"""Subclass of FSDPStrategy to ensure compatibility with Ray orchestration."""

Expand Down Expand Up @@ -98,19 +103,10 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:
return super().lightning_module_state_dict()


@PublicAPI(stability="alpha")
class RayDeepSpeedStrategy(DeepSpeedStrategy):
"""Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration."""

def setup_distributed(self):
# We have to set the device ids for each node
# e.g. CUDA_VISIBLE_DEVICES = 2,3
# worker 0: LOCAL_RANK=0, parallel devices = [cuda:0, cuda:1]
# worker 1: LOCAL_RANK=1, parallel devices = [cuda:0, cuda:1]
self.parallel_devices = [
torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
]
super().setup_distributed()

@property
def root_device(self) -> torch.device:
return get_worker_root_device()
Expand All @@ -123,7 +119,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]:
)


class RayEnvironment(LightningEnvironment):
@PublicAPI(stability="alpha")
class RayLightningEnvironment(LightningEnvironment):
"""Setup Lightning DDP training environment for Ray cluster."""

def world_size(self) -> int:
Expand All @@ -150,6 +147,58 @@ def teardown(self):
pass


@PublicAPI(stability="alpha")
def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:
"""Prepare the PyTorch Lightning Trainer for distributed execution."""

# Check strategy class
valid_strategy_class = [RayDDPStrategy, RayFSDPStrategy, RayDeepSpeedStrategy]

if not any(isinstance(trainer.strategy, cls) for cls in valid_strategy_class):
raise RuntimeError(
f"Invalid strategy class: {type(trainer.strategy)}. To use "
"PyTorch Lightning with Ray, the strategy object should be one of "
f"{[cls.__name__ for cls in valid_strategy_class]} class "
"or its subclass."
)

# Check cluster environment
cluster_environment = getattr(trainer.strategy, "cluster_environment", None)
if cluster_environment and not isinstance(
cluster_environment, RayLightningEnvironment
):
raise RuntimeError(
"Invalid cluster environment plugin. The expected class is"
"`ray.train.lightning.RayLightningEnvironment` "
f"but got {type(cluster_environment)}!"
)

return trainer


@PublicAPI(stability="alpha")
class RayTrainReportCallback(Callback):
"""A simple callback that reports checkpoints to Ray on train epoch end."""

def on_train_epoch_end(self, trainer, pl_module) -> None:
with TemporaryDirectory() as tmpdir:
# Fetch metrics
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}

# (Optional) Add customized metrics
metrics["epoch"] = trainer.current_epoch
metrics["steps"] = trainer.global_step
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

# Save checkpoint to local
ckpt_path = os.path.join(tmpdir, f"ckpt_epoch_{trainer.current_epoch}")
trainer.save_checkpoint(ckpt_path, weights_only=False)

# Report to train session
checkpoint = Checkpoint.from_directory(tmpdir)
ray.train.report(metrics=metrics, checkpoint=checkpoint)


class RayIterableDataset(IterableDataset):
def __init__(self, dataset: "DataIterator", config: Dict[str, Any]) -> None:
super().__init__()
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from ray import train


Expand Down
Loading
Loading