Skip to content

Commit

Permalink
[part 1/2] [train] Add metadata argument to Trainer (ray-project#38481)
Browse files Browse the repository at this point in the history
Signed-off-by: Victor <vctr.y.m@example.com>
  • Loading branch information
ericl authored and Victor committed Oct 11, 2023
1 parent f44e241 commit 34e3182
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 20 deletions.
6 changes: 5 additions & 1 deletion python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Any

import ray
from ray.data import Dataset
Expand Down Expand Up @@ -346,6 +346,7 @@ def start_training(
self,
train_func: Callable[[], T],
datasets: Dict[str, Dataset],
metadata: Dict[str, Any],
data_config: DataConfig,
checkpoint: Optional[Checkpoint] = None,
on_session_init: Callable[[], None] = None,
Expand Down Expand Up @@ -379,6 +380,7 @@ def initialize_session(
trial_info,
checkpoint,
dataset_shard,
metadata,
encode_data_fn,
checkpoint_keep_all_ranks,
checkpoint_upload_from_workers,
Expand All @@ -394,6 +396,7 @@ def initialize_session(
world_size=world_size,
trial_info=trial_info,
dataset_shard=dataset_shard,
metadata=metadata,
checkpoint=checkpoint,
encode_data_fn=encode_data_fn,
detailed_autofilled_metrics=use_detailed_autofilled_metrics,
Expand Down Expand Up @@ -446,6 +449,7 @@ def initialize_session(
trial_info=self._trial_info,
train_func=train_func,
dataset_shard=self.dataset_shards[index],
metadata=metadata,
checkpoint=checkpoint,
encode_data_fn=self._backend._encode_data,
checkpoint_keep_all_ranks=self._checkpoint_keep_all_ranks,
Expand Down
20 changes: 20 additions & 0 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
# TODO(xwjiang): Legacy Ray Train trainer clean up!
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
metadata: Dict[str, Any] = None,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Checkpoint] = None,
# Deprecated
Expand Down Expand Up @@ -136,6 +137,8 @@ def __init__(

# Ray Train worker properties
self.dataset_shard = dataset_shard
self.metadata = metadata

self.world_rank = world_rank
self.local_rank = local_rank
self.node_rank = node_rank
Expand Down Expand Up @@ -555,6 +558,16 @@ def new_report(

metrics = self._auto_fill_metrics(metrics)

# Set additional user metadata from the Trainer.
if persisted_checkpoint and self.metadata:
user_metadata = persisted_checkpoint.get_metadata()
for k, v in self.metadata.items():
# Update keys not already set by the user. This gives user-set keys
# precedence over keys set at the Trainer level.
if k not in user_metadata:
user_metadata[k] = v
persisted_checkpoint.set_metadata(user_metadata)

result = _TrainingResult(
checkpoint=persisted_checkpoint,
metrics=metrics,
Expand Down Expand Up @@ -837,6 +850,13 @@ def train_func():
return _get_session().loaded_checkpoint


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_metadata() -> Dict[str, Any]:
"""User metadata dict passed to the Trainer constructor."""
return _get_session().metadata


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_experiment_name() -> str:
Expand Down
24 changes: 24 additions & 0 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import copy
import inspect
import json
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -170,6 +171,9 @@ def training_loop(self):
run_config: Configuration for the execution of the training run.
datasets: Any Datasets to use for training. Use the key "train"
to denote which dataset is the training dataset.
metadata: Dict that should be made available via
`train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
resume_from_checkpoint: A checkpoint to resume training from.
"""

Expand All @@ -190,6 +194,7 @@ def __init__(
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
Expand All @@ -198,6 +203,7 @@ def __init__(
scaling_config if scaling_config is not None else ScalingConfig()
)
self.run_config = run_config if run_config is not None else RunConfig()
self.metadata = metadata
self.datasets = datasets if datasets is not None else {}
self.preprocessor = preprocessor
self.starting_checkpoint = resume_from_checkpoint
Expand Down Expand Up @@ -468,6 +474,19 @@ def _validate_attributes(self):
"`ray.data.Dataset`. "
f"Received {dataset} instead."
)
# Metadata.
self.metadata = self.metadata or {}
if not isinstance(self.metadata, dict):
raise TypeError(
f"The provided metadata must be a dict, was {type(self.metadata)}."
)
try:
self.metadata = json.loads(json.dumps(self.metadata))
except Exception as e:
raise ValueError(
"The provided metadata must be JSON-serializable: "
f"{self.metadata}: {e}"
)

# Preprocessor
if self.preprocessor is not None and not isinstance(
Expand Down Expand Up @@ -724,8 +743,13 @@ def _generate_trainable_cls(self) -> Type["Trainable"]:
trainer_cls = self.__class__
scaling_config = self.scaling_config
restored = bool(self._restore_path)
metadata = self.metadata

def train_func(config):
assert metadata is not None, metadata
# Propagate user metadata from the Trainer constructor.
session._get_session().metadata = metadata

# config already contains merged values.
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Dict, Any

from ray.train._internal import session
from ray.util.annotations import PublicAPI
Expand All @@ -26,6 +26,10 @@ def wrapped(func):
class TrainContext:
"""Context for Ray training executions."""

@_copy_doc(session.get_metadata)
def get_metadata(self) -> Dict[str, Any]:
return session.get_metadata()

@_copy_doc(session.get_experiment_name)
def get_experiment_name(self) -> str:
return session.get_experiment_name()
Expand Down
6 changes: 6 additions & 0 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def __init__(self, train_loop_per_worker, my_backend_config:
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
metadata: Dict that should be made available via
`train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
preprocessor: A ray.data.Preprocessor to preprocess the
provided datasets.
resume_from_checkpoint: A checkpoint to resume training from.
Expand Down Expand Up @@ -270,6 +273,7 @@ def __init__(
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
Expand Down Expand Up @@ -328,6 +332,7 @@ def __init__(
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
metadata=metadata,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
Expand Down Expand Up @@ -560,6 +565,7 @@ def clear_lazy_checkpoint_marker():
backend_config=self._backend_config,
train_func=train_loop_per_worker,
datasets=self.datasets,
metadata=self.metadata,
data_config=self._data_config,
checkpoint_manager=checkpoint_manager,
checkpoint=self.starting_checkpoint,
Expand Down
44 changes: 26 additions & 18 deletions python/ray/train/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_start(ray_start_2_cpus):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
with pytest.raises(InactiveWorkerGroupError):
e.start_training(lambda: 1, datasets={}, data_config=DataConfig())
e.start_training(lambda: 1, datasets={}, data_config=DataConfig(), metadata={})
e.start()
assert len(e.worker_group) == 2

Expand All @@ -105,7 +105,7 @@ def check():

return os.getenv("TEST", "0")

e.start_training(check, datasets={}, data_config=DataConfig())
e.start_training(check, datasets={}, data_config=DataConfig(), metadata={})
assert e.finish_training() == ["1", "1"]


Expand All @@ -116,15 +116,15 @@ def test_shutdown(ray_start_2_cpus):
assert len(e.worker_group) == 2
e.shutdown()
with pytest.raises(InactiveWorkerGroupError):
e.start_training(lambda: 1, datasets={}, data_config=DataConfig())
e.start_training(lambda: 1, datasets={}, data_config=DataConfig(), metadata={})


def test_train(ray_start_2_cpus):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()

e.start_training(lambda: 1, datasets={}, data_config=DataConfig())
e.start_training(lambda: 1, datasets={}, data_config=DataConfig(), metadata={})
assert e.finish_training() == [1, 1]


Expand All @@ -136,7 +136,7 @@ def test_local_ranks(ray_start_2_cpus):
def train_func():
return train.get_context().get_local_rank()

e.start_training(train_func, datasets={}, data_config=DataConfig())
e.start_training(train_func, datasets={}, data_config=DataConfig(), metadata={})
assert set(e.finish_training()) == {0, 1}


Expand All @@ -149,7 +149,7 @@ def test_local_world_size(ray_2_node_2_cpu):
def train_func():
return train.get_context().get_local_world_size()

e.start_training(train_func, datasets={}, data_config=DataConfig())
e.start_training(train_func, datasets={}, data_config=DataConfig(), metadata={})
assert list(e.finish_training()) == [2, 2, 1]


Expand All @@ -162,7 +162,7 @@ def test_node_ranks(ray_2_node_2_cpu):
def train_func():
return train.get_context().get_node_rank()

e.start_training(train_func, datasets={}, data_config=DataConfig())
e.start_training(train_func, datasets={}, data_config=DataConfig(), metadata={})
assert list(e.finish_training()) == [0, 0, 1]


Expand All @@ -183,10 +183,10 @@ def test_train_failure(ray_start_2_cpus):
e.finish_training()
assert isinstance(exc.value.__cause__, TrainBackendError)

e.start_training(lambda: 1, datasets={}, data_config=DataConfig())
e.start_training(lambda: 1, datasets={}, data_config=DataConfig(), metadata={})

with pytest.raises(StartTraceback) as exc:
e.start_training(lambda: 2, datasets={}, data_config=DataConfig())
e.start_training(lambda: 2, datasets={}, data_config=DataConfig(), metadata={})
assert isinstance(exc.value.__cause__, TrainBackendError)

assert e.finish_training() == [1, 1]
Expand All @@ -204,7 +204,9 @@ def single_worker_fail():
else:
time.sleep(1000000)

e.start_training(single_worker_fail, datasets={}, data_config=DataConfig())
e.start_training(
single_worker_fail, datasets={}, data_config=DataConfig(), metadata={}
)

with pytest.raises(StartTraceback) as exc:
e.get_next_results()
Expand All @@ -222,7 +224,9 @@ def train_fail():
new_execute_func = gen_execute_special(train_fail)
with patch.object(WorkerGroup, "execute_async", new_execute_func):
with pytest.raises(TrainingWorkerError):
e.start_training(lambda: 1, datasets={}, data_config=DataConfig())
e.start_training(
lambda: 1, datasets={}, data_config=DataConfig(), metadata={}
)
e.finish_training()


Expand All @@ -238,7 +242,7 @@ def get_tf_config():

return json.loads(os.environ["TF_CONFIG"])

e.start_training(get_tf_config, datasets={}, data_config=DataConfig())
e.start_training(get_tf_config, datasets={}, data_config=DataConfig(), metadata={})
results = e.finish_training()
assert len(results) == num_workers

Expand All @@ -263,12 +267,16 @@ def check_process_group():
and torch.distributed.get_world_size() == 2
)

e.start_training(check_process_group, datasets={}, data_config=DataConfig())
e.start_training(
check_process_group, datasets={}, data_config=DataConfig(), metadata={}
)
assert all(e.finish_training())

e._backend.on_shutdown(e.worker_group, e._backend_config)

e.start_training(check_process_group, datasets={}, data_config=DataConfig())
e.start_training(
check_process_group, datasets={}, data_config=DataConfig(), metadata={}
)
assert not any(e.finish_training())


Expand Down Expand Up @@ -305,7 +313,7 @@ def get_resources():
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=1
)
e.start()
e.start_training(get_resources, datasets={}, data_config=DataConfig())
e.start_training(get_resources, datasets={}, data_config=DataConfig(), metadata={})
results = e.finish_training()
results.sort()
assert results == expected_results
Expand Down Expand Up @@ -351,7 +359,7 @@ def get_resources():
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=0.5
)
e.start()
e.start_training(get_resources, datasets={}, data_config=DataConfig())
e.start_training(get_resources, datasets={}, data_config=DataConfig(), metadata={})
results = e.finish_training()
results.sort()
assert results == expected_results
Expand Down Expand Up @@ -390,7 +398,7 @@ def get_resources():
config, num_workers=num_workers, num_cpus_per_worker=0, num_gpus_per_worker=2
)
e.start()
e.start_training(get_resources, datasets={}, data_config=DataConfig())
e.start_training(get_resources, datasets={}, data_config=DataConfig(), metadata={})
results = e.finish_training()
results.sort()
assert results == expected_results
Expand Down Expand Up @@ -441,7 +449,7 @@ def test():
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
e.start_training(train_func, datasets={}, data_config=DataConfig())
e.start_training(train_func, datasets={}, data_config=DataConfig(), metadata={})
return e.finish_training()

results_future = test.options(
Expand Down
Loading

0 comments on commit 34e3182

Please sign in to comment.