Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Update pytorch-lightning integration API (#38883) #38985

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/ray/tune/integration/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from lightgbm.callback import CallbackEnv
from lightgbm.basic import Booster
from ray.util.annotations import Deprecated


class TuneCallback:
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(self, *args, **kwargs):
)


@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
Expand All @@ -203,7 +205,7 @@ def __init__(
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
if log_once("tune_report_deprecated"):
if log_once("tune_lightgbm_report_deprecated"):
warnings.warn(
"`ray.tune.integration.lightgbm.TuneReportCallback` is deprecated. "
"Use `ray.tune.integration.lightgbm.TuneCheckpointReportCallback` "
Expand Down
174 changes: 69 additions & 105 deletions python/ray/tune/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import inspect
import logging
import os
import tempfile
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Type, Union

from pytorch_lightning import Callback, Trainer, LightningModule
from ray import tune
from ray.util import PublicAPI
from ray import train
from ray.util import log_once
from ray.util.annotations import PublicAPI, Deprecated
from ray.air.checkpoint import Checkpoint as LegacyCheckpoint
from ray.train._checkpoint import Checkpoint
from ray.train._internal.storage import _use_storage_context

import os

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,55 +79,56 @@ def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):


@PublicAPI
class TuneReportCallback(TuneCallback):
"""PyTorch Lightning to Ray Tune reporting callback

Reports metrics to Ray Tune.

.. note::
In Ray 2.4, we introduced
:class:`LightningTrainer <ray.train.lightning.LightningTrainer>`,
which provides native integration with PyTorch Lightning. Here is
:ref:`a simple example <lightning_mnist_example>` of how to use
``LightningTrainer``.
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback

Saves checkpoints after each validation step. Also reports metrics to Tune,
which is needed for checkpoint registration.

Args:
metrics: Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
on: When to trigger checkpoint creations. Must be one of
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
save_checkpoints: If True (default), checkpoints will be saved and
reported to Ray. If False, only metrics will be reported.
on: When to trigger checkpoint creations and metric reports. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".


Example:

.. code-block:: python

import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.integration.pytorch_lightning import (
TuneReportCheckpointCallback)

# Report loss and accuracy to Tune after each validation epoch:
trainer = pl.Trainer(callbacks=[TuneReportCallback(
["val_loss", "val_acc"], on="validation_end")])
# Save checkpoint after each training batch and after each
# validation epoch.
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
filename="trainer.ckpt", on="validation_end")])

# Same as above, but report as `loss` and `mean_accuracy`:
trainer = pl.Trainer(callbacks=[TuneReportCallback(
{"loss": "val_loss", "mean_accuracy": "val_acc"},
on="validation_end")])

"""

def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = "checkpoint",
save_checkpoints: bool = True,
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCallback, self).__init__(on=on)
super(TuneReportCheckpointCallback, self).__init__(on=on)
if isinstance(metrics, str):
metrics = [metrics]
self._save_checkpoints = save_checkpoints
self._filename = filename
self._metrics = metrics

def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule):
Expand All @@ -146,102 +154,58 @@ def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule):

return report_dict

def _handle(self, trainer: Trainer, pl_module: LightningModule):
report_dict = self._get_report_dict(trainer, pl_module)
if report_dict is not None:
tune.report(**report_dict)


class _TuneCheckpointCallback(TuneCallback):
"""PyTorch Lightning checkpoint callback

Saves checkpoints after each validation step.

.. note::
In Ray 2.4, we introduced
:class:`LightningTrainer <ray.train.lightning.LightningTrainer>`,
which provides native integration with PyTorch Lightning. Here is
:ref:`a simple example <lightning_mnist_example>` of how to use
``LightningTrainer``.

Checkpoint are currently not registered if no ``tune.report()`` call
is made afterwards. Consider using ``TuneReportCheckpointCallback``
instead.

Args:
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".
@contextmanager
def _get_checkpoint(
self, trainer: Trainer
) -> Optional[Union[Checkpoint, LegacyCheckpoint]]:
if not self._save_checkpoints:
yield None
return

with tempfile.TemporaryDirectory() as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))

"""
if _use_storage_context():
checkpoint = Checkpoint.from_directory(checkpoint_dir)
else:
checkpoint = LegacyCheckpoint.from_directory(checkpoint_dir)

