Skip to content

Commit

Permalink
[train] update Train API references & annotations (#39294)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Deng <matt@anyscale.com>
  • Loading branch information
matthewdeng committed Sep 8, 2023
1 parent b6edccf commit fb4dd92
Show file tree
Hide file tree
Showing 20 changed files with 136 additions and 138 deletions.
103 changes: 28 additions & 75 deletions doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,24 @@
Ray Train API
=============

This page covers framework specific integrations with Ray Train and Ray Train Developer APIs.

.. _train-integration-api:
.. _train-framework-specific-ckpts:

.. currentmodule:: ray

Ray Train Integrations
----------------------

.. _train-pytorch-integration:

PyTorch Ecosystem
~~~~~~~~~~~~~~~~~

Scale out your PyTorch, Lightning, Hugging Face code with Ray TorchTrainer.
-----------------

.. autosummary::
:toctree: doc/

~train.torch.TorchTrainer
~train.torch.TorchConfig
~train.torch.TorchCheckpoint

.. _train-pytorch-integration:

PyTorch
*******
~~~~~~~

.. autosummary::
:toctree: doc/
Expand All @@ -43,7 +34,7 @@ PyTorch
.. _train-lightning-integration:

PyTorch Lightning
*****************
~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: doc/
Expand All @@ -55,49 +46,20 @@ PyTorch Lightning
~train.lightning.RayDeepSpeedStrategy
~train.lightning.RayTrainReportCallback

.. note::

We will deprecate `LightningTrainer`, `LightningConfigBuilder`,
`LightningCheckpoint`, and `LightningPredictor` in Ray 2.8. Please
refer to the :ref:`migration guide <lightning-trainer-migration-guide>` for more info.

.. autosummary::
:toctree: doc/

~train.lightning.LightningTrainer
~train.lightning.LightningConfigBuilder
~train.lightning.LightningCheckpoint
~train.lightning.LightningPredictor

.. _train-transformers-integration:

Hugging Face Transformers
*************************
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: doc/

~train.huggingface.transformers.prepare_trainer
~train.huggingface.transformers.RayTrainReportCallback

.. note::

We will deprecate `TransformersTrainer`, `TransformersCheckpoint` in Ray 2.8. Please
refer to the :ref:`migration guide <transformers-trainer-migration-guide>` for more info.

.. autosummary::
:toctree: doc/

~train.huggingface.TransformersTrainer
~train.huggingface.TransformersCheckpoint

Hugging Face Accelerate
***********************

.. autosummary::
:toctree: doc/

~train.huggingface.AccelerateTrainer
More Frameworks
---------------

Tensorflow/Keras
~~~~~~~~~~~~~~~~
Expand All @@ -107,21 +69,8 @@ Tensorflow/Keras

~train.tensorflow.TensorflowTrainer
~train.tensorflow.TensorflowConfig
~train.tensorflow.TensorflowCheckpoint


Tensorflow/Keras Training Loop Utilities
****************************************

.. autosummary::
:toctree: doc/

~train.tensorflow.prepare_dataset_shard

.. autosummary::

~air.integrations.keras.ReportCheckpointCallback

~train.tensorflow.keras.ReportCheckpointCallback

Horovod
~~~~~~~
Expand All @@ -140,7 +89,6 @@ XGBoost
:toctree: doc/

~train.xgboost.XGBoostTrainer
~train.xgboost.XGBoostCheckpoint


LightGBM
Expand All @@ -150,32 +98,42 @@ LightGBM
:toctree: doc/

~train.lightgbm.LightGBMTrainer
~train.lightgbm.LightGBMCheckpoint


.. _ray-train-configs-api:

Ray Train Config
----------------
Ray Train Configuration
-----------------------

.. autosummary::
:toctree: doc/

~train.ScalingConfig
~train.RunConfig
~train.CheckpointConfig
~train.FailureConfig
~train.DataConfig
~train.FailureConfig
~train.RunConfig
~train.ScalingConfig
~train.SyncConfig

.. _train-loop-api:

Ray Train Loop
--------------
Ray Train Utilities
-------------------

**Classes**

.. autosummary::
:toctree: doc/

~train.Checkpoint
~train.context.TrainContext

**Functions**

.. autosummary::
:toctree: doc/

~train.get_checkpoint
~train.get_context
~train.get_dataset_shard
~train.report
Expand All @@ -190,14 +148,9 @@ Ray Train Output

~train.Result

.. autosummary::
:toctree: doc/

~train.Checkpoint


Ray Train Base Classes (Developer APIs)
---------------------------------------
Ray Train Developer APIs
------------------------

.. _train-base-trainer:

Expand Down
4 changes: 2 additions & 2 deletions doc/source/train/distributed-tensorflow-keras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ to Ray Train. This reporting logs the results to the console output and appends
local log files. The logging also triggers :ref:`checkpoint bookkeeping <train-dl-configure-checkpoints>`.

The easiest way to report your results with Keras is by using the
:class:`~air.integrations.keras.ReportCheckpointCallback`:
:class:`~ray.train.tensorflow.keras.ReportCheckpointCallback`:

.. code-block:: python
from ray.air.integrations.keras import ReportCheckpointCallback
from ray.train.tensorflow.keras import ReportCheckpointCallback
def train_func(config: dict):
# ...
Expand Down
2 changes: 1 addition & 1 deletion doc/source/tune/api/syncing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ Tune Syncing Configuration
--------------------------

.. autosummary::
:toctree: doc/

ray.train.SyncConfig
:noindex:
49 changes: 21 additions & 28 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) ->


