Skip to content

Commit

Permalink
[Train] Colocate Trainer and rank 0 worker (#43115)
Browse files Browse the repository at this point in the history
This PR automatically merge the trainer bundle with the rank 0 worker bundle, so that the trainer and rank 0 worker can always colocate on the same node.

### Benefits:
- Enables users to specify additional resources for rank 0 worker.
- Always colocate trainers and rank 0 workers together to make the scheduling behavior deterministic.

### Major changes:
#### 1. Merge trainer bundle and the first worker bundle.

Specifically, we build a placement groups with bundles `[{}, {trainer+worker}, {worker}, ..., {worker}]`, and schedule the `TrainTrainable` with the first non-empty bundle. When assigning worker ranks, we designate the worker with the smallest GPU ID on the same node as the trainer to be rank 0.

#### 2. Set `num_workers=1` by default in `ScalingConfig`.
Previously, setting `num_workers` to `None` resulted launching a single `TrainTrainable` with zero workers. It no longer applies to the current Ray Train, as all Trainers now require at least one worker to execute the `train_func`.

Additionally, this approach led to undefined behaviors during the merging and separation of the first bundle. To ensure the consistent behavior, we have now set the default value of `num_workers` to 1.

#### 3. Forbid using `ScalingConfig` with `tune.with_resources`.

`ScalingConfig` should be a Ray Train only utility and it's should not be used for Tune Trainables. For example, it doesn't make sense to provide ScalingConfig for a function trainable, since there's no trainer and worker concepts for it.


Passed Release Test:https://buildkite.com/ray-project/release/builds/9650#018dee6e-e3ce-4376-9f3d-5ad7e250e513

## Related PRs:
The below two PRs enabled that the actors with empty resources can be launched on the node of a specific bundle in placement group.
- #43269
- #43448

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
  • Loading branch information
woshiyyya committed Feb 28, 2024
1 parent 2946e79 commit 4a73957
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 138 deletions.
14 changes: 0 additions & 14 deletions doc/source/tune/doc_code/faq.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,6 @@ def train_func(config):
tuner.fit()
# __resources_pgf_end__

# __resources_scalingconfig_start__
tuner = tune.Tuner(
tune.with_resources(
train_fn,
resources=ScalingConfig(
trainer_resources={"CPU": 2, "GPU": 0.5, "hdd": 80},
num_workers=2,
resources_per_worker={"CPU": 1},
),
)
)
tuner.fit()
# __resources_scalingconfig_end__

# __resources_lambda_start__
tuner = tune.Tuner(
tune.with_resources(
Expand Down
8 changes: 0 additions & 8 deletions doc/source/tune/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,6 @@ on other nodes as well. Please refer to the
:ref:`placement groups documentation <ray-placement-group-doc-ref>` to learn more
about these placement strategies.

You can also use the :class:`~ray.tune.ScalingConfig` to achieve the same results:

.. literalinclude:: doc_code/faq.py
:dedent:
:language: python
:start-after: __resources_scalingconfig_start__
:end-before: __resources_scalingconfig_end__

You can also allocate specific resources to a trial based on a custom rule via lambda functions.
For instance, if you want to allocate GPU resources to trials based on a setting in your param space:

Expand Down
82 changes: 41 additions & 41 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections import defaultdict
from collections import defaultdict, Counter
from dataclasses import _MISSING_TYPE, dataclass, fields
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -102,8 +102,13 @@ class ScalingConfig:
"""Configuration for scaling training.
Args:
trainer_resources: Resources to allocate for the trainer. If None is provided,
will default to 1 CPU for most trainers.
trainer_resources: Resources to allocate for the training coordinator.
The training coordinator launches the worker group and executes
the training function per worker, and this process does NOT require
GPUs. The coordinator is always scheduled on the same node as the
rank 0 worker, so one example use case is to set a minimum amount
of resources (e.g. CPU memory) required by the rank 0 node.
By default, this assigns 1 CPU to the training coordinator.
num_workers: The number of workers (Ray actors) to launch.
Each worker will reserve 1 CPU by default. The number of CPUs
reserved by each worker can be overridden with the
Expand Down Expand Up @@ -139,7 +144,7 @@ class ScalingConfig:
"""

trainer_resources: Optional[Union[Dict, SampleRange]] = None
num_workers: Optional[Union[int, SampleRange]] = None
num_workers: Union[int, SampleRange] = 1
use_gpu: Union[bool, SampleRange] = False
resources_per_worker: Optional[Union[Dict, SampleRange]] = None
placement_strategy: Union[str, SampleRange] = "PACK"
Expand Down Expand Up @@ -185,7 +190,8 @@ def _resources_per_worker_not_none(self):
resources_per_worker = {
k: v for k, v in self.resources_per_worker.items() if v != 0
}
resources_per_worker.setdefault("GPU", int(self.use_gpu))
if self.use_gpu:
resources_per_worker.setdefault("GPU", 1)
return resources_per_worker

@property
Expand Down Expand Up @@ -216,9 +222,8 @@ def _trainer_resources_not_none(self):
def total_resources(self):
"""Map of total resources required for the trainer."""
total_resource_map = defaultdict(float, self._trainer_resources_not_none)
num_workers = self.num_workers or 0
for k, value in self._resources_per_worker_not_none.items():
total_resource_map[k] += value * num_workers
total_resource_map[k] += value * self.num_workers
return dict(total_resource_map)

@property
Expand All @@ -244,49 +249,44 @@ def as_placement_group_factory(self) -> "PlacementGroupFactory":
"""Returns a PlacementGroupFactory to specify resources for Tune."""
from ray.tune.execution.placement_groups import PlacementGroupFactory

trainer_resources = self._trainer_resources_not_none
trainer_bundle = [trainer_resources]
worker_resources = {
"CPU": self.num_cpus_per_worker,
"GPU": self.num_gpus_per_worker,
}
worker_resources_extra = (
{} if self.resources_per_worker is None else self.resources_per_worker
)
worker_bundles = [
{**worker_resources, **worker_resources_extra}
for _ in range(self.num_workers if self.num_workers else 0)
]
bundles = trainer_bundle + worker_bundles
trainer_bundle = self._trainer_resources_not_none
worker_bundle = self._resources_per_worker_not_none

# Colocate Trainer and rank0 worker by merging their bundles
# Note: This empty bundle is required so that the Tune actor manager schedules
# the Trainable onto the combined bundle while taking none of its resources,
# rather than a non-empty head bundle.
combined_bundle = dict(Counter(trainer_bundle) + Counter(worker_bundle))
bundles = [{}, combined_bundle] + [worker_bundle] * (self.num_workers - 1)
return PlacementGroupFactory(bundles, strategy=self.placement_strategy)

@classmethod
def from_placement_group_factory(
cls, pgf: "PlacementGroupFactory"
) -> "ScalingConfig":
"""Create a ScalingConfig from a Tune's PlacementGroupFactory"""
if pgf.head_bundle_is_empty:
trainer_resources = {}
worker_bundles = pgf.bundles
else:
trainer_resources = pgf.bundles[0]
worker_bundles = pgf.bundles[1:]
"""Create a ScalingConfig from a Tune's PlacementGroupFactory
use_gpu = False
Note that this is only needed for ResourceChangingScheduler, which
modifies a trial's PlacementGroupFactory but doesn't propagate
the changes to ScalingConfig. TrainTrainable needs to reconstruct
a ScalingConfig from on the trial's PlacementGroupFactory.
"""

# pgf.bundles = [{trainer + worker}, {worker}, ..., {worker}]
num_workers = len(pgf.bundles)
combined_resources = pgf.bundles[0]
resources_per_worker = pgf.bundles[-1]
use_gpu = bool(resources_per_worker.get("GPU", False))
placement_strategy = pgf.strategy
resources_per_worker = None
num_workers = None

if worker_bundles:
first_bundle = worker_bundles[0]
if not all(bundle == first_bundle for bundle in worker_bundles[1:]):
raise ValueError(
"All worker bundles (any other than the first one) "
"must be equal to each other."
)
use_gpu = bool(first_bundle.get("GPU"))
num_workers = len(worker_bundles)
resources_per_worker = first_bundle
# In `as_placement_group_factory`, we merged the trainer resource into the
# first worker resources bundle. We need to calculate the resources diff to
# get the trainer resources.
# Note: If there's only one worker, we won't be able to calculate the diff.
# We'll have empty trainer bundle and assign all resources to the worker.
trainer_resources = dict(
Counter(combined_resources) - Counter(resources_per_worker)
)

return ScalingConfig(
trainer_resources=trainer_resources,
Expand Down
5 changes: 3 additions & 2 deletions python/ray/air/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def test_scaling_config_validate_config_bad_allowed_keys():
@pytest.mark.parametrize(
"trainer_resources", [None, {}, {"CPU": 1}, {"CPU": 2, "GPU": 1}, {"CPU": 0}]
)
@pytest.mark.parametrize("num_workers", [None, 1, 2])
@pytest.mark.parametrize(
"resources_per_worker_and_use_gpu",
[
Expand All @@ -157,8 +156,10 @@ def test_scaling_config_validate_config_bad_allowed_keys():
)
@pytest.mark.parametrize("placement_strategy", ["PACK", "SPREAD"])
def test_scaling_config_pgf_equivalance(
trainer_resources, resources_per_worker_and_use_gpu, num_workers, placement_strategy
trainer_resources, resources_per_worker_and_use_gpu, placement_strategy
):
num_workers = 2

resources_per_worker, use_gpu = resources_per_worker_and_use_gpu
scaling_config = ScalingConfig(
trainer_resources=trainer_resources,
Expand Down
16 changes: 7 additions & 9 deletions python/ray/air/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ def gc_collect():

@pytest.fixture
def cluster_setup(ray_start_cluster_head: Cluster):
# Sets up a cluster with 4 nodes: head node + 3 workers
# Sets up a cluster with 3 nodes: head node + 2 workers
cluster = ray_start_cluster_head
nodes = []
nodes.append(cluster.add_node(resources={"worker1": 1, "coordinator": 1}))
nodes.append(cluster.add_node(resources={"worker1": 1, "cpu": 1, "coordinator": 1}))
nodes.append(cluster.add_node(resources={"worker2": 1, "cpu": 1}))
nodes.append(cluster.add_node(resources={"worker3": 1, "cpu": 1}))
cluster.wait_for_nodes()

@ray.remote
Expand All @@ -69,15 +68,13 @@ def get_node_id():

worker1_node_id = ray.get(get_node_id.options(resources={"worker1": 1}).remote())
worker2_node_id = ray.get(get_node_id.options(resources={"worker2": 1}).remote())
worker3_node_id = ray.get(get_node_id.options(resources={"worker3": 1}).remote())
wait_for_condition(
lambda: len({node["NodeID"] for node in ray.nodes() if (node["Alive"])}) == 4
lambda: len({node["NodeID"] for node in ray.nodes() if (node["Alive"])}) == 3
)

yield cluster, nodes, [
worker1_node_id,
worker2_node_id,
worker3_node_id,
]


Expand Down Expand Up @@ -155,6 +152,7 @@ def test_trainable_error_with_trainer(ray_start_4_cpus, tmp_path, fail_fast):
name=name,
failure_config=FailureConfig(fail_fast=fail_fast),
),
scaling_config=ScalingConfig(num_workers=1),
)

if fail_fast in [False, True]:
Expand Down Expand Up @@ -242,9 +240,9 @@ def test_preemption_handling(
"""Integration test for node preemption handling in Ray Train/Tune.
Even though `max_failures=0`, preemption errors should still be retried."""
cluster, nodes, node_ids = cluster_setup
# node 1 = coordinator, node 2 = worker, node 3 = worker
coordinator_node, worker_node, _ = nodes
coordinator_node_id, worker_node_id, _ = node_ids
# node 1 = coordinator and worker, node 2 = worker
coordinator_node, worker_node = nodes
coordinator_node_id, worker_node_id = node_ids

num_workers = 2
tmp_path.joinpath("markers").mkdir()
Expand Down
25 changes: 10 additions & 15 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,9 @@ def setup(self, config, **kwargs):
run_config = base_config.pop("run_config", None)
self._merged_config = merge_dicts(base_config, self.config)
self._merged_config["run_config"] = run_config
merged_scaling_config = self._merged_config.get("scaling_config")
merged_scaling_config = self._merged_config.get(
"scaling_config", ScalingConfig()
)
if isinstance(merged_scaling_config, dict):
merged_scaling_config = ScalingConfig(**merged_scaling_config)
self._merged_config[
Expand Down Expand Up @@ -763,21 +765,14 @@ def _reconcile_scaling_config_with_trial_resources(
if not isinstance(trial_resources, PlacementGroupFactory):
return scaling_config

if scaling_config:
scaling_config = trainer_cls._validate_scaling_config(
scaling_config
)
scaling_config_from_trial_resources = (
ScalingConfig.from_placement_group_factory(trial_resources)
)
# Ignore ResourceChangingScheduler workaround when resource bundles
# are unchanged
if self.trial_resources == scaling_config.as_placement_group_factory():
return scaling_config

# This check should always pass if ResourceChangingScheduler is not
# used.
if scaling_config_from_trial_resources != scaling_config:
scaling_config = trainer_cls._validate_scaling_config(
scaling_config_from_trial_resources
)
return scaling_config
trainer_cls._validate_scaling_config(scaling_config)

return ScalingConfig.from_placement_group_factory(trial_resources)

def _trainable_func(self, config):
# We ignore the config passed by Tune and instead use the merged
Expand Down
7 changes: 5 additions & 2 deletions python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def check_cpus(self):

assert ray.available_resources()["CPU"] == 4
trainer = DummyTrainer(
check_cpus, scaling_config=ScalingConfig(trainer_resources={"CPU": 2})
check_cpus,
scaling_config=ScalingConfig(
trainer_resources={"CPU": 2}, resources_per_worker={}
),
)
trainer.fit()

Expand All @@ -70,7 +73,7 @@ def check_override(self):
assert self.custom_arg["outer"]["fixed"] == 1

pg = get_current_placement_group()
assert len(pg.bundle_specs) == 2 # 1 trainer, 1 worker
assert len(pg.bundle_specs) == 1 # Merged trainer and worker bundle

scale_config = ScalingConfig(num_workers=4)
trainer = DummyTrainer(
Expand Down
24 changes: 0 additions & 24 deletions python/ray/train/tests/test_data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,30 +188,6 @@ def train_func(config):
assert trainer._train_loop_config["x"] == 100


def test_scaling_config_validation(ray_start_4_cpus):
def train_func(config):
train.report({"loss": config["x"]})

# Should be able to create a DataParallelTrainer w/o scaling_config,
# but it should fail on fit
trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
train_loop_config={"x": 100},
)
with pytest.raises(ValueError):
trainer.fit()

# Scaling config must be passed in through Tuner param space if not
# included in the initial trainer
tuner = Tuner(trainer)
with pytest.raises(ValueError):
tuner.fit()

tuner = Tuner(trainer, param_space={"scaling_config": ScalingConfig(num_workers=1)})
results = tuner.fit()
assert not results.errors


def test_fast_slow(ray_start_4_cpus):
def train_func():
for i in range(2):
Expand Down
17 changes: 1 addition & 16 deletions python/ray/tune/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import ray
from ray import train, tune
from ray.train import CheckpointConfig, ScalingConfig
from ray.train import CheckpointConfig
from ray.air.constants import TIME_THIS_ITER_S, TRAINING_ITERATION
from ray.rllib import _register_all
from ray.train._internal.session import shutdown_session
Expand Down Expand Up @@ -1389,21 +1389,6 @@ def train_fn(config):
assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_scaling_config(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
return len(ray.get_gpu_ids())

[trial] = tune.run(
tune.with_resources(
train_fn,
resources=ScalingConfig(trainer_resources={"GPU": num_gpus}, num_workers=0),
)
).trials

assert trial.last_result["_metric"] == num_gpus


@pytest.mark.parametrize("num_gpus", [1, 2])
def test_with_resources_fn(ray_start_2_cpus_2_gpus, num_gpus):
def train_fn(config):
Expand Down

0 comments on commit 4a73957

Please sign in to comment.