def __init__(
self, filename: str = "checkpoint", on: Union[str, List[str]] = "validation_end"
):
super(_TuneCheckpointCallback, self).__init__(on)
self._filename = filename
yield checkpoint

def _handle(self, trainer: Trainer, pl_module: LightningModule):
if trainer.sanity_checking:
return
step = f"epoch={trainer.current_epoch}-step={trainer.global_step}"
with tune.checkpoint_dir(step=step) as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))


@PublicAPI
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback

Saves checkpoints after each validation step. Also reports metrics to Tune,
which is needed for checkpoint registration.

Args:
metrics: Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".

report_dict = self._get_report_dict(trainer, pl_module)
if not report_dict:
return

Example:

.. code-block:: python

import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import (
TuneReportCheckpointCallback)

# Save checkpoint after each training batch and after each
# validation epoch.
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
filename="trainer.ckpt", on="validation_end")])
with self._get_checkpoint(trainer) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)


"""
class _TuneCheckpointCallback(TuneCallback):
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` "
"is deprecated."
)

_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callbacks_cls = TuneReportCallback

@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = "checkpoint",
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCheckpointCallback, self).__init__(on)
self._checkpoint = self._checkpoint_callback_cls(filename, on)
self._report = self._report_callbacks_cls(metrics, on)

def _handle(self, trainer: Trainer, pl_module: LightningModule):
self._checkpoint._handle(trainer, pl_module)
self._report._handle(trainer, pl_module)
if log_once("tune_ptl_report_deprecated"):
warnings.warn(
"`ray.tune.integration.pytorch_lightning.TuneReportCallback` "
"is deprecated. Use "
"`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`"
" instead."
)
super(TuneReportCallback, self).__init__(
metrics=metrics, save_checkpoints=False, on=on
)
4 changes: 3 additions & 1 deletion python/ray/tune/integration/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ray.train._internal.storage import _use_storage_context
from ray.tune.utils import flatten_dict
from ray.util import log_once
from ray.util.annotations import Deprecated
from xgboost.core import Booster

try:
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(self, *args, **kwargs):
)


@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
Expand All @@ -196,7 +198,7 @@ def __init__(
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
if log_once("tune_report_deprecated"):
if log_once("tune_xgboost_report_deprecated"):
warnings.warn(
"`ray.tune.integration.xgboost.TuneReportCallback` is deprecated. "
"Use `ray.tune.integration.xgboost.TuneCheckpointReportCallback` "
Expand Down
17 changes: 9 additions & 8 deletions python/ray/tune/tests/test_integration_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ray.tune.integration.pytorch_lightning import (
TuneReportCallback,
TuneReportCheckpointCallback,
_TuneCheckpointCallback,
)


Expand Down Expand Up @@ -89,7 +88,7 @@ def train(config):
max_epochs=1,
callbacks=[
TuneReportCallback(
{"tune_loss": "avg_val_loss"}, on="validation_end"
metrics={"tune_loss": "avg_val_loss"}, on="validation_end"
)
],
)
Expand All @@ -106,10 +105,10 @@ def testCheckpointCallback(self):
def train(config):
module = _MockModule(10.0, 20.0)
trainer = pl.Trainer(
max_epochs=1,
max_epochs=10,
callbacks=[
_TuneCheckpointCallback(
"trainer.ckpt", on=["batch_end", "train_end"]
TuneReportCheckpointCallback(
filename="trainer.ckpt", on=["train_epoch_end"]
)
],
)
Expand All @@ -128,8 +127,8 @@ def train(config):
for dir in os.listdir(analysis.trials[0].local_path)
if dir.startswith("checkpoint")
]
# 10 checkpoints after each batch, 1 checkpoint at end
self.assertEqual(len(checkpoints), 11)
# 1 checkpoint per epoch
self.assertEqual(len(checkpoints), 10)

def testReportCheckpointCallback(self):
tmpdir = tempfile.mkdtemp()
Expand All @@ -141,7 +140,9 @@ def train(config):
max_epochs=1,
callbacks=[
TuneReportCheckpointCallback(
["avg_val_loss"], "trainer.ckpt", on="validation_end"
metrics=["avg_val_loss"],
filename="trainer.ckpt",
on="validation_end",
)
],
)
Expand Down
Loading