Skip to content

Commit

Permalink
[train] Fix issues in migration of tune_cifar_torch_pbt_example (#39158)
Browse files Browse the repository at this point in the history
Resolves three issues that come up when migrating the `tune_cifar_torch_pbt_example` from Ray 2.6 to Ray 2.7:

1. There is a warning message because PBT uses the `_schedule_trial_save` interface. This is added to the white list attributes so it doesn't come up anymore.
2. PBT malfunctions in Python 2.7, so instead of silently failing, we raise an error and ask users to migrate
3. When users use old `ray.air.Checkpoint` APIs on `ray.train.Checkpoint`, we should raise an actionable error message.

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke committed Sep 7, 2023
1 parent 8b7fcd7 commit f33b8eb
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
74 changes: 73 additions & 1 deletion python/ray/train/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,33 @@
_CHECKPOINT_TEMP_DIR_PREFIX = "checkpoint_tmp_"


class _CheckpointMetaClass(type):
def __getattr__(self, item):
try:
return super().__getattribute__(item)
except AttributeError as exc:
if item in {
"from_dict",
"to_dict",
"from_bytes",
"to_bytes",
"get_internal_representation",
}:
raise _get_migration_error(item) from exc
elif item in {
"from_uri",
"to_uri",
"uri",
}:
raise _get_uri_error(item) from exc
elif item in {"get_preprocessor", "set_preprocessor"}:
raise _get_preprocessor_error(item) from exc

raise exc


@PublicAPI(stability="beta")
class Checkpoint:
class Checkpoint(metaclass=_CheckpointMetaClass):
"""A reference to data persisted as a directory in local or remote storage.
Access checkpoint contents locally using ``checkpoint.to_directory()``.
Expand Down Expand Up @@ -301,3 +326,50 @@ def _list_existing_del_locks(path: str) -> List[str]:
then this should return a list of 2 deletion lock files.
"""
return list(glob.glob(f"{_get_del_lock_path(path, suffix='*')}"))


def _get_migration_error(name: str):
return AttributeError(
f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
f"Instead, only directories are supported.\n\n"
f"Example to store a dictionary in a checkpoint:\n\n"
f"import os, tempfile\n"
f"import ray.cloudpickle as pickle\n"
f"from ray import train\n"
f"from ray.train import Checkpoint\n\n"
f"with tempfile.TemporaryDirectory() as checkpoint_dir:\n"
f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'wb') as fp:\n"
f" pickle.dump({{'data': 'value'}}, fp)\n\n"
f" checkpoint = Checkpoint.from_directory(checkpoint_dir)\n"
f" train.report(..., checkpoint=checkpoint)\n\n"
f"Example to load a dictionary from a checkpoint:\n\n"
f"if train.get_checkpoint():\n"
f" with train.get_checkpoint().as_directory() as checkpoint_dir:\n"
f" with open(os.path.join(checkpoint_dir, 'data.pkl'), 'rb') as fp:\n"
f" data = pickle.load(fp)"
)


def _get_uri_error(name: str):
return AttributeError(
f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
f"To create a checkpoint from remote storage, create a `Checkpoint` using its "
f"constructor instead of `from_directory`.\n"
f'Example: `Checkpoint(path="s3://a/b/c")`.\n'
f"Then, access the contents of the checkpoint with "
f"`checkpoint.as_directory()` / `checkpoint.to_directory()`.\n"
f"To upload data to remote storage, use e.g. `pyarrow.fs.FileSystem` "
f"or your client of choice."
)


def _get_preprocessor_error(name: str):
return AttributeError(
f"The new `ray.train.Checkpoint` class does not support `{name}()`. "
f"To include preprocessor information in checkpoints, "
f"pass it as metadata in the <Framework>Trainer constructor.\n"
f"Example: `TorchTrainer(..., metadata={{...}})`.\n"
f"After training, access it in the checkpoint via `checkpoint.get_metadata()`. "
f"See here: https://docs.ray.io/en/master/train/user-guides/"
f"data-loading-preprocessing.html#preprocessing-structured-data"
)
1 change: 1 addition & 0 deletions python/ray/tune/execution/tune_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def _wrapped(self):
"_set_trial_status",
"pause_trial",
"stop_trial",
"_schedule_trial_save",
},
executor_whitelist_attr={
"has_resources_for_trial",
Expand Down
11 changes: 11 additions & 0 deletions python/ray/tune/schedulers/pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,17 @@ def __init__(
require_attrs: bool = True,
synch: bool = False,
):
if not _use_storage_context():
raise RuntimeError(
"Due to breaking API changes, PBT does not work with the old "
"persistence mode (enabled via RAY_AIR_NEW_PERSISTENCE_MODE=0). "
"Migrate your script to use the new APIs instead and disabled the "
"environment variable flag. "
"See this migration guide: "
"https://docs.google.com/document/d/"
"1J-09US8cXc-tpl2A1BpOrlHLTEDMdIJp6Ah1ifBUw7Y/view"
)

hyperparam_mutations = hyperparam_mutations or {}
for value in hyperparam_mutations.values():
if not isinstance(value, (dict, list, tuple, Domain, Callable)):
Expand Down

0 comments on commit f33b8eb

Please sign in to comment.