Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Update PyTorch Lightning import path #39841

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"## Set up ray cluster \n",
"In this example, we are using a Ray cluster with a `g4dn.8xlarge` head node and 15 `g4dn.4xlarge` worker nodes. Each instance has one Tesla T4 GPU (16GiB Memory). \n",
"\n",
"We define a `runtime_env` to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers' base image. We tested this example with `pytorch_lightning==2.0.2` and `transformers==4.29.2`."
"We define a `runtime_env` to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers' base image. We tested this example with `lightning==2.0.2` and `transformers==4.29.2`."
]
},
{
Expand All @@ -47,7 +47,7 @@
" \"evaluate\",\n",
" \"transformers>=4.26.0\",\n",
" \"torch>=1.12.0\",\n",
" \"pytorch_lightning>=2.0\",\n",
" \"lightning>=2.0\",\n",
" ]\n",
" }\n",
")"
Expand Down Expand Up @@ -225,7 +225,7 @@
"outputs": [],
"source": [
"import torch\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"\n",
"class DollyV2Model(pl.LightningModule):\n",
" def __init__(self, lr=2e-5, eps=1e-8):\n",
Expand Down Expand Up @@ -284,7 +284,7 @@
"outputs": [],
"source": [
"import functools\n",
"import pytorch_lightning as pl \n",
"import lightning.pytorch as pl \n",
"\n",
"from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\n",
"from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch\n",
Expand Down Expand Up @@ -415,7 +415,7 @@
}
],
"source": [
"from pytorch_lightning.callbacks import TQDMProgressBar\n",
"from lightning.pytorch.callbacks import TQDMProgressBar\n",
"\n",
"# Create a customized progress bar for Ray Data Iterable Dataset\n",
"class DollyV2ProgressBar(TQDMProgressBar):\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
" \"deepspeed==0.9.4\",\n",
" \"accelerate==0.20.3\",\n",
" \"transformers==4.30.2\",\n",
" \"pytorch_lightning==2.0.3\",\n",
" \"lightning==2.0.3\",\n",
" ],\n",
" \"env_vars\": {\"RAY_AIR_NEW_PERSISTENCE_MODE\": \"1\"} # TODO(@justinvyu): Remove it\n",
" }\n",
Expand Down Expand Up @@ -275,7 +275,7 @@
"source": [
"import torch\n",
"import transformers\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from deepspeed.ops.adam import DeepSpeedCPUAdam\n",
"\n",
Expand Down Expand Up @@ -1180,7 +1180,7 @@
"source": [
"import torch\n",
"import ray\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM\n",
"from accelerate import (\n",
" init_empty_weights,\n",
Expand Down
22 changes: 13 additions & 9 deletions doc/source/train/getting-started-pytorch-lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Compare a PyTorch Lightning training script with and without Ray Train.
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import lightning.pytorch as pl

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
Expand Down Expand Up @@ -91,7 +91,7 @@ Compare a PyTorch Lightning training script with and without Ray Train.
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import lightning.pytorch as pl

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
Expand Down Expand Up @@ -167,7 +167,7 @@ make a few changes to your Lightning Trainer definition.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning
Expand Down Expand Up @@ -207,7 +207,7 @@ sampler arguments.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning

Expand All @@ -231,7 +231,7 @@ local, global, and node rank and world size.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

Expand All @@ -256,7 +256,7 @@ GPUs by setting ``devices="auto"`` and ``acelerator="auto"``.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl

def train_func(config):
...
Expand All @@ -280,7 +280,7 @@ To persist your checkpoints and monitor training progress, add a

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
from ray.train.lightning import RayTrainReportCallback

def train_func(config):
Expand All @@ -306,7 +306,7 @@ your configurations.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
import ray.train.lightning

def train_func(config):
Expand Down Expand Up @@ -378,6 +378,10 @@ Ray Train is tested with `pytorch_lightning` versions `1.6.5` and `2.0.4`. For f
Earlier versions aren't prohibited but may result in unexpected issues. If you run into any compatibility issues, consider upgrading your PyTorch Lightning version or
`file an issue <https://github.com/ray-project/ray/issues>`_.

.. note::

If you are using Lightning 2.x, please use the import path `lightning.pytorch.xxx` instead of `pytorch_lightning.xxx`.

.. _lightning-trainer-migration-guide:

LightningTrainer Migration Guide
Expand Down Expand Up @@ -447,7 +451,7 @@ control over their native Lightning code.

.. code-block:: python

import pytorch_lightning as pl
import lightning.pytorch as pl
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
Expand Down
13 changes: 8 additions & 5 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# isort: off
try:
import pytorch_lightning # noqa: F401
import lightning # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
"please run 'pip install pytorch-lightning'"
)
try:
import pytorch_lightning # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
"please run 'pip install lightning'"
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
# isort: on

from ray.train.lightning._lightning_utils import (
Expand Down
32 changes: 21 additions & 11 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
import tempfile
from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch
from packaging.version import Version
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from torch.utils.data import DataLoader, IterableDataset

import ray
Expand All @@ -18,16 +14,28 @@
from ray.air.constants import MODEL_KEY
from ray.data.dataset import DataIterator
from ray.train import Checkpoint
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.util import PublicAPI


def import_lightning(): # noqa: F402
try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl
return pl


pl = import_lightning()

_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()

if _LIGHTNING_GREATER_EQUAL_2_0:
from pytorch_lightning.strategies import FSDPStrategy
from lightning.pytorch.plugins.environments import LightningEnvironment
from lightning.pytorch.strategies import FSDPStrategy
else:
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy

if _TORCH_FSDP_AVAILABLE:
Expand All @@ -53,7 +61,7 @@ def get_worker_root_device():


@PublicAPI(stability="beta")
class RayDDPStrategy(DDPStrategy):
class RayDDPStrategy(pl.strategies.DDPStrategy):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

For a full list of initialization arguments, please refer to:
Expand Down Expand Up @@ -120,7 +128,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:


@PublicAPI(stability="beta")
class RayDeepSpeedStrategy(DeepSpeedStrategy):
class RayDeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
"""Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

For a full list of initialization arguments, please refer to:
Expand Down Expand Up @@ -206,7 +214,7 @@ def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:


@PublicAPI(stability="beta")
class RayTrainReportCallback(Callback):
class RayTrainReportCallback(pl.callbacks.Callback):
"""A simple callback that reports checkpoints to Ray on train epoch end."""

def __init__(self) -> None:
Expand Down Expand Up @@ -284,7 +292,7 @@ def _val_dataloader() -> DataLoader:
self.val_dataloader = _val_dataloader


class RayModelCheckpoint(ModelCheckpoint):
class RayModelCheckpoint(pl.callbacks.ModelCheckpoint):
"""
AIR customized ModelCheckpoint callback.

Expand All @@ -302,7 +310,7 @@ def setup(
super().setup(trainer, pl_module, stage)
self.is_checkpoint_step = False

if isinstance(trainer.strategy, DeepSpeedStrategy):
if isinstance(trainer.strategy, pl.strategies.DeepSpeedStrategy):
# For DeepSpeed, each node has a unique set of param and optimizer states,
# so the local rank 0 workers report the checkpoint shards for all workers
# on their node.
Expand All @@ -320,6 +328,8 @@ def _session_report(self, trainer: "pl.Trainer", stage: str):
the latest metrics.
"""

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint

# Align the frequency of checkpointing and logging
if not self.is_checkpoint_step:
return
Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from inspect import isclass
from typing import Any, Dict, Optional, Type

import pytorch_lightning as pl

from ray.air.constants import MODEL_KEY
from ray.data import Preprocessor
from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
from ray.train.lightning._lightning_utils import import_lightning
from ray.util.annotations import PublicAPI

pl = import_lightning()

logger = logging.getLogger(__name__)


Expand Down
6 changes: 4 additions & 2 deletions python/ray/train/lightning/lightning_predictor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from typing import Optional, Type

import pytorch_lightning as pl

from ray.data.preprocessor import Preprocessor
from ray.train.lightning._lightning_utils import import_lightning
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI

pl = import_lightning()


logger = logging.getLogger(__name__)


Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from inspect import isclass
from typing import Any, Dict, Optional, Type

import pytorch_lightning as pl

from ray.air import session
from ray.air.constants import MODEL_KEY
from ray.data.preprocessor import Preprocessor
Expand All @@ -17,6 +15,7 @@
RayFSDPStrategy,
RayLightningEnvironment,
RayModelCheckpoint,
import_lightning,
prepare_trainer,
)
from ray.train.torch import TorchTrainer
Expand All @@ -26,6 +25,8 @@

logger = logging.getLogger(__name__)

pl = import_lightning()


LIGHTNING_CONFIG_BUILDER_DEPRECATION_MESSAGE = (
"The LightningConfigBuilder will be hard deprecated in Ray 2.8. "
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from ray import train
from ray.train.lightning._lightning_utils import import_lightning

pl = import_lightning()


class LinearModule(pl.LightningModule):
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/tests/test_lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tempfile

import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
Expand All @@ -12,8 +11,11 @@
LightningConfigBuilder,
LightningTrainer,
)
from ray.train.lightning._lightning_utils import import_lightning
from ray.train.tests.lightning_test_utils import DummyDataModule, LinearModule

pl = import_lightning()


class Net(pl.LightningModule):
def __init__(self, input_dim, output_dim) -> None:
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/tests/test_lightning_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import numpy as np
import pytest
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from ray.air.constants import MAX_REPR_LENGTH, MODEL_KEY
from ray.train.lightning import LightningCheckpoint, LightningPredictor
from ray.train.lightning._lightning_utils import import_lightning
from ray.train.tests.conftest import * # noqa
from ray.train.tests.dummy_preprocessor import DummyPreprocessor
from ray.train.tests.lightning_test_utils import LightningMNISTClassifier

pl = import_lightning()


def test_repr():
model = pl.LightningModule()
Expand Down