diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index 4d2c90987b3dd..1495fdcf21dce 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -1,26 +1,38 @@ +import ray +from ray.air import session +from ray.air.constants import MODEL_KEY +from ray.data.datastream import DataIterator +from ray.train.lightning.lightning_checkpoint import LightningCheckpoint + import logging import shutil import torch import tempfile from packaging.version import Version from typing import Any, Dict, Optional +from torch.utils.data import IterableDataset, DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies import DDPStrategy -if Version(pl.__version__) >= Version("2.0.0"): +_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0") +_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0") +_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available() + +if _LIGHTNING_GREATER_EQUAL_2_0: from pytorch_lightning.strategies import FSDPStrategy else: from pytorch_lightning.strategies import DDPFullyShardedStrategy as FSDPStrategy -import ray -from ray.air import session -from ray.air.constants import MODEL_KEY -from ray.train.lightning.lightning_checkpoint import LightningCheckpoint -from torch.utils.data import IterableDataset, DataLoader -from ray.data.datastream import DataIterator +if _TORCH_FSDP_AVAILABLE: + from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel, + StateDictType, + ) + logger = logging.getLogger(__name__) @@ -65,6 +77,25 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]: rank=self.global_rank, ) + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Gathers the full state dict to rank 0 on CPU.""" + assert self.model is not None, "Failed to get the state dict for a None model!" + + if _LIGHTNING_GREATER_EQUAL_2_0 and _TORCH_FSDP_AVAILABLE: + with FullyShardedDataParallel.state_dict_type( + module=self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ), + ): + state_dict = self.model.state_dict() + prefix_len = len("_forward_module.") + return {k[prefix_len:]: v for k, v in state_dict.items()} + else: + # Otherwise Lightning uses Fairscale FSDP, no need to unshard by ourself. + return super().lightning_module_state_dict() + class RayEnvironment(LightningEnvironment): """Setup Lightning DDP training environment for Ray cluster.""" diff --git a/python/ray/train/tests/test_lightning_checkpoint.py b/python/ray/train/tests/test_lightning_checkpoint.py index 64bcd40b32bec..5109fb0a051b3 100644 --- a/python/ray/train/tests/test_lightning_checkpoint.py +++ b/python/ray/train/tests/test_lightning_checkpoint.py @@ -4,9 +4,15 @@ import torch.nn as nn import tempfile -from ray.train.lightning import LightningCheckpoint +import ray from ray.air.constants import MODEL_KEY from torch.utils.data import DataLoader +from ray.train.tests.lightning_test_utils import LinearModule, DummyDataModule +from ray.train.lightning import ( + LightningCheckpoint, + LightningConfigBuilder, + LightningTrainer, +) class Net(pl.LightningModule): @@ -100,6 +106,42 @@ def test_from_directory(): assert torch.equal(output, checkpoint_output) +def test_fsdp_checkpoint(): + num_epochs = 1 + batch_size = 8 + input_dim = 32 + output_dim = 4 + dataset_size = 256 + + datamodule = DummyDataModule(batch_size, dataset_size) + + config_builder = ( + LightningConfigBuilder() + .module( + LinearModule, input_dim=input_dim, output_dim=output_dim, strategy="fsdp" + ) + .trainer(max_epochs=num_epochs, accelerator="gpu") + .strategy("fsdp") + .checkpointing(save_last=True) + .fit_params(datamodule=datamodule) + ) + + scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=True) + + trainer = LightningTrainer( + lightning_config=config_builder.build(), scaling_config=scaling_config + ) + + results = trainer.fit() + + with results.checkpoint.as_directory() as checkpoint_dir: + checkpoint = torch.load(f"{checkpoint_dir}/{MODEL_KEY}") + model = LinearModule(input_dim=input_dim, output_dim=output_dim) + + for key in model.state_dict().keys(): + assert key in checkpoint["state_dict"] + + if __name__ == "__main__": import sys