Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Update PyTorch Lightning import path #39841

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"## Set up ray cluster \n",
"In this example, we are using a Ray cluster with a `g4dn.8xlarge` head node and 15 `g4dn.4xlarge` worker nodes. Each instance has one Tesla T4 GPU (16GiB Memory). \n",
"\n",
"We define a `runtime_env` to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers' base image. We tested this example with `pytorch_lightning==2.0.2` and `transformers==4.29.2`."
"We define a `runtime_env` to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers' base image. We tested this example with `lightning==2.0.2` and `transformers==4.29.2`."
]
},
{
Expand All @@ -47,7 +47,7 @@
" \"evaluate\",\n",
" \"transformers>=4.26.0\",\n",
" \"torch>=1.12.0\",\n",
" \"pytorch_lightning>=2.0\",\n",
" \"lightning>=2.0\",\n",
" ]\n",
" }\n",
")"
Expand Down Expand Up @@ -225,7 +225,7 @@
"outputs": [],
"source": [
"import torch\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"\n",
"class DollyV2Model(pl.LightningModule):\n",
" def __init__(self, lr=2e-5, eps=1e-8):\n",
Expand Down Expand Up @@ -284,7 +284,7 @@
"outputs": [],
"source": [
"import functools\n",
"import pytorch_lightning as pl \n",
"import lightning.pytorch as pl \n",
"\n",
"from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy\n",
"from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch\n",
Expand Down Expand Up @@ -415,7 +415,7 @@
}
],
"source": [
"from pytorch_lightning.callbacks import TQDMProgressBar\n",
"from lightning.pytorch.callbacks import TQDMProgressBar\n",
"\n",
"# Create a customized progress bar for Ray Data Iterable Dataset\n",
"class DollyV2ProgressBar(TQDMProgressBar):\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
" \"deepspeed==0.9.4\",\n",
" \"accelerate==0.20.3\",\n",
" \"transformers==4.30.2\",\n",
" \"pytorch_lightning==2.0.3\",\n",
" \"lightning==2.0.3\",\n",
" ],\n",
" \"env_vars\": {\"RAY_AIR_NEW_PERSISTENCE_MODE\": \"1\"} # TODO(@justinvyu): Remove it\n",
" }\n",
Expand Down Expand Up @@ -275,7 +275,7 @@
"source": [
"import torch\n",
"import transformers\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from deepspeed.ops.adam import DeepSpeedCPUAdam\n",
"\n",
Expand Down Expand Up @@ -1180,7 +1180,7 @@
"source": [
"import torch\n",
"import ray\n",
"import pytorch_lightning as pl\n",
"import lightning.pytorch as pl\n",
"from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM\n",
"from accelerate import (\n",
" init_empty_weights,\n",
Expand Down
22 changes: 13 additions & 9 deletions doc/source/train/getting-started-pytorch-lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Compare a PyTorch Lightning training script with and without Ray Train.
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import lightning.pytorch as pl

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
Expand Down Expand Up @@ -91,7 +91,7 @@ Compare a PyTorch Lightning training script with and without Ray Train.
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import lightning.pytorch as pl

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
Expand Down Expand Up @@ -167,7 +167,7 @@ make a few changes to your Lightning Trainer definition.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning
Expand Down Expand Up @@ -207,7 +207,7 @@ sampler arguments.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning

Expand All @@ -231,7 +231,7 @@ local, global, and node rank and world size.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

Expand All @@ -256,7 +256,7 @@ GPUs by setting ``devices="auto"`` and ``acelerator="auto"``.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl

def train_func(config):
...
Expand All @@ -280,7 +280,7 @@ To persist your checkpoints and monitor training progress, add a

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
from ray.train.lightning import RayTrainReportCallback

def train_func(config):
Expand All @@ -306,7 +306,7 @@ your configurations.

.. code-block:: diff

import pytorch_lightning as pl
import lightning.pytorch as pl
import ray.train.lightning

def train_func(config):
Expand Down Expand Up @@ -378,6 +378,10 @@ Ray Train is tested with `pytorch_lightning` versions `1.6.5` and `2.0.4`. For f
Earlier versions aren't prohibited but may result in unexpected issues. If you run into any compatibility issues, consider upgrading your PyTorch Lightning version or
`file an issue <https://github.com/ray-project/ray/issues>`_.

.. note::

If you are using Lightning 2.x, please use the import path `lightning.pytorch.xxx` instead of `pytorch_lightning.xxx`.

.. _lightning-trainer-migration-guide:

LightningTrainer Migration Guide
Expand Down Expand Up @@ -447,7 +451,7 @@ control over their native Lightning code.

.. code-block:: python

import pytorch_lightning as pl
import lightning.pytorch as pl
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
Expand Down
13 changes: 8 additions & 5 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# isort: off
try:
import pytorch_lightning # noqa: F401
import lightning # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
"please run 'pip install pytorch-lightning'"
)
try:
import pytorch_lightning # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"PyTorch Lightning isn't installed. To install PyTorch Lightning, "
"please run 'pip install lightning'"
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
# isort: on

from ray.train.lightning.lightning_trainer import (
Expand Down
48 changes: 29 additions & 19 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,43 @@
import os
import logging
import ray
import shutil
import tempfile
import torch

from ray import train
from ray.air.constants import MODEL_KEY
from ray.data.dataset import DataIterator
from ray.util import PublicAPI
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag

import logging
import shutil
import torch
import tempfile
from ray.train import Checkpoint
from ray.train._internal.storage import _use_storage_context
from ray.train.lightning.lightning_checkpoint import (
LightningCheckpoint,
LegacyLightningCheckpoint,
)

from packaging.version import Version
from typing import Any, Dict, Optional
from torch.utils.data import IterableDataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy

def import_lightning(): # noqa: F402
try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl
return pl


pl = import_lightning()

_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()

if _LIGHTNING_GREATER_EQUAL_2_0:
from pytorch_lightning.strategies import FSDPStrategy
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.plugins.environments import LightningEnvironment
else:
from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy
from pytorch_lightning.plugins.environments import LightningEnvironment

if _TORCH_FSDP_AVAILABLE:
from torch.distributed.fsdp import (
Expand All @@ -57,7 +62,7 @@ def get_worker_root_device():


@PublicAPI(stability="beta")
class RayDDPStrategy(DDPStrategy):
class RayDDPStrategy(pl.strategies.DDPStrategy):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

For a full list of initialization arguments, please refer to:
Expand Down Expand Up @@ -124,7 +129,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:


@PublicAPI(stability="beta")
class RayDeepSpeedStrategy(DeepSpeedStrategy):
class RayDeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
"""Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

For a full list of initialization arguments, please refer to:
Expand Down Expand Up @@ -210,7 +215,7 @@ def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:


@PublicAPI(stability="beta")
class RayTrainReportCallback(Callback):
class RayTrainReportCallback(pl.callbacks.Callback):
"""A simple callback that reports checkpoints to Ray on train epoch end."""

def __init__(self) -> None:
Expand Down Expand Up @@ -288,7 +293,7 @@ def _val_dataloader() -> DataLoader:
self.val_dataloader = _val_dataloader


class RayModelCheckpoint(ModelCheckpoint):
class RayModelCheckpoint(pl.callbacks.ModelCheckpoint):
"""
AIR customized ModelCheckpoint callback.

Expand All @@ -306,7 +311,7 @@ def setup(
super().setup(trainer, pl_module, stage)
self.is_checkpoint_step = False

if isinstance(trainer.strategy, DeepSpeedStrategy):
if isinstance(trainer.strategy, pl.strategies.DeepSpeedStrategy):
# For DeepSpeed, each node has a unique set of param and optimizer states,
# so the local rank 0 workers report the checkpoint shards for all workers
# on their node.
Expand All @@ -324,6 +329,11 @@ def _session_report(self, trainer: "pl.Trainer", stage: str):
the latest metrics.
"""

from ray.train.lightning.lightning_checkpoint import (
LightningCheckpoint,
LegacyLightningCheckpoint,
)

# Align the frequency of checkpointing and logging
if not self.is_checkpoint_step:
return
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import logging
import pytorch_lightning as pl
import tempfile
import shutil

Expand All @@ -13,6 +12,9 @@
from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
from ray.train.torch import LegacyTorchCheckpoint
from ray.util.annotations import PublicAPI
from ray.train.lightning._lightning_utils import import_lightning

pl = import_lightning()

logger = logging.getLogger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion python/ray/train/lightning/lightning_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from ray.data.preprocessor import Preprocessor
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning._lightning_utils import import_lightning
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI
import pytorch_lightning as pl

pl = import_lightning()


logger = logging.getLogger(__name__)

Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pytorch_lightning as pl

from copy import copy
from inspect import isclass
Expand All @@ -23,13 +22,15 @@
RayDataModule,
RayModelCheckpoint,
prepare_trainer,
import_lightning,
)


import logging

logger = logging.getLogger(__name__)

pl = import_lightning()


LIGHTNING_CONFIG_BUILDER_DEPRECATION_MESSAGE = (
"The LightningConfigBuilder will be hard deprecated in Ray 2.8. "
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from torchmetrics import Accuracy

from ray import train
from ray.train.lightning._lightning_utils import import_lightning

pl = import_lightning()


class LinearModule(pl.LightningModule):
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/tests/test_lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytorch_lightning as pl
import torch
import torch.nn as nn
import tempfile
Expand All @@ -12,6 +11,9 @@
LightningConfigBuilder,
LightningTrainer,
)
from ray.train.lightning._lightning_utils import import_lightning

pl = import_lightning()


class Net(pl.LightningModule):
Expand Down