Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Add backend-specific context manager for train_func. #43209

Conversation

woshiyyya
Copy link
Member

@woshiyyya woshiyyya commented Feb 15, 2024

Why are these changes needed?

This PR provides a way to inject backend-specific context manager for train_func. It's an developer API(not for users), which enabled us to inject specific setup and teardown logics for the training function.

Use case 1: PyTorch sets default cuda device

Set torch default device to the current device allocated to this worker.

Previously, ray train did not automatically set torch.cuda.current_device, but only set it when the user calls train.torch.prepare_model in the training function. If the user does not call prepare_model, the default cuda device for all workers will be "cuda:0", which is not ideal and may cause problems (Moving all tensors to device 0).

def train_func():
    model.to("cuda") # -> it will move models on all ranks to device 0
    ...

We add this behavior internally rather than let the users to call it, because we want it to have the same behavior when scaling training from 1 GPU to multiple GPUs without changing the user code.

### Use case 2: XGBoost CommunicatorContext

Previous discussions: #42767 (comment)

To make XGBoost training distributed, users have to call the training function under a context manager.

### Use case 3: LightGBM set Env Vars

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya changed the title [Train] Enable calling backend0specific prologue before executing train_func. [Train] Enable calling backend-specific prologue before executing train_func. Feb 15, 2024
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya changed the title [Train] Enable calling backend-specific prologue before executing train_func. [Train] Enable calling backend-specific setup function before executing train_func. Feb 16, 2024
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya changed the title [Train] Enable calling backend-specific setup function before executing train_func. [Train] Add backend-specific context manager for train_func. Feb 16, 2024
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya marked this pull request as ready for review February 16, 2024 19:37
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I do like this and am in favor of adding it to the BackendConfig developer API, but I am also worried that we are doing some magical stuff behind the scenes that is not explicit to the user. We should be strict in terms of what we put in this default context manager -- otherwise we'll end up with a bunch of implicit behavior that users aren't aware of.

Some alternatives to consider and discuss pros/cons before merging this PR:

  1. Have these decorators as utilities that users should call explicitly.
  2. Don't have any default setup/teardown and show users how to achieve certain things like setting default cuda device in documentation.

python/ray/train/_internal/utils.py Outdated Show resolved Hide resolved
python/ray/train/_internal/utils.py Outdated Show resolved Hide resolved
python/ray/train/data_parallel_trainer.py Outdated Show resolved Hide resolved
@@ -16,6 +16,19 @@
logger = logging.getLogger(__name__)


class TorchConfigContextManager:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we actually swap to the function style context so that it's easier to reuse existing contexts?

@contextlib.contextmanager
def torch_context_manager():
    # some other setup
    with torch.device(ray.train.torch.get_device()):
        yield
    # some other teardown
def xgboost_context_manager():
    # some other setup
    with CommunicatorContext():
        yield
    # some other teardown

Copy link
Member Author

@woshiyyya woshiyyya Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine since you can return either function-based or class-based context manager.

def train_func_context(self):
    def func_based_ctx_mgr():
        ...
        yield
        ...
    return func_based_ctx_mgr

alternatively, to reuse an existing context manager, we can subclass it as below:

class InnerContextManager:
    def __enter__(self):
        print("Entering InnerContextManager")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        print("Exiting InnerContextManager")
        return False

class OuterContextManager(InnerContextManager):
    def __enter__(self):
        print("Entering OuterContextManager")
        super().__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        print("Exiting OuterContextManager")
        return False 

python/ray/train/backend.py Show resolved Hide resolved
woshiyyya and others added 4 commits February 16, 2024 14:36
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to add a quick sentence to the PR description about the decision to call this for the user rather than expose it as a utility that the user can call themselves?

torch.cuda.set_device(device)

def __exit__(self, type, value, traceback):
# Propagate exceptions if any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think only return True is needed if you want to suppress exceptions https://docs.python.org/3/reference/datamodel.html#object.__exit__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description updated!

Oh actually we are not suppressing the exceptions since it's captured in the outer layer here: https://github.com/ray-project/ray/pull/43209/files#diff-8b259b33153d078b025da24134ff3b897aa1227d287d2ad38a1d2f11afb7d213R154

@woshiyyya
Copy link
Member Author

I updated the PR description.

The BackendConfig is a developer API so it'd be safe and won't be exposed to the users.

@woshiyyya
Copy link
Member Author

woshiyyya commented Feb 21, 2024

Currently we have multiple ways to do the initialization before calling the users' train_func.

  • Using this backend-specific context manager
  • Use the predefined training loop (e.g. LightGBM, XGBoost)
  • Backend.on_start + Backend.on_training_start

In the future design, we need a more unified way to do the initialization. One possible way is to store all states in a global context, wrap the initialization logic with context manager around train_func. This ensures that initialization logics executes in the same thread as train_func.

@matthewdeng matthewdeng merged commit 852e9f0 into ray-project:master Feb 21, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants