diff --git a/doc/source/train/api/api.rst b/doc/source/train/api/api.rst index 17807fe9494fe..59dc83c5a5142 100644 --- a/doc/source/train/api/api.rst +++ b/doc/source/train/api/api.rst @@ -4,33 +4,24 @@ Ray Train API ============= -This page covers framework specific integrations with Ray Train and Ray Train Developer APIs. - .. _train-integration-api: .. _train-framework-specific-ckpts: .. currentmodule:: ray -Ray Train Integrations ----------------------- - -.. _train-pytorch-integration: - PyTorch Ecosystem -~~~~~~~~~~~~~~~~~ - -Scale out your PyTorch, Lightning, Hugging Face code with Ray TorchTrainer. +----------------- .. autosummary:: :toctree: doc/ ~train.torch.TorchTrainer ~train.torch.TorchConfig - ~train.torch.TorchCheckpoint +.. _train-pytorch-integration: PyTorch -******* +~~~~~~~ .. autosummary:: :toctree: doc/ @@ -43,7 +34,7 @@ PyTorch .. _train-lightning-integration: PyTorch Lightning -***************** +~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: doc/ @@ -55,24 +46,10 @@ PyTorch Lightning ~train.lightning.RayDeepSpeedStrategy ~train.lightning.RayTrainReportCallback -.. note:: - - We will deprecate `LightningTrainer`, `LightningConfigBuilder`, - `LightningCheckpoint`, and `LightningPredictor` in Ray 2.8. Please - refer to the :ref:`migration guide ` for more info. - -.. autosummary:: - :toctree: doc/ - - ~train.lightning.LightningTrainer - ~train.lightning.LightningConfigBuilder - ~train.lightning.LightningCheckpoint - ~train.lightning.LightningPredictor - .. _train-transformers-integration: Hugging Face Transformers -************************* +~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: doc/ @@ -80,24 +57,9 @@ Hugging Face Transformers ~train.huggingface.transformers.prepare_trainer ~train.huggingface.transformers.RayTrainReportCallback -.. note:: - - We will deprecate `TransformersTrainer`, `TransformersCheckpoint` in Ray 2.8. Please - refer to the :ref:`migration guide ` for more info. - -.. autosummary:: - :toctree: doc/ - - ~train.huggingface.TransformersTrainer - ~train.huggingface.TransformersCheckpoint -Hugging Face Accelerate -*********************** - -.. autosummary:: - :toctree: doc/ - - ~train.huggingface.AccelerateTrainer +More Frameworks +--------------- Tensorflow/Keras ~~~~~~~~~~~~~~~~ @@ -107,21 +69,8 @@ Tensorflow/Keras ~train.tensorflow.TensorflowTrainer ~train.tensorflow.TensorflowConfig - ~train.tensorflow.TensorflowCheckpoint - - -Tensorflow/Keras Training Loop Utilities -**************************************** - -.. autosummary:: - :toctree: doc/ - ~train.tensorflow.prepare_dataset_shard - -.. autosummary:: - - ~air.integrations.keras.ReportCheckpointCallback - + ~train.tensorflow.keras.ReportCheckpointCallback Horovod ~~~~~~~ @@ -140,7 +89,6 @@ XGBoost :toctree: doc/ ~train.xgboost.XGBoostTrainer - ~train.xgboost.XGBoostCheckpoint LightGBM @@ -150,32 +98,42 @@ LightGBM :toctree: doc/ ~train.lightgbm.LightGBMTrainer - ~train.lightgbm.LightGBMCheckpoint .. _ray-train-configs-api: -Ray Train Config ----------------- +Ray Train Configuration +----------------------- .. autosummary:: :toctree: doc/ - ~train.ScalingConfig - ~train.RunConfig ~train.CheckpointConfig - ~train.FailureConfig ~train.DataConfig + ~train.FailureConfig + ~train.RunConfig + ~train.ScalingConfig + ~train.SyncConfig .. _train-loop-api: -Ray Train Loop --------------- +Ray Train Utilities +------------------- + +**Classes** .. autosummary:: :toctree: doc/ + ~train.Checkpoint ~train.context.TrainContext + +**Functions** + +.. autosummary:: + :toctree: doc/ + + ~train.get_checkpoint ~train.get_context ~train.get_dataset_shard ~train.report @@ -190,14 +148,9 @@ Ray Train Output ~train.Result -.. autosummary:: - :toctree: doc/ - - ~train.Checkpoint - -Ray Train Base Classes (Developer APIs) ---------------------------------------- +Ray Train Developer APIs +------------------------ .. _train-base-trainer: diff --git a/doc/source/train/distributed-tensorflow-keras.rst b/doc/source/train/distributed-tensorflow-keras.rst index 485b4b7212ac7..e13e780ef94aa 100644 --- a/doc/source/train/distributed-tensorflow-keras.rst +++ b/doc/source/train/distributed-tensorflow-keras.rst @@ -191,11 +191,11 @@ to Ray Train. This reporting logs the results to the console output and appends local log files. The logging also triggers :ref:`checkpoint bookkeeping `. The easiest way to report your results with Keras is by using the -:class:`~air.integrations.keras.ReportCheckpointCallback`: +:class:`~ray.train.tensorflow.keras.ReportCheckpointCallback`: .. code-block:: python - from ray.air.integrations.keras import ReportCheckpointCallback + from ray.train.tensorflow.keras import ReportCheckpointCallback def train_func(config: dict): # ... diff --git a/doc/source/tune/api/syncing.rst b/doc/source/tune/api/syncing.rst index 36517271c65ed..7233cb0277de4 100644 --- a/doc/source/tune/api/syncing.rst +++ b/doc/source/tune/api/syncing.rst @@ -12,6 +12,6 @@ Tune Syncing Configuration -------------------------- .. autosummary:: - :toctree: doc/ ray.train.SyncConfig + :noindex: diff --git a/python/ray/air/config.py b/python/ray/air/config.py index cea3012c02a71..c0b04a921ef84 100644 --- a/python/ray/air/config.py +++ b/python/ray/air/config.py @@ -90,7 +90,7 @@ def _repr_dataclass(obj, *, default_values: Optional[Dict[str, Any]] = None) -> @dataclass -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") class ScalingConfig: """Configuration for scaling training. @@ -112,13 +112,6 @@ class ScalingConfig: placement_strategy: The placement strategy to use for the placement group of the Ray actors. See :ref:`Placement Group Strategies ` for the possible options. - _max_cpu_fraction_per_node: [Experimental] The max fraction of CPUs per node - that Train will use for scheduling training actors. The remaining CPUs - can be used for dataset tasks. It is highly recommended that you set this - to less than 1.0 (e.g., 0.8) when passing datasets to trainers, to avoid - hangs / CPU starvation of dataset tasks. Warning: this feature is - experimental and is not recommended for use with autoscaling (scale-up will - not trigger properly). """ # If adding new attributes here, please also update @@ -523,7 +516,7 @@ def _merge(self, other: "DatasetConfig") -> "DatasetConfig": @dataclass -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") class FailureConfig: """Configuration related to failure handling of each training/tuning run. @@ -574,7 +567,7 @@ def _repr_html_(self): @dataclass -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") class CheckpointConfig: """Configurable parameters for defining the checkpointing strategy. @@ -723,7 +716,7 @@ def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]: @dataclass -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") class RunConfig: """Runtime configuration for training and tuning runs. @@ -734,32 +727,32 @@ class RunConfig: Args: name: Name of the trial or experiment. If not provided, will be deduced from the Trainable. - storage_path: Path to store results at. Can be a local directory or + storage_path: [Beta] Path to store results at. Can be a local directory or a destination on cloud storage. If Ray storage is set up, defaults to the storage location. Otherwise, this defaults to the local ``~/ray_results`` directory. + failure_config: Failure mode configuration. + checkpoint_config: Checkpointing configuration. + sync_config: Configuration object for syncing. See train.SyncConfig. + verbose: 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = default, 2 = verbose. Defaults to 1. + If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set, + uses the old verbosity settings: + 0 = silent, 1 = only status updates, 2 = status and brief + results, 3 = status and detailed results. stop: Stop conditions to consider. Refer to ray.tune.stopper.Stopper for more info. Stoppers should be serializable. - callbacks: Callbacks to invoke. + callbacks: [DeveloperAPI] Callbacks to invoke. Refer to ray.tune.callback.Callback for more info. Callbacks should be serializable. Currently only stateless callbacks are supported for resumed runs. (any state of the callback will not be checkpointed by Tune and thus will not take effect in resumed runs). - failure_config: Failure mode configuration. - sync_config: Configuration object for syncing. See train.SyncConfig. - checkpoint_config: Checkpointing configuration. - progress_reporter: Progress reporter for reporting + progress_reporter: [DeveloperAPI] Progress reporter for reporting intermediate experiment progress. Defaults to CLIReporter if running in command-line, or JupyterNotebookReporter if running in a Jupyter notebook. - verbose: 0, 1, or 2. Verbosity mode. - 0 = silent, 1 = default, 2 = verbose. Defaults to 1. - If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set, - uses the old verbosity settings: - 0 = silent, 1 = only status updates, 2 = status and brief - results, 3 = status and detailed results. - log_to_file: Log stdout and stderr to files in + log_to_file: [DeveloperAPI] Log stdout and stderr to files in trial directories. If this is `False` (default), no files are written. If `true`, outputs are written to `trialdir/stdout` and `trialdir/stderr`, respectively. If this is a single string, @@ -773,13 +766,13 @@ class RunConfig: name: Optional[str] = None storage_path: Optional[str] = None storage_filesystem: Optional[pyarrow.fs.FileSystem] = None - callbacks: Optional[List["Callback"]] = None - stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None failure_config: Optional[FailureConfig] = None - sync_config: Optional["SyncConfig"] = None checkpoint_config: Optional[CheckpointConfig] = None - progress_reporter: Optional["ProgressReporter"] = None + sync_config: Optional["SyncConfig"] = None verbose: Optional[Union[int, "AirVerbosity", "Verbosity"]] = None + stop: Optional[Union[Mapping, "Stopper", Callable[[str, Mapping], bool]]] = None + callbacks: Optional[List["Callback"]] = None + progress_reporter: Optional["ProgressReporter"] = None log_to_file: Union[bool, str, Tuple[str, str]] = False # Deprecated diff --git a/python/ray/air/result.py b/python/ray/air/result.py index 43d19793839ce..cecf235ba947f 100644 --- a/python/ray/air/result.py +++ b/python/ray/air/result.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") @dataclass class Result: """The final result of a ML training run or a Tune trial. diff --git a/python/ray/train/__init__.py b/python/ray/train/__init__.py index 85d2c85553203..bef2f190a02eb 100644 --- a/python/ray/train/__init__.py +++ b/python/ray/train/__init__.py @@ -52,3 +52,18 @@ "TrainingIterator", "TRAIN_DATASET_KEY", ] + +get_checkpoint.__module__ = "ray.train" +get_context.__module__ = "ray.train" +get_dataset_shard.__module__ = "ray.train" +report.__module__ = "ray.train" +BackendConfig.__module__ = "ray.train" +Checkpoint.__module__ = "ray.train" +CheckpointConfig.__module__ = "ray.train" +DataConfig.__module__ = "ray.train" +FailureConfig.__module__ = "ray.train" +Result.__module__ = "ray.train" +RunConfig.__module__ = "ray.train" +ScalingConfig.__module__ = "ray.train" +SyncConfig.__module__ = "ray.train" +TrainingIterator.__module__ = "ray.train" diff --git a/python/ray/train/_internal/data_config.py b/python/ray/train/_internal/data_config.py index f1721a42a244c..59a6b19d8d7b0 100644 --- a/python/ray/train/_internal/data_config.py +++ b/python/ray/train/_internal/data_config.py @@ -24,7 +24,7 @@ from typing_extensions import Literal -@PublicAPI +@PublicAPI(stability="stable") class DataConfig: """Class responsible for configuring Train dataset preprocessing. diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index a7fcf91966434..20b27d82d5332 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -846,7 +846,7 @@ def wrapper(*args, **kwargs): return inner -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") @_warn_session_misuse() def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: """Report metrics and optionally save a checkpoint. @@ -907,7 +907,7 @@ def train_func(): _get_session().report(metrics, checkpoint=checkpoint) -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the session's last checkpoint to resume from if applicable. @@ -1223,7 +1223,7 @@ def train_loop_per_worker(): return session.node_rank -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, diff --git a/python/ray/train/_internal/syncer.py b/python/ray/train/_internal/syncer.py index 2ee1b43b80de9..cac5e86de8a00 100644 --- a/python/ray/train/_internal/syncer.py +++ b/python/ray/train/_internal/syncer.py @@ -40,7 +40,7 @@ DEFAULT_SYNC_TIMEOUT = 1800 -@PublicAPI +@PublicAPI(stability="stable") @dataclass class SyncConfig: """Configuration object for Train/Tune file syncing to `RunConfig(storage_path)`. @@ -68,7 +68,7 @@ class SyncConfig: sync_timeout: Maximum time in seconds to wait for a sync process to finish running. A sync operation will run for at most this long before raising a `TimeoutError`. Defaults to 30 minutes. - sync_artifacts: Whether or not to sync artifacts that are saved to the + sync_artifacts: [Beta] Whether or not to sync artifacts that are saved to the trial directory (accessed via `train.get_context().get_trial_dir()`) to the persistent storage configured via `train.RunConfig(storage_path)`. The trial or remote worker will try to launch an artifact syncing @@ -81,12 +81,12 @@ class SyncConfig: Defaults to True. """ - upload_dir: Optional[str] = _DEPRECATED_VALUE - syncer: Optional[Union[str, "Syncer"]] = _DEPRECATED_VALUE sync_period: int = DEFAULT_SYNC_PERIOD sync_timeout: int = DEFAULT_SYNC_TIMEOUT sync_artifacts: bool = False sync_artifacts_on_checkpoint: bool = True + upload_dir: Optional[str] = _DEPRECATED_VALUE + syncer: Optional[Union[str, "Syncer"]] = _DEPRECATED_VALUE sync_on_checkpoint: bool = _DEPRECATED_VALUE def _deprecation_warning(self, attr_name: str, extra_msg: str): diff --git a/python/ray/train/context.py b/python/ray/train/context.py index 565717d723dd8..0503ff73d0b5f 100644 --- a/python/ray/train/context.py +++ b/python/ray/train/context.py @@ -23,7 +23,7 @@ def wrapped(func): return wrapped -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") class TrainContext: """Context for Ray training executions.""" @@ -77,7 +77,7 @@ def get_storage(self) -> StorageContext: return session.get_storage() -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") def get_context() -> TrainContext: """Get or create a singleton training context. diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index 7139b1c84649e..da0e3e383b6b9 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -21,7 +21,7 @@ from ray.train._internal.utils import construct_train_func from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY from ray.train.trainer import BaseTrainer, GenDataset -from ray.util.annotations import DeveloperAPI +from ray.util.annotations import DeveloperAPI, PublicAPI from ray.widgets import Template from ray.widgets.util import repr_with_fallback @@ -333,6 +333,7 @@ def __init__( resume_from_checkpoint=resume_from_checkpoint, ) + @PublicAPI(stability="beta") @classmethod def restore( cls: Type["DataParallelTrainer"], diff --git a/python/ray/train/huggingface/accelerate/accelerate_trainer.py b/python/ray/train/huggingface/accelerate/accelerate_trainer.py index c62a2ded49950..06d15dc331b8a 100644 --- a/python/ray/train/huggingface/accelerate/accelerate_trainer.py +++ b/python/ray/train/huggingface/accelerate/accelerate_trainer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Tuple, Union from ray import train +from ray.util.annotations import Deprecated from ray.train import Checkpoint, RunConfig, ScalingConfig from ray.train import DataConfig from ray.train.torch import TorchConfig @@ -34,6 +35,15 @@ from ray.tune.trainable import Trainable +ACCELERATE_TRAINER_DEPRECATION_MESSAGE = ( + "The AccelerateTrainer will be hard deprecated in Ray 2.8. " + "Use TorchTrainer instead. " + "See https://docs.ray.io/en/releases-2.7.0/train/huggingface-accelerate.html#acceleratetrainer-migration-guide " # noqa: E501 + "for more details." +) + + +@Deprecated(message=ACCELERATE_TRAINER_DEPRECATION_MESSAGE, warning=True) class AccelerateTrainer(TorchTrainer): """A Trainer for data parallel HuggingFace Accelerate training with PyTorch. diff --git a/python/ray/train/huggingface/transformers/_transformers_utils.py b/python/ray/train/huggingface/transformers/_transformers_utils.py index 1b7b11dbca605..a5e50570d242f 100644 --- a/python/ray/train/huggingface/transformers/_transformers_utils.py +++ b/python/ray/train/huggingface/transformers/_transformers_utils.py @@ -233,7 +233,7 @@ def on_train_end(self, args, state, control, **kwargs): self._report() -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") class RayTrainReportCallback(TrainerCallback): """A simple callback to report checkpoints and metrics to Ray Tarin. @@ -297,7 +297,7 @@ def __iter__(self) -> Iterator: return iter(self.data_iterable) -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") def prepare_trainer(trainer: "Trainer") -> "Trainer": """Prepare your HuggingFace Transformer Trainer for Ray Train. diff --git a/python/ray/train/huggingface/transformers/transformers_trainer.py b/python/ray/train/huggingface/transformers/transformers_trainer.py index 50b995fda1b10..81072367334da 100644 --- a/python/ray/train/huggingface/transformers/transformers_trainer.py +++ b/python/ray/train/huggingface/transformers/transformers_trainer.py @@ -20,7 +20,7 @@ from ray.train.data_parallel_trainer import DataParallelTrainer from ray.train.torch import TorchConfig, TorchTrainer from ray.train.trainer import GenDataset -from ray.util import PublicAPI +from ray.util.annotations import Deprecated TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None @@ -72,8 +72,15 @@ TRAINER_INIT_FN_KEY = "_trainer_init_per_worker" +TRANSFORMERS_TRAINER_DEPRECATION_MESSAGE = ( + "The TransformersTransformers will be hard deprecated in Ray 2.8. " + "Use TorchTrainer instead. " + "See https://docs.ray.io/en/releases-2.7.0/train/getting-started-transformers.html#transformerstrainer-migration-guide " # noqa: E501 + "for more details." +) + -@PublicAPI(stability="alpha") +@Deprecated(message=TRANSFORMERS_TRAINER_DEPRECATION_MESSAGE, warning=True) class TransformersTrainer(TorchTrainer): """A Trainer for data parallel HuggingFace Transformers on PyTorch training. diff --git a/python/ray/train/lightning/_lightning_utils.py b/python/ray/train/lightning/_lightning_utils.py index 5f6d5b1e4a55c..4cb2432e67c2b 100644 --- a/python/ray/train/lightning/_lightning_utils.py +++ b/python/ray/train/lightning/_lightning_utils.py @@ -56,7 +56,7 @@ def get_worker_root_device(): return devices -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") class RayDDPStrategy(DDPStrategy): """Subclass of DDPStrategy to ensure compatibility with Ray orchestration. @@ -80,7 +80,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]: ) -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") class RayFSDPStrategy(FSDPStrategy): """Subclass of FSDPStrategy to ensure compatibility with Ray orchestration. @@ -123,7 +123,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: return super().lightning_module_state_dict() -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") class RayDeepSpeedStrategy(DeepSpeedStrategy): """Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration. @@ -147,7 +147,7 @@ def distributed_sampler_kwargs(self) -> Dict[str, Any]: ) -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") class RayLightningEnvironment(LightningEnvironment): """Setup Lightning DDP training environment for Ray cluster.""" @@ -179,7 +179,7 @@ def teardown(self): pass -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer: """Prepare the PyTorch Lightning Trainer for distributed execution.""" @@ -209,7 +209,7 @@ def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer: return trainer -@PublicAPI(stability="alpha") +@PublicAPI(stability="beta") class RayTrainReportCallback(Callback): """A simple callback that reports checkpoints to Ray on train epoch end.""" diff --git a/python/ray/train/lightning/lightning_trainer.py b/python/ray/train/lightning/lightning_trainer.py index ddcb1158a0002..e4ef7a9a67a55 100644 --- a/python/ray/train/lightning/lightning_trainer.py +++ b/python/ray/train/lightning/lightning_trainer.py @@ -14,7 +14,7 @@ from ray.train.trainer import GenDataset from ray.train.torch import TorchTrainer from ray.train.torch.config import TorchConfig -from ray.util import PublicAPI +from ray.util.annotations import Deprecated from ray.train.lightning._lightning_utils import ( RayDDPStrategy, RayFSDPStrategy, @@ -31,7 +31,15 @@ logger = logging.getLogger(__name__) -@PublicAPI(stability="alpha") +LIGHTNING_CONFIG_BUILDER_DEPRECATION_MESSAGE = ( + "The LightningConfigBuilder will be hard deprecated in Ray 2.8. " + "Use TorchTrainer instead. " + "See https://docs.ray.io/en/releases-2.7.0/train/getting-started-pytorch-lightning.html#lightningtrainer-migration-guide " # noqa: E501 + "for more details." +) + + +@Deprecated(message=LIGHTNING_CONFIG_BUILDER_DEPRECATION_MESSAGE, warning=True) class LightningConfigBuilder: """Configuration Class to pass into LightningTrainer. @@ -221,7 +229,15 @@ def build(self) -> Dict["str", Any]: return config_dict -@PublicAPI(stability="alpha") +LIGHTNING_TRAINER_DEPRECATION_MESSAGE = ( + "The LightningTrainer will be hard deprecated in Ray 2.8. " + "Use TorchTrainer instead. " + "See https://docs.ray.io/en/releases-2.7.0/train/getting-started-pytorch-lightning.html#lightningtrainer-migration-guide " # noqa: E501 + "for more details." +) + + +@Deprecated(message=LIGHTNING_TRAINER_DEPRECATION_MESSAGE, warning=True) class LightningTrainer(TorchTrainer): """A Trainer for data parallel PyTorch Lightning training. @@ -399,6 +415,7 @@ def __init__( resume_from_checkpoint: Optional[Checkpoint] = None, metadata: Optional[Dict[str, Any]] = None, ): + run_config = copy(run_config) or RunConfig() lightning_config = lightning_config or LightningConfigBuilder().build() diff --git a/python/ray/train/tensorflow/keras.py b/python/ray/train/tensorflow/keras.py new file mode 100644 index 0000000000000..3594779c8db18 --- /dev/null +++ b/python/ray/train/tensorflow/keras.py @@ -0,0 +1,3 @@ +from ray.air.integrations.keras import ReportCheckpointCallback + +ReportCheckpointCallback.__module__ = "ray.train.tensorflow.keras" diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index 727f8ebe4199d..278c984ac3cc7 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") @dataclass class TorchConfig(BackendConfig): """Configuration for torch process group setup. diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index afe1500c89a54..e61d87e3386fd 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -12,7 +12,7 @@ from ray.data.preprocessor import Preprocessor -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") class TorchTrainer(DataParallelTrainer): """A Trainer for data parallel PyTorch training. diff --git a/python/ray/train/torch/train_loop_utils.py b/python/ray/train/torch/train_loop_utils.py index d026bcc766cce..34a94266b1776 100644 --- a/python/ray/train/torch/train_loop_utils.py +++ b/python/ray/train/torch/train_loop_utils.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") def get_device() -> Union[torch.device, List[torch.device]]: """Gets the correct torch device configured for this process. @@ -70,8 +70,7 @@ def get_device() -> Union[torch.device, List[torch.device]]: return torch_utils.get_device() -# TODO: Deprecation: Hard-deprecate args in Ray 2.2. -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") def prepare_model( model: torch.nn.Module, move_to_device: Union[bool, torch.device] = True, @@ -113,7 +112,7 @@ def prepare_model( ) -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") def prepare_data_loader( data_loader: torch.utils.data.DataLoader, add_dist_sampler: bool = True, @@ -192,7 +191,7 @@ def backward(tensor: torch.Tensor) -> None: get_accelerator(_TorchAccelerator).backward(tensor) -@PublicAPI(stability="beta") +@PublicAPI(stability="stable") def enable_reproducibility(seed: int = 0) -> None: """Limits sources of nondeterministic behavior.