Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Unified TorchTrainer API use with lightning vs pytorch_lightning #39715

Closed
CMGeldenhuys opened this issue Sep 16, 2023 · 3 comments
Closed
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks ray 2.8 train Ray Train Related Issue

Comments

@CMGeldenhuys
Copy link

What happened + What you expected to happen

TLDR;
Root of the problem, use pytorch_lightning instead of lighting.pytorch

When using the new unified TorchTrainer API with the lightning module instead of pytorch_lightning the following error is produced:

...
  File ".../lib/python3.11/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File ".../lib/python3.11/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
    train_func(*args, **kwargs)
  File ".../bug.py", line 44, in train_func
    trainer = pl.Trainer(
              ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/lightning/pytorch/utilities/argparse.py", line 70, in insert_env_defaults
    return fn(self, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 399, in __init__
    self._accelerator_connector = _AcceleratorConnector(
                                  ^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py", line 140, in __init__
    self._check_config_and_set_final_flags(
  File ".../lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py", line 210, in _check_config_and_set_final_flags
    raise ValueError(
ValueError: You selected an invalid strategy name: `strategy=<ray.train.lightning._lightning_utils.RayDDPStrategy object at 0x7f7327186290>`. It must be either a string or an instance of `lightning.pytorch.strategies.Strategy`. Example choices: auto, ddp, ddp_spawn, deepspeed, ... Find a complete list of options in our documentation at https://lightning.ai

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../bug.py", line 60, in <module>
    result = trainer.fit()
...

The error complains that RayDDPStrategy() is not an instance of Strategy.

A simple test shows that this is probably a name space issue, caused by importing the pl.Trainer from lighting.pytorch instead of the pytorch_lightning module.

In [1]: import pytorch_lightning
In [2]: import lightning
In [3]: from  ray.train.lightning import RayDDPStrategy

In [4]: isinstance(RayDDPStrategy(), pytorch_lightning.strategies.Strategy)
Out[4]: True

In [5]: isinstance(RayDDPStrategy(), lightning.pytorch.strategies.Strategy)
Out[5]: False

I feel this is a common trap for plenty of people to walk into, as lightning docs makes use of lightning.pytorch and not pytorch_lightning. Perhaps it should be made explicit in the ray docs that pytorch_lightning should be used throughout one's own code (including when one imports other lightning classes, such as DataModule or the same error will just happen elsewhere). Or alternatively, ray should use lightning.pytorch but that means pulling in the whole lightning ecosystem as a dependency, which doesn't seem logical. I am uncertain if there is a python hacky-way of getting this to work with both pytorch_lightning and lightning.pytroch.

Versions / Dependencies

ray 2.7.0rc0
pytorch-lightning 2.0.8
lightning 2.0.8
torch 2.0.1
cuda 11.7
python 3.11.5

Reproduction script

# Example from ray.io docs
import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl  # <-- Problematic
# import pytorch_lightning as pl

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import ray.train.lightning

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)


def train_func(config):

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = FashionMNIST(root='./data-tmp', train=True, download=True, transform=transform)
    train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

    # Training
    model = ImageClassifier()
    # [1] Configure PyTorch Lightning Trainer.
    trainer = pl.Trainer(
        max_epochs=10,
        devices="auto",
        accelerator="auto",
        strategy=ray.train.lightning.RayDDPStrategy(),
        plugins=[ray.train.lightning.RayLightningEnvironment()],
        callbacks=[ray.train.lightning.RayTrainReportCallback()],
    )
    trainer = ray.train.lightning.prepare_trainer(trainer)
    trainer.fit(model, train_dataloaders=train_dataloader)

# [2] Configure scaling and resource requirements.
scaling_config = ScalingConfig(num_workers=1, use_gpu=True)

# [3] Launch distributed training job.
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()

Issue Severity

Low: It annoys or frustrates me.

@CMGeldenhuys CMGeldenhuys added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Sep 16, 2023
@matthewdeng matthewdeng added P1 Issue that should be fixed within a few weeks train Ray Train Related Issue and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Sep 20, 2023
@matthewdeng
Copy link
Contributor

Thanks for the detailed report, we'll look to find a solution for this for Ray 2.8.

@CMGeldenhuys
Copy link
Author

All good, just thought of putting it up somewhere in case someone else ran into the same issue as I did.

@woshiyyya
Copy link
Member

Resolved by this PR: #39841

You should be able to use lightning.pytorch with Ray Train in 2.8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks ray 2.8 train Ray Train Related Issue
Projects
None yet
Development

No branches or pull requests

3 participants