@dataclass
@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
class ScalingConfig:
"""Configuration for scaling training.
Expand All @@ -112,13 +112,6 @@ class ScalingConfig:
placement_strategy: The placement strategy to use for the
placement group of the Ray actors. See :ref:`Placement Group
Strategies <pgroup-strategy>` for the possible options.
_max_cpu_fraction_per_node: [Experimental] The max fraction of CPUs per node
that Train will use for scheduling training actors. The remaining CPUs
can be used for dataset tasks. It is highly recommended that you set this
to less than 1.0 (e.g., 0.8) when passing datasets to trainers, to avoid
hangs / CPU starvation of dataset tasks. Warning: this feature is
experimental and is not recommended for use with autoscaling (scale-up will
not trigger properly).
"""

# If adding new attributes here, please also update
Expand Down Expand Up @@ -523,7 +516,7 @@ def _merge(self, other: "DatasetConfig") -> "DatasetConfig":


@dataclass
@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
class FailureConfig:
"""Configuration related to failure handling of each training/tuning run.
Expand Down Expand Up @@ -574,7 +567,7 @@ def _repr_html_(self):


@dataclass
@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
class CheckpointConfig:
"""Configurable parameters for defining the checkpointing strategy.
Expand Down Expand Up @@ -723,7 +716,7 @@ def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:


