Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Update PyTorch Lightning import path #39841

Merged
13 changes: 8 additions & 5 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# isort: off
try:
import pytorch_lightning # noqa: F401
import lightning # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
"please run 'pip install pytorch-lightning'"
)
try:
import pytorch_lightning # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
"please run 'pip install lightning'"
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
# isort: on

from ray.train.lightning.lightning_trainer import (
Expand Down
22 changes: 12 additions & 10 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@
from typing import Any, Dict, Optional
from torch.utils.data import IterableDataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()

if _LIGHTNING_GREATER_EQUAL_2_0:
from pytorch_lightning.strategies import FSDPStrategy
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.plugins.environments import LightningEnvironment
else:
from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy
from pytorch_lightning.plugins.environments import LightningEnvironment

if _TORCH_FSDP_AVAILABLE:
from torch.distributed.fsdp import (
Expand All @@ -57,7 +59,7 @@ def get_worker_root_device():


@PublicAPI(stability="beta")
class RayDDPStrategy(DDPStrategy):
class RayDDPStrategy(pl.strategies.DDPStrategy):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

For a full list of initialization arguments, please refer to:
Expand Down Expand Up @@ -124,7 +126,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:


@PublicAPI(stability="beta")
class RayDeepSpeedStrategy(DeepSpeedStrategy):
class RayDeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
"""Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

For a full list of initialization arguments, please refer to:
Expand Down Expand Up @@ -210,7 +212,7 @@ def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:


@PublicAPI(stability="beta")
class RayTrainReportCallback(Callback):
class RayTrainReportCallback(pl.callbacks.Callback):
"""A simple callback that reports checkpoints to Ray on train epoch end."""

def __init__(self) -> None:
Expand Down Expand Up @@ -288,7 +290,7 @@ def _val_dataloader() -> DataLoader:
self.val_dataloader = _val_dataloader


class RayModelCheckpoint(ModelCheckpoint):
class RayModelCheckpoint(pl.callbacks.ModelCheckpoint):
"""
AIR customized ModelCheckpoint callback.

Expand All @@ -306,7 +308,7 @@ def setup(
super().setup(trainer, pl_module, stage)
self.is_checkpoint_step = False

if isinstance(trainer.strategy, DeepSpeedStrategy):
if isinstance(trainer.strategy, pl.strategies.DeepSpeedStrategy):
# For DeepSpeed, each node has a unique set of param and optimizer states,
# so the local rank 0 workers report the checkpoint shards for all workers
# on their node.
Expand Down
7 changes: 6 additions & 1 deletion python/ray/train/lightning/lightning_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI
import pytorch_lightning as pl

try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl


logger = logging.getLogger(__name__)

Expand Down
7 changes: 6 additions & 1 deletion python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
import pytorch_lightning as pl

try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl


from copy import copy
from inspect import isclass
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from torchmetrics import Accuracy
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/tests/test_lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import pytorch_lightning as pl
try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl

import torch
import torch.nn as nn
import tempfile
Expand Down
7 changes: 6 additions & 1 deletion python/ray/train/tests/test_lightning_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import torch

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader

try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl


from ray.air.constants import MAX_REPR_LENGTH, MODEL_KEY
from ray.train.tests.conftest import * # noqa
from ray.train.lightning import LightningCheckpoint, LightningPredictor
Expand Down
9 changes: 6 additions & 3 deletions python/ray/train/tests/test_lightning_trainer_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import numpy as np
from pathlib import Path
import pytest
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl

import ray
from ray.train import RunConfig, CheckpointConfig
Expand Down Expand Up @@ -147,7 +150,7 @@ def test_air_trainer_restore(

if resume_from_ckpt_path:
ckpt_dir = f"{tmpdir}/ckpts"
callback = ModelCheckpoint(dirpath=ckpt_dir, save_last=True)
callback = pl.callbacks.ModelCheckpoint(dirpath=ckpt_dir, save_last=True)
pl_trainer = pl.Trainer(
max_epochs=init_epoch, accelerator="cpu", callbacks=[callback]
)
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/tests/test_torch_lightning_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
LinearModule,
DummyDataModule,
)
import pytorch_lightning as pl

try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl


@pytest.fixture
Expand Down