Skip to content

Commit

Permalink
[train] enable new persistence mode for doctests (#38876)
Browse files Browse the repository at this point in the history
* [train] enable new persistence mode for doctests

Signed-off-by: Matthew Deng <matt@anyscale.com>
  • Loading branch information
matthewdeng committed Aug 25, 2023
1 parent 585bf1f commit 8c72ec0
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 26 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
- ./ci/env/env_info.sh
- bazel test --config=ci $(./scripts/bazel_export_options)
--test_tag_filters=doctest,-gpu
--test_env=RAY_AIR_NEW_PERSISTENCE_MODE=1
python/ray/... doc/...

- label: ":python: Ray on Spark Test"
Expand Down
4 changes: 3 additions & 1 deletion .buildkite/pipeline.gpu_large.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@
- pip install transformers==4.30.2 datasets==2.14.0
- ./ci/env/env_info.sh
- bazel test --config=ci $(./scripts/bazel_export_options)
--test_tag_filters=doctest,-cpu python/ray/... doc/...
--test_tag_filters=doctest,-cpu
--test_env=RAY_AIR_NEW_PERSISTENCE_MODE=1
python/ray/... doc/...

- label: ":zap: :python: Lightning 2.0 Train GPU tests"
conditions:
Expand Down
12 changes: 7 additions & 5 deletions doc/source/data/batch_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,12 @@ Models that have been trained with :ref:`Ray Train <train-docs>` can then be use

checkpoint = result.checkpoint

**Step 3:** Use Ray Data for batch inference. To load in the model from the :class:`Checkpoint <ray.train.Checkpoint>` inside the Python class, use one of the framework-specific Checkpoint classes.
**Step 3:** Use Ray Data for batch inference. To load in the model from the :class:`Checkpoint <ray.train.Checkpoint>` inside the Python class, use the methodology corresponding to the Trainer used to train the model.

In this case, use :class:`XGBoostCheckpoint <ray.train.xgboost.XGBoostCheckpoint>` to load the model.
- **Deep Learning Trainers:** :ref:`train-checkpointing`
- **Tree-Based Trainers:** :ref:`train-gbdt-checkpoints`

In this case, use :meth:`XGBoostTrainer.get_model() <ray.train.xgboost.XGBoostTrainer.get_model>` to load the model.

The rest of the logic looks the same as in the `Quickstart <#quickstart>`_.

Expand All @@ -479,14 +482,13 @@ The rest of the logic looks the same as in the `Quickstart <#quickstart>`_.
import xgboost

from ray.train import Checkpoint
from ray.train.xgboost import LegacyXGBoostCheckpoint
from ray.train.xgboost import XGBoostTrainer

test_dataset = valid_dataset.drop_columns(["target"])

class XGBoostPredictor:
def __init__(self, checkpoint: Checkpoint):
xgboost_checkpoint = LegacyXGBoostCheckpoint.from_checkpoint(checkpoint)
self.model = xgboost_checkpoint.get_model()
self.model = XGBoostTrainer.get_model(checkpoint)

def __call__(self, data: pd.DataFrame) -> Dict[str, np.ndarray]:
dmatrix = xgboost.DMatrix(data)
Expand Down
3 changes: 3 additions & 0 deletions doc/source/train/distributed-xgboost-lightgbm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ training parameters are passed as the ``params`` dictionary.

Ray-specific params are passed in through the trainer constructors.


.. _train-gbdt-checkpoints:

Saving and Loading XGBoost and LightGBM Checkpoints
---------------------------------------------------

Expand Down
18 changes: 9 additions & 9 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Checkpoint:
.. code-block:: python
from ray.train import Checkpoint
from ray.air import Checkpoint
# Create checkpoint data dict
checkpoint_data = {"data": 123}
Expand Down Expand Up @@ -270,7 +270,7 @@ def path(self) -> Optional[str]:
Example:
>>> from ray.train import Checkpoint
>>> from ray.air import Checkpoint
>>> checkpoint = Checkpoint.from_uri("s3://some-bucket/some-location")
>>> assert checkpoint.path == "s3://some-bucket/some-location"
>>> checkpoint = Checkpoint.from_dict({"data": 1})
Expand Down Expand Up @@ -299,11 +299,11 @@ def uri(self) -> Optional[str]:
In all other cases, this will return None. Users can then choose to
persist to cloud with
:meth:`Checkpoint.to_uri() <ray.train.Checkpoint.to_uri>`.
:meth:`Checkpoint.to_uri() <ray.air.Checkpoint.to_uri>`.
Example:
>>> from ray.train import Checkpoint
>>> from ray.air import Checkpoint
>>> checkpoint = Checkpoint.from_uri("s3://some-bucket/some-location")
>>> assert checkpoint.uri == "s3://some-bucket/some-location"
>>> checkpoint = Checkpoint.from_dict({"data": 1})
Expand All @@ -330,7 +330,7 @@ def from_bytes(cls, data: bytes) -> "Checkpoint":
data: Data object containing pickled checkpoint data.
Returns:
ray.train.Checkpoint: checkpoint object.
ray.air.Checkpoint: checkpoint object.
"""
bytes_data = pickle.loads(data)
if isinstance(bytes_data, dict):
Expand Down Expand Up @@ -359,7 +359,7 @@ def from_dict(cls, data: dict) -> "Checkpoint":
data: Dictionary containing checkpoint data.
Returns:
ray.train.Checkpoint: checkpoint object.
ray.air.Checkpoint: checkpoint object.
"""
state = {}
if _METADATA_KEY in data:
Expand Down Expand Up @@ -454,7 +454,7 @@ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
Checkpoint).
Returns:
ray.train.Checkpoint: checkpoint object.
ray.air.Checkpoint: checkpoint object.
"""
state = {}

Expand All @@ -473,7 +473,7 @@ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
@classmethod
@DeveloperAPI
def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
"""Create a checkpoint from a generic :class:`ray.train.Checkpoint`.
"""Create a checkpoint from a generic :class:`ray.air.Checkpoint`.
This method can be used to create a framework-specific checkpoint from a
generic :class:`Checkpoint` object.
Expand Down Expand Up @@ -714,7 +714,7 @@ def from_uri(cls, uri: str) -> "Checkpoint":
uri: Source location URI to read data from.
Returns:
ray.train.Checkpoint: checkpoint object.
ray.air.Checkpoint: checkpoint object.
"""
state = {}
try:
Expand Down
18 changes: 12 additions & 6 deletions python/ray/train/huggingface/accelerate/accelerate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def train_loop_per_worker():
Example:
.. testcode::
import os
import tempfile
import torch
import torch.nn as nn
Expand Down Expand Up @@ -184,16 +186,20 @@ def train_loop_per_worker():
if epoch % 20 == 0:
print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")
# Create checkpoint.
base_model=accelerator.unwrap_model(model)
checkpoint_dir = tempfile.mkdtemp()
torch.save(
{"model_state_dict": base_model.state_dict()},
os.path.join(checkpoint_dir, "model.pt"),
)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
# Report and record metrics, checkpoint model at end of each
# epoch
train.report(
{"loss": loss.item(), "epoch": epoch},
checkpoint=Checkpoint.from_dict(
dict(
epoch=epoch,
model=accelerator.unwrap_model(model).state_dict(),
)
),
checkpoint=checkpoint
)
Expand Down
16 changes: 11 additions & 5 deletions python/ray/train/tensorflow/tensorflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def train_loop_per_worker():
.. testcode::
import os
import tempfile
import tensorflow as tf
import ray
Expand Down Expand Up @@ -118,13 +120,17 @@ def train_loop_per_worker(config):
)
for epoch in range(config["num_epochs"]):
model.fit(tf_dataset)
# You can also use ray.air.integrations.keras.Callback
# for reporting and checkpointing instead of reporting manually.
# Create checkpoint.
checkpoint_dir = tempfile.mkdtemp()
model.save_weights(
os.path.join(checkpoint_dir, "my_checkpoint")
)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
train.report(
{},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.get_weights())
),
checkpoint=checkpoint,
)
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
Expand Down

0 comments on commit 8c72ec0

Please sign in to comment.