Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Add backend-specific context manager for train_func. #43209

21 changes: 18 additions & 3 deletions python/ray/train/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
ContextManager,
Dict,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import ray
from ray.actor import ActorHandle
Expand Down Expand Up @@ -88,6 +98,7 @@ def update_env_vars(env_vars: Dict[str, Any]):
def construct_train_func(
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
config: Optional[Dict[str, Any]],
train_func_context: Optional[ContextManager],
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
fn_arg_name: Optional[str] = "train_func",
discard_returns: bool = False,
) -> Callable[[], T]:
Expand All @@ -97,6 +108,8 @@ def construct_train_func(
This can either take in no arguments or a ``config`` dict.
config (Optional[Dict]): Configurations to pass into
``train_func``. If None then an empty Dict will be created.
train_func_context: Context manager for user's `train_func`, which executes
backend-specific logics before and after the training function.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
fn_arg_name (Optional[str]): The name of training function to use for error
messages.
discard_returns: Whether to discard any returns from train_func or not.
Expand Down Expand Up @@ -135,7 +148,8 @@ def discard_return_wrapper(*args, **kwargs):
@functools.wraps(wrapped_train_func)
def train_fn():
try:
return wrapped_train_func(config)
with train_func_context:
return wrapped_train_func(config)
except Exception as e:
raise StartTraceback from e

Expand All @@ -144,7 +158,8 @@ def train_fn():
@functools.wraps(wrapped_train_func)
def train_fn():
try:
return wrapped_train_func()
with train_func_context:
return wrapped_train_func()
except Exception as e:
raise StartTraceback from e

Expand Down
5 changes: 5 additions & 0 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import nullcontext
from typing import TypeVar

from ray.train._internal.utils import Singleton
Expand All @@ -19,6 +20,10 @@ class BackendConfig:
def backend_cls(self):
return Backend

@property
def train_func_context(self):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
return nullcontext

def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)

Expand Down
1 change: 1 addition & 0 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def training_loop(self) -> None:
train_loop_per_worker = construct_train_func(
self._train_loop_per_worker,
self._train_loop_config,
train_func_context=self._backend_config.train_func_context(),
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
fn_arg_name="train_loop_per_worker",
discard_returns=True,
)
Expand Down
4 changes: 4 additions & 0 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def test_torch_get_device(
ray.init(num_cpus=4, num_gpus=2)

def train_fn():
# Confirm that the TorchConfig Prologue is effective
assert torch.cuda.current_device() == train.torch.get_device().index
# Make sure environment variable is being set correctly.
if cuda_visible_devices:
visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
Expand Down Expand Up @@ -102,6 +104,8 @@ def train_fn():
def test_torch_get_device_dist(ray_2_node_2_gpu, num_gpus_per_worker, tmp_path):
@patch("torch.cuda.is_available", lambda: True)
def train_fn():
# Confirm that the TorchConfig Prologue is effective
assert torch.cuda.current_device() == train.torch.get_device().index
devices = sorted([device.index for device in train.torch.get_devices()])
write_rank_data(tmp_path, devices)

Expand Down
3 changes: 2 additions & 1 deletion python/ray/train/tests/test_training_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import time
from contextlib import nullcontext
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -91,7 +92,7 @@ def create_iterator(
):
# Similar logic to the old Trainer.run_iterator().

train_func = construct_train_func(train_func, None)
train_func = construct_train_func(train_func, None, train_func_context=nullcontext)

backend_executor = backend_executor_cls(
backend_config=backend_config, num_workers=num_workers, max_retries=MAX_RETRIES
Expand Down
17 changes: 17 additions & 0 deletions python/ray/train/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@
logger = logging.getLogger(__name__)


class TorchConfigContextManager:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we actually swap to the function style context so that it's easier to reuse existing contexts?

@contextlib.contextmanager
def torch_context_manager():
    # some other setup
    with torch.device(ray.train.torch.get_device()):
        yield
    # some other teardown
def xgboost_context_manager():
    # some other setup
    with CommunicatorContext():
        yield
    # some other teardown

Copy link
Member Author

@woshiyyya woshiyyya Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine since you can return either function-based or class-based context manager.

def train_func_context(self):
    def func_based_ctx_mgr():
        ...
        yield
        ...
    return func_based_ctx_mgr

alternatively, to reuse an existing context manager, we can subclass it as below:

class InnerContextManager:
    def __enter__(self):
        print("Entering InnerContextManager")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting InnerContextManager")
        return False

class OuterContextManager(InnerContextManager):
    def __enter__(self):
        print("Entering OuterContextManager")
        super().__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        print("Exiting OuterContextManager")
        return False 

def __enter__(self):
# Set default cuda device
if torch.cuda.is_available():
device = ray.train.torch.get_device()
if device.type == "cuda":
torch.cuda.set_device(device)

def __exit__(self, type, value, traceback):
# Propagate exceptions if any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think only return True is needed if you want to suppress exceptions https://docs.python.org/3/reference/datamodel.html#object.__exit__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description updated!

Oh actually we are not suppressing the exceptions since it's captured in the outer layer here: https://github.com/ray-project/ray/pull/43209/files#diff-8b259b33153d078b025da24134ff3b897aa1227d287d2ad38a1d2f11afb7d213R154

return False


@PublicAPI(stability="stable")
@dataclass
class TorchConfig(BackendConfig):
Expand Down Expand Up @@ -43,6 +56,10 @@ class TorchConfig(BackendConfig):
def backend_cls(self):
return _TorchBackend

@property
def train_func_context(self):
return TorchConfigContextManager


def _setup_torch_process_group(
backend: str,
Expand Down