Skip to content

Commit

Permalink
[train] Legacy interface cleanup (air.Checkpoint, `LegacyExperiment…
Browse files Browse the repository at this point in the history
…Analysis`) (#39289)

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: matthewdeng <matt@anyscale.com>
  • Loading branch information
justinvyu and matthewdeng committed Sep 8, 2023
1 parent fddde50 commit 2913e9b
Show file tree
Hide file tree
Showing 24 changed files with 81 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Saving and Loading your RL Algorithms and Policies
##################################################


You can use :py:class:`~ray.air.checkpoint.Checkpoint` objects to store
You can use :py:class:`~ray.train.Checkpoint` objects to store
and load the current state of your :py:class:`~ray.rllib.algorithms.algorithm.Algorithm`
or :py:class:`~ray.rllib.policy.policy.Policy` and the neural networks (weights)
within these structures. In the following, we will cover how you can create these
Expand All @@ -26,7 +26,7 @@ or a single :py:class:`~ray.rllib.policy.policy.Policy` instance.
The Algorithm- or Policy instances that were used to create the checkpoint in the first place
may or may not have been trained prior to this.

RLlib uses the :py:class:`~ray.air.checkpoint.Checkpoint` class to create checkpoints and
RLlib uses the :py:class:`~ray.train.Checkpoint` class to create checkpoints and
restore objects from them.

The main file in a checkpoint directory, containing the state information, is currently
Expand All @@ -50,7 +50,7 @@ How do I create an Algorithm checkpoint?
----------------------------------------

The :py:class:`~ray.rllib.algorithms.algorithm.Algorithm` ``save()`` method creates a new checkpoint
(directory with files in it) and returns the path to that directory.
(directory with files in it).

Let's take a look at a simple example on how to create such an
Algorithm checkpoint:
Expand All @@ -69,8 +69,6 @@ like this:
$ ls -la
.
..
.is_checkpoint
.tune_metadata
policies/
algorithm_state.pkl
rllib_checkpoint.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,18 @@
"Enabling checkpointing is pretty easy - we just need to pass a `Checkpoint` object with the model state to the `ray.train.report()` API.\n",
"\n",
"```python\n",
" from ray import train\n",
" from ray.train import Checkpoint\n",
"\n",
" with TemporaryDirectory() as tmpdir:\n",
" torch.save(model.state_dict(), os.path.join(tmpdir, \"checkpoint.pt\"))\n",
" train.report(dict(loss=test_loss), \n",
" checkpoint=Checkpoint.from_directory(tmpdir))\n",
" torch.save(\n",
" {\n",
" \"epoch\": epoch,\n",
" \"model\": model.module.state_dict()\n",
" },\n",
" os.path.join(tmpdir, \"checkpoint.pt\")\n",
" )\n",
" train.report(dict(loss=test_loss), checkpoint=Checkpoint.from_directory(tmpdir))\n",
"```\n",
"\n",
"### Move the data loader to the training function\n",
Expand Down Expand Up @@ -888,11 +894,17 @@
" loss_fn = nn.CrossEntropyLoss()\n",
" optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
" \n",
" for t in range(epochs):\n",
" for epoch in range(epochs):\n",
" train_epoch(train_dataloader, model, loss_fn, optimizer)\n",
" test_loss = test_epoch(test_dataloader, model, loss_fn)\n",
" with TemporaryDirectory() as tmpdir:\n",
" torch.save(model.state_dict(), os.path.join(tmpdir, \"checkpoint.pt\"))\n",
" torch.save(\n",
" {\n",
" \"epoch\": epoch,\n",
" \"model\": model.module.state_dict()\n",
" },\n",
" os.path.join(tmpdir, \"checkpoint.pt\")\n",
" )\n",
" train.report(dict(loss=test_loss), checkpoint=Checkpoint.from_directory(tmpdir))\n",
"\n",
" print(\"Done!\")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ray.train.torch import TorchTrainer, LegacyTorchCheckpoint\n",
"from ray.train.torch import TorchTrainer\n",
"from ray.train import ScalingConfig, RunConfig, CheckpointConfig\n",
"\n",
"# Scale out model training across 4 GPUs.\n",
Expand Down
6 changes: 3 additions & 3 deletions doc/source/tune/tutorials/tune-storage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ that implements saving and loading checkpoints.
import os
import ray
from ray import air, tune
from ray import train, tune
from your_module import my_trainable
# Look for the existing cluster and connect to it
Expand All @@ -179,14 +179,14 @@ that implements saving and loading checkpoints.
tuner = tune.Tuner(
my_trainable,
run_config=air.RunConfig(
run_config=train.RunConfig(
# Name of your experiment
name="my-tune-exp",
# Configure how experiment data and checkpoints are persisted.
# We recommend cloud storage checkpointing as it survives the cluster when
# instances are terminated and has better performance.
storage_path="s3://my-checkpoints-bucket/path/",
checkpoint_config=air.CheckpointConfig(
checkpoint_config=train.CheckpointConfig(
# We'll keep the best five checkpoints at all times
# (with the highest AUC scores, a metric reported by the trainable)
checkpoint_score_attribute="max-auc",
Expand Down
32 changes: 13 additions & 19 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from ray.air._internal.util import _copy_dir_ignore_conflicts
from ray.air.constants import PREPROCESSOR_KEY, CHECKPOINT_ID_ATTR
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.annotations import Deprecated, DeveloperAPI

if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
Expand Down Expand Up @@ -63,9 +63,9 @@ class _CheckpointMetadata:
checkpoint_state: Dict[str, Any]


@PublicAPI(stability="beta")
@Deprecated
class Checkpoint:
"""Ray AIR Checkpoint.
"""[Deprecated] Ray AIR Checkpoint.
An AIR Checkpoint are a common interface for accessing models across
different AIR components and libraries. A Checkpoint can have its data
Expand Down Expand Up @@ -166,6 +166,16 @@ def __init__(
data_dict: Optional[dict] = None,
uri: Optional[str] = None,
):
from ray.train._internal.storage import _use_storage_context

if _use_storage_context():
raise DeprecationWarning(
"`ray.air.Checkpoint` is deprecated. "
"Please use `ray.train.Checkpoint` instead. "
"See the `Checkpoint: New API` section in "
"https://github.com/ray-project/ray/issues/37868 for a migration guide."
)

# First, resolve file:// URIs to local paths
if uri:
local_path = _get_local_path(uri)
Expand Down Expand Up @@ -269,14 +279,6 @@ def path(self) -> Optional[str]:
In all other cases, this will return None.
Example:
>>> from ray.air import Checkpoint
>>> checkpoint = Checkpoint.from_uri("s3://some-bucket/some-location")
>>> assert checkpoint.path == "s3://some-bucket/some-location"
>>> checkpoint = Checkpoint.from_dict({"data": 1})
>>> assert checkpoint.path == None
Returns:
Checkpoint path if this checkpoint is reachable from the current node (e.g.
cloud storage or locally available directory).
Expand All @@ -302,14 +304,6 @@ def uri(self) -> Optional[str]:
persist to cloud with
:meth:`Checkpoint.to_uri() <ray.air.Checkpoint.to_uri>`.
Example:
>>> from ray.air import Checkpoint
>>> checkpoint = Checkpoint.from_uri("s3://some-bucket/some-location")
>>> assert checkpoint.uri == "s3://some-bucket/some-location"
>>> checkpoint = Checkpoint.from_dict({"data": 1})
>>> assert checkpoint.uri == None
Returns:
Checkpoint URI if this URI is reachable from the current node (e.g.
cloud storage or locally available file URI).
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/integrations/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pyarrow.fs

from ray.train import _use_storage_context
from ray.train._internal.storage import _use_storage_context
from ray.tune.logger import LoggerCallback
from ray.tune.experiment import Trial
from ray.tune.utils import flatten_dict
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/integrations/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ray.air._internal import usage as air_usage
from ray.air.util.node import _force_on_current_node

from ray.train import _use_storage_context
from ray.train._internal.storage import _use_storage_context
from ray.tune.logger import LoggerCallback
from ray.tune.utils import flatten_dict
from ray.tune.experiment import Trial
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/tests/test_keras_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_report_and_checkpoint_on_different_events(self, mock_report, model):
assert second_metric == {"loss": 1}
assert second_checkpoint is not None

def parse_call(self, call) -> Tuple[Dict, ray.air.Checkpoint]:
def parse_call(self, call) -> Tuple[Dict, train.Checkpoint]:
(metrics,), kwargs = call
checkpoint = kwargs["checkpoint"]
return metrics, checkpoint
Expand Down
9 changes: 2 additions & 7 deletions python/ray/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,8 @@

from ray._private.usage import usage_lib

from ray.train._internal.storage import _use_storage_context

# Import this first so it can be used in other modules
if _use_storage_context():
from ray.train._checkpoint import Checkpoint
else:
from ray.air import Checkpoint

from ray.train._checkpoint import Checkpoint
from ray.train._internal.data_config import DataConfig
from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
from ray.train._internal.syncer import SyncConfig
Expand All @@ -34,6 +28,7 @@

usage_lib.record_library_usage("train")

Checkpoint.__module__ = "ray.train"

__all__ = [
"get_checkpoint",
Expand Down
13 changes: 3 additions & 10 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ray
from ray.air._internal.session import _get_session
from ray.air._internal.util import StartTraceback, RunnerThread
from ray.air.checkpoint import Checkpoint
from ray.air.constants import (
_RESULT_FETCH_TIMEOUT,
_ERROR_FETCH_TIMEOUT,
Expand All @@ -26,7 +25,7 @@
TIME_THIS_ITER_S,
)
from ray.data import Dataset, DatasetPipeline
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.train import Checkpoint
from ray.train._internal.accelerator import Accelerator
from ray.train._internal.storage import _use_storage_context, StorageContext
from ray.train.constants import (
Expand Down Expand Up @@ -79,6 +78,7 @@ class TrialInfo:
experiment_name: Optional[str] = None


# TODO(justinvyu): [code_removal]
@dataclass
class TrainingResult:
type: TrainingResultType
Expand Down Expand Up @@ -576,20 +576,13 @@ def _report_training_result(self, training_result: _TrainingResult) -> None:
sys.exit(0)

def new_report(
self, metrics: Dict, checkpoint: Optional[NewCheckpoint] = None
self, metrics: Dict, checkpoint: Optional[Checkpoint] = None
) -> None:
if self.ignore_report:
return

persisted_checkpoint = None
if checkpoint:
# TODO(justinvyu): [code_removal]
if not isinstance(checkpoint, NewCheckpoint):
raise ValueError(
"You must pass a `ray.train.Checkpoint` "
"object to `train.report`. `ray.air.Checkpoint` is deprecated."
)

# Persist the reported checkpoint files to storage.
persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

Expand Down
42 changes: 0 additions & 42 deletions python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,48 +216,6 @@ def get_model(
) -> pl.LightningModule:
"""Retrieve the model stored in this checkpoint.
Example:
.. testcode::
import pytorch_lightning as pl
from ray.train.lightning import LightningCheckpoint, LightningPredictor
class MyLightningModule(pl.LightningModule):
def __init__(self, input_dim, output_dim) -> None:
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.save_hyperparameters()
# ...
# After the training is finished, LightningTrainer saves AIR
# checkpoints in the result directory, for example:
# ckpt_dir = "{storage_path}/LightningTrainer_.*/checkpoint_000000"
# You can load model checkpoint with model init arguments
def load_checkpoint(ckpt_dir):
ckpt = LightningCheckpoint.from_directory(ckpt_dir)
# `get_model()` takes the argument list of
# `LightningModule.load_from_checkpoint()` as additional kwargs.
# Please refer to PyTorch Lightning API for more details.
return checkpoint.get_model(
model_class=MyLightningModule,
input_dim=32,
output_dim=10,
)
# You can also load checkpoint with a hyperparameter file
def load_checkpoint_with_hparams(
ckpt_dir, hparam_file="./hparams.yaml"
):
ckpt = LightningCheckpoint.from_directory(ckpt_dir)
return ckpt.get_model(
model_class=MyLightningModule,
hparams_file=hparam_file
)
Args:
model_class: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic.
Expand Down
6 changes: 3 additions & 3 deletions python/ray/train/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd

from ray.air.checkpoint import Checkpoint
from ray.train import Checkpoint
from ray.air.data_batch_type import DataBatchType
from ray.air.util.data_batch_conversion import (
BatchFormat,
Expand Down Expand Up @@ -108,13 +108,13 @@ def from_pandas_udf(

class PandasUDFPredictor(Predictor):
@classmethod
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs):
def from_checkpoint(cls, checkpoint: Checkpoint, **kwargs) -> "Predictor":
return PandasUDFPredictor()

def _predict_pandas(self, df, **kwargs) -> "pd.DataFrame":
return pandas_udf(df, **kwargs)

return PandasUDFPredictor.from_checkpoint(Checkpoint.from_dict({"dummy": 1}))
return PandasUDFPredictor()

def get_preprocessor(self) -> Optional[Preprocessor]:
"""Get the preprocessor to use prior to executing predictions."""
Expand Down
17 changes: 0 additions & 17 deletions python/ray/train/tensorflow/tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,6 @@ def from_model(
Returns:
A :py:class:`TensorflowCheckpoint` containing the specified model.
Examples:
.. testcode::
from ray.train.tensorflow import TensorflowCheckpoint
import tensorflow as tf
model = tf.keras.applications.resnet.ResNet101()
checkpoint = TensorflowCheckpoint.from_model(model)
.. testoutput::
:options: +MOCK
:hide:
... # Model may or may not be downloaded
"""
checkpoint = cls.from_dict(
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()}
Expand Down
Loading

0 comments on commit 2913e9b

Please sign in to comment.