From f33b8ebba683ed58e47ff540bf120e7d9a07d942 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 7 Sep 2023 13:51:52 +0200 Subject: [PATCH] [train] Fix issues in migration of tune_cifar_torch_pbt_example (#39158) Resolves three issues that come up when migrating the `tune_cifar_torch_pbt_example` from Ray 2.6 to Ray 2.7: 1. There is a warning message because PBT uses the `_schedule_trial_save` interface. This is added to the white list attributes so it doesn't come up anymore. 2. PBT malfunctions in Python 2.7, so instead of silently failing, we raise an error and ask users to migrate 3. When users use old `ray.air.Checkpoint` APIs on `ray.train.Checkpoint`, we should raise an actionable error message. Signed-off-by: Kai Fricke --- python/ray/train/_checkpoint.py | 74 +++++++++++++++++++- python/ray/tune/execution/tune_controller.py | 1 + python/ray/tune/schedulers/pbt.py | 11 +++ 3 files changed, 85 insertions(+), 1 deletion(-) diff --git a/python/ray/train/_checkpoint.py b/python/ray/train/_checkpoint.py index 2d8f6a72e0314..a82959992ace4 100644 --- a/python/ray/train/_checkpoint.py +++ b/python/ray/train/_checkpoint.py @@ -26,8 +26,33 @@ _CHECKPOINT_TEMP_DIR_PREFIX = "checkpoint_tmp_" +class _CheckpointMetaClass(type): + def __getattr__(self, item): + try: + return super().__getattribute__(item) + except AttributeError as exc: + if item in { + "from_dict", + "to_dict", + "from_bytes", + "to_bytes", + "get_internal_representation", + }: + raise _get_migration_error(item) from exc + elif item in { + "from_uri", + "to_uri", + "uri", + }: + raise _get_uri_error(item) from exc + elif item in {"get_preprocessor", "set_preprocessor"}: + raise _get_preprocessor_error(item) from exc + + raise exc + + @PublicAPI(stability="beta") -class Checkpoint: +class Checkpoint(metaclass=_CheckpointMetaClass): """A reference to data persisted as a directory in local or remote storage. Access checkpoint contents locally using ``checkpoint.to_directory()``. @@ -301,3 +326,50 @@ def _list_existing_del_locks(path: str) -> List[str]: then this should return a list of 2 deletion lock files. """ return list(glob.glob(f"{_get_del_lock_path(path, suffix='*')}")) + + +def _get_migration_error(name: str): + return AttributeError( + f"The new `ray.train.Checkpoint` class does not support `{name}()`. " + f"Instead, only directories are supported.\n\n" + f"Example to store a dictionary in a checkpoint:\n\n" + f"import os, tempfile\n" + f"import ray.cloudpickle as pickle\n" + f"from ray import train\n" + f"from ray.train import Checkpoint\n\n" + f"with tempfile.TemporaryDirectory() as checkpoint_dir:\n" + f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'wb') as fp:\n" + f" pickle.dump({{'data': 'value'}}, fp)\n\n" + f" checkpoint = Checkpoint.from_directory(checkpoint_dir)\n" + f" train.report(..., checkpoint=checkpoint)\n\n" + f"Example to load a dictionary from a checkpoint:\n\n" + f"if train.get_checkpoint():\n" + f" with train.get_checkpoint().as_directory() as checkpoint_dir:\n" + f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'rb') as fp:\n" + f" data = pickle.load(fp)" + ) + + +def _get_uri_error(name: str): + return AttributeError( + f"The new `ray.train.Checkpoint` class does not support `{name}()`. " + f"To create a checkpoint from remote storage, create a `Checkpoint` using its " + f"constructor instead of `from_directory`.\n" + f'Example: `Checkpoint(path="s3://a/b/c")`.\n' + f"Then, access the contents of the checkpoint with " + f"`checkpoint.as_directory()` / `checkpoint.to_directory()`.\n" + f"To upload data to remote storage, use e.g. `pyarrow.fs.FileSystem` " + f"or your client of choice." + ) + + +def _get_preprocessor_error(name: str): + return AttributeError( + f"The new `ray.train.Checkpoint` class does not support `{name}()`. " + f"To include preprocessor information in checkpoints, " + f"pass it as metadata in the Trainer constructor.\n" + f"Example: `TorchTrainer(..., metadata={{...}})`.\n" + f"After training, access it in the checkpoint via `checkpoint.get_metadata()`. " + f"See here: https://docs.ray.io/en/master/train/user-guides/" + f"data-loading-preprocessing.html#preprocessing-structured-data" + ) diff --git a/python/ray/tune/execution/tune_controller.py b/python/ray/tune/execution/tune_controller.py index 7344e73c851df..61241ca03cbab 100644 --- a/python/ray/tune/execution/tune_controller.py +++ b/python/ray/tune/execution/tune_controller.py @@ -351,6 +351,7 @@ def _wrapped(self): "_set_trial_status", "pause_trial", "stop_trial", + "_schedule_trial_save", }, executor_whitelist_attr={ "has_resources_for_trial", diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 0d8679538fa92..8b68e418de399 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -361,6 +361,17 @@ def __init__( require_attrs: bool = True, synch: bool = False, ): + if not _use_storage_context(): + raise RuntimeError( + "Due to breaking API changes, PBT does not work with the old " + "persistence mode (enabled via RAY_AIR_NEW_PERSISTENCE_MODE=0). " + "Migrate your script to use the new APIs instead and disabled the " + "environment variable flag. " + "See this migration guide: " + "https://docs.google.com/document/d/" + "1J-09US8cXc-tpl2A1BpOrlHLTEDMdIJp6Ah1ifBUw7Y/view" + ) + hyperparam_mutations = hyperparam_mutations or {} for value in hyperparam_mutations.values(): if not isinstance(value, (dict, list, tuple, Domain, Callable)):