From 50f2f3d0f6dd1aa7d18c5ee0e36d37f3f82deb98 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 11 Oct 2023 10:49:20 -0700 Subject: [PATCH 1/3] init Signed-off-by: woshiyyya --- python/ray/train/lightning/_lightning_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index cc11f3a96426f..d23f9beb6626a 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -31,10 +31,10 @@ def import_lightning(): # noqa: F402 _TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0") _TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available() -if _LIGHTNING_GREATER_EQUAL_2_0: +try: from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import FSDPStrategy -else: +except ModuleNotFoundError: from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy From 2d54e972b770ccb1b67f525a14aaafe2f20184c0 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Wed, 11 Oct 2023 12:55:31 -0700 Subject: [PATCH 2/3] dynamically import modules Signed-off-by: woshiyyya --- python/ray/train/lightning/_lightning_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index d23f9beb6626a..b931b6a345ea6 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -31,12 +31,14 @@ def import_lightning(): # noqa: F402 _TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0") _TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available() -try: - from lightning.pytorch.plugins.environments import LightningEnvironment - from lightning.pytorch.strategies import FSDPStrategy -except ModuleNotFoundError: - from pytorch_lightning.plugins.environments import LightningEnvironment - from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy +# Dynamically load lightning modules +exec(f"from {pl.__name__}.plugins.environments import LightningEnvironment") +if _LIGHTNING_GREATER_EQUAL_2_0: + exec(f"from {pl.__name__}.strategies import FSDPStrategy") +else: + exec( + f"from {pl.__name__}.strategies import DDPFullyShardedStrategy as FSDPStrategy" + ) if _TORCH_FSDP_AVAILABLE: from torch.distributed.fsdp import ( @@ -85,7 +87,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]: @PublicAPI(stability="beta") -class RayFSDPStrategy(FSDPStrategy): +class RayFSDPStrategy(FSDPStrategy): # noqa: F821 """Subclass of FSDPStrategy to ensure compatibility with Ray orchestration. For a full list of initialization arguments, please refer to: @@ -152,7 +154,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]: @PublicAPI(stability="beta") -class RayLightningEnvironment(LightningEnvironment): +class RayLightningEnvironment(LightningEnvironment): # noqa: F821 """Setup Lightning DDP training environment for Ray cluster.""" def __init__(self, *args, **kwargs): From 96b4473baa77e8126fe99083050e9ffc0285b817 Mon Sep 17 00:00:00 2001 From: woshiyyya Date: Fri, 13 Oct 2023 14:18:43 -0700 Subject: [PATCH 3/3] change it back Signed-off-by: woshiyyya --- python/ray/train/lightning/_lightning_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index b931b6a345ea6..47311a1f431d5 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -31,14 +31,15 @@ def import_lightning(): # noqa: F402 _TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0") _TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available() -# Dynamically load lightning modules -exec(f"from {pl.__name__}.plugins.environments import LightningEnvironment") +try: + from lightning.pytorch.plugins.environments import LightningEnvironment +except ModuleNotFoundError: + from pytorch_lightning.plugins.environments import LightningEnvironment + if _LIGHTNING_GREATER_EQUAL_2_0: - exec(f"from {pl.__name__}.strategies import FSDPStrategy") + FSDPStrategy = pl.strategies.FSDPStrategy else: - exec( - f"from {pl.__name__}.strategies import DDPFullyShardedStrategy as FSDPStrategy" - ) + FSDPStrategy = pl.strategies.DDPFullyShardedStrategy if _TORCH_FSDP_AVAILABLE: from torch.distributed.fsdp import (