Skip to content

Commit

Permalink
[Train][Docs] Test code snippets in session.py (#37588)
Browse files Browse the repository at this point in the history
Many of the code snippets in session.py aren't tested, and some of them are broken. This PR fixes and tests the examples.

---------

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani committed Jul 21, 2023
1 parent f34f087 commit 8ad889f
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ This functionality can be easily enabled by setting ``auto_transfer=True`` in :f
from torch.utils.data import DataLoader
from ray import train
data_loader = DataLoader(my_dataset, batch_size)
train_loader = train.torch.prepare_data_loader(
data_loader=train_loader, move_to_device=True, auto_transfer=True
Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ doctest(
# GPU tests
"tensorflow/tensorflow_trainer.py",
"huggingface/transformers/transformers_trainer.py",
"_internal/session.py",
"context.py"
]
),
size = "large",
Expand All @@ -19,7 +21,10 @@ doctest(
files = [
"tensorflow/tensorflow_trainer.py",
"huggingface/transformers/transformers_trainer.py",
"_internal/session.py",
"context.py"
],
size = "large",
tags = ["team:ml"],
gpu = True
)
Expand Down
224 changes: 143 additions & 81 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,19 +599,22 @@ def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
longer be accessible to the caller after the report call.
Example:
.. code-block: python
.. testcode::
import tensorflow as tf
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
######## Using it in the *per worker* train loop (TrainSession) #######
def train_func():
model = build_model()
model = tf.keras.applications.resnet50.ResNet50()
model.save("my_model", overwrite=True)
session.report(
metrics={"foo": "bar"},
checkpoint=Checkpoint.from_directory(temp_dir.name)
checkpoint=Checkpoint.from_directory("my_model")
)
# Air guarantees by this point, you can safely write new stuff to
# "my_model" directory.
Expand All @@ -622,10 +625,15 @@ def train_func():
)
result = trainer.fit()
# If you navigate to result.checkpoint's path, you will find the
content of ``model.save()`` under it.
# content of ``model.save()`` under it.
# If you have `SyncConfig` configured, the content should also
# show up in the corresponding cloud storage path.
.. testoutput::
:hide:
...
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
Expand All @@ -643,21 +651,23 @@ def get_checkpoint() -> Optional[Checkpoint]:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
.. code-block:: python
.. testcode::
import tensorflow as tf
######## Using it in the *per worker* train loop (TrainSession) ######
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
def train_func():
ckpt = session.get_checkpoint()
if ckpt:
with ckpt.as_directory() as loaded_checkpoint_dir:
import tensorflow as tf
model = tf.keras.models.load_model(loaded_checkpoint_dir)
else:
model = build_model()
model = tf.keras.applications.resnet50.ResNet50()
model.save("my_model", overwrite=True)
session.report(
Expand All @@ -680,6 +690,11 @@ def train_func():
resume_from_checkpoint=result.checkpoint,
)
result2 = trainer2.fit()
.. testoutput::
:hide:
...
"""

return _get_session().loaded_checkpoint
Expand Down Expand Up @@ -720,18 +735,21 @@ def get_trial_dir() -> str:
If calling from a Train session, this will give the trial directory of its parent
Tune session.
.. code-block:: python
.. testcode::
from ray import tune
from ray.air import session
def train_func():
# Example:
# >>> session.get_trial_dir()
# ~/ray_results/<exp-name>/<trial-dir>
def train_func(config):
print(session.get_trial_dir())
tuner = tune.Tuner(train_func)
tuner.fit()
.. testoutput::
:options: +MOCK
/Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
"""
return _get_session().trial_dir

Expand All @@ -741,21 +759,30 @@ def train_func():
def get_world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
.. code-block:: python
.. testcode::
import time
import ray
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
NUM_WORKERS = 2
def train_loop_per_worker(config):
assert session.get_world_size() == 4
assert session.get_world_size() == NUM_WORKERS
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TensorflowTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
"""
session = _get_session()
if not hasattr(session, "world_size"):
Expand All @@ -772,24 +799,29 @@ def train_loop_per_worker(config):
def get_world_rank() -> int:
"""Get the world rank of this worker.
.. code-block:: python
.. testcode::
import time
import ray
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
def train_loop_per_worker(config):
if session.get_world_rank() == 0:
print("Worker 0")
def train_loop_per_worker():
for iter in range(100):
time.sleep(1)
if session.get_world_rank() == 0:
print("Worker 0")
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TensorflowTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
"""
session = _get_session()
if not hasattr(session, "world_rank"):
Expand All @@ -806,23 +838,32 @@ def train_loop_per_worker():
def get_local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. code-block:: python
.. testcode::
import torch
import time
import ray
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker():
def train_loop_per_worker(config):
if torch.cuda.is_available():
torch.cuda.set_device(session.get_local_rank())
...
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
"""
session = _get_session()
if not hasattr(session, "local_rank"):
Expand All @@ -840,20 +881,28 @@ def get_local_world_size() -> int:
"""Get the local world size of this node (i.e. number of workers on this node).
Example:
>>> import ray
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> from ray.train.torch import TorchTrainer
>>>
>>> def train_loop_per_worker():
... return session.get_local_world_size()
>>>
>>> train_dataset = ray.data.from_items(
... [{"x": x, "y": x + 1} for x in range(32)])
>>> trainer = TorchTrainer(train_loop_per_worker,
... scaling_config=ScalingConfig(num_workers=1),
... datasets={"train": train_dataset})
>>> trainer.fit() # doctest: +SKIP
.. testcode::
import ray
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker():
print(session.get_local_world_size())
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
trainer.fit()
.. testoutput::
:hide:
...
"""
session = _get_session()
if not hasattr(session, "local_world_size"):
Expand All @@ -871,20 +920,28 @@ def get_node_rank() -> int:
"""Get the rank of this node.
Example:
>>> import ray
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> from ray.train.torch import TorchTrainer
>>>
>>> def train_loop_per_worker():
... return session.get_node_rank()
>>>
>>> train_dataset = ray.data.from_items(
... [{"x": x, "y": x + 1} for x in range(32)])
>>> trainer = TorchTrainer(train_loop_per_worker,
... scaling_config=ScalingConfig(num_workers=1),
... datasets={"train": train_dataset})
>>> trainer.fit() # doctest: +SKIP
.. testcode::
import ray
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker():
print(session.get_node_rank())
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
trainer.fit()
.. testoutput::
:hide:
...
"""
session = _get_session()
if not hasattr(session, "node_rank"):
Expand All @@ -907,29 +964,34 @@ def get_dataset_shard(
:meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
appropriate framework-specific data type.
.. code-block:: python
.. testcode::
import ray
from ray import train
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker():
model = Net()
for iter in range(100):
def train_loop_per_worker(config):
...
for epoch in range(2):
# Trainer will automatically handle sharding.
data_shard = session.get_dataset_shard("train")
for batch in data_shard.iter_torch_batches():
# ...
return model
...
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2),
datasets={"train": train_dataset})
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
Args:
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
specifies which dataset shard to return.
Expand Down

0 comments on commit 8ad889f

Please sign in to comment.