@dataclass
@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
class RunConfig:
"""Runtime configuration for training and tuning runs.
Expand All @@ -734,32 +727,32 @@ class RunConfig:
Args:
name: Name of the trial or experiment. If not provided, will be deduced
from the Trainable.
storage_path: Path to store results at. Can be a local directory or
storage_path: [Beta] Path to store results at. Can be a local directory or
a destination on cloud storage. If Ray storage is set up,
defaults to the storage location. Otherwise, this defaults to
the local ``~/ray_results`` directory.
failure_config: Failure mode configuration.
checkpoint_config: Checkpointing configuration.
sync_config: Configuration object for syncing. See train.SyncConfig.
verbose: 0, 1, or 2. Verbosity mode.
0 = silent, 1 = default, 2 = verbose. Defaults to 1.
If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set,
uses the old verbosity settings:
0 = silent, 1 = only status updates, 2 = status and brief
results, 3 = status and detailed results.
stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper
for more info. Stoppers should be serializable.
callbacks: Callbacks to invoke.
callbacks: [DeveloperAPI] Callbacks to invoke.
Refer to ray.tune.callback.Callback for more info.
Callbacks should be serializable.
Currently only stateless callbacks are supported for resumed runs.
(any state of the callback will not be checkpointed by Tune
and thus will not take effect in resumed runs).
failure_config: Failure mode configuration.
sync_config: Configuration object for syncing. See train.SyncConfig.
checkpoint_config: Checkpointing configuration.
progress_reporter: Progress reporter for reporting
progress_reporter: [DeveloperAPI] Progress reporter for reporting
intermediate experiment progress. Defaults to CLIReporter if
running in command-line, or JupyterNotebookReporter if running in
a Jupyter notebook.
verbose: 0, 1, or 2. Verbosity mode.
0 = silent, 1 = default, 2 = verbose. Defaults to 1.
If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set,
uses the old verbosity settings:
0 = silent, 1 = only status updates, 2 = status and brief
results, 3 = status and detailed results.
log_to_file: Log stdout and stderr to files in
log_to_file: [DeveloperAPI] Log stdout and stderr to files in
trial directories. If this is `False` (default), no files
are written. If `true`, outputs are written to `trialdir/stdout`
and `trialdir/stderr`, respectively. If this is a single string,
Expand All @@ -773,13 +766,13 @@ class RunConfig:
name: Optional[str] = None
storage_path: Optional[str] = None
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
callbacks: Optional[List["Callback"]] = None
stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
failure_config: Optional[FailureConfig] = None
sync_config: Optional["SyncConfig"] = None
checkpoint_config: Optional[CheckpointConfig] = None
progress_reporter: Optional["ProgressReporter"] = None
sync_config: Optional["SyncConfig"] = None
verbose: Optional[Union[int, "AirVerbosity", "Verbosity"]] = None
stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None
callbacks: Optional[List["Callback"]] = None
progress_reporter: Optional["ProgressReporter"] = None
log_to_file: Union[bool, str, Tuple[str, str]] = False

# Deprecated
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger(__name__)


@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
@dataclass
class Result:
"""The final result of a ML training run or a Tune trial.
Expand Down
15 changes: 15 additions & 0 deletions python/ray/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@
"TrainingIterator",
"TRAIN_DATASET_KEY",
]

get_checkpoint.__module__ = "ray.train"
get_context.__module__ = "ray.train"
get_dataset_shard.__module__ = "ray.train"
report.__module__ = "ray.train"
BackendConfig.__module__ = "ray.train"
Checkpoint.__module__ = "ray.train"
CheckpointConfig.__module__ = "ray.train"
DataConfig.__module__ = "ray.train"
FailureConfig.__module__ = "ray.train"
Result.__module__ = "ray.train"
RunConfig.__module__ = "ray.train"
ScalingConfig.__module__ = "ray.train"
SyncConfig.__module__ = "ray.train"
TrainingIterator.__module__ = "ray.train"
2 changes: 1 addition & 1 deletion python/ray/train/_internal/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing_extensions import Literal


@PublicAPI
@PublicAPI(stability="stable")
class DataConfig:
"""Class responsible for configuring Train dataset preprocessing.
Expand Down
6 changes: 3 additions & 3 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def wrapper(*args, **kwargs):
return inner


@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
@_warn_session_misuse()
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save a checkpoint.
Expand Down Expand Up @@ -907,7 +907,7 @@ def train_func():
_get_session().report(metrics, checkpoint=checkpoint)


@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
"""Access the session's last checkpoint to resume from if applicable.
Expand Down Expand Up @@ -1223,7 +1223,7 @@ def train_loop_per_worker():
return session.node_rank


@PublicAPI(stability="beta")
@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_dataset_shard(
dataset_name: Optional[str] = None,
Expand Down
Loading

0 comments on commit fb4dd92

Please sign in to comment.