New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Add backend-specific context manager for train_func
.
#43209
[Train] Add backend-specific context manager for train_func
.
#43209
Conversation
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
train_func
.train_func
.
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
train_func
.train_func
.
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
train_func
.train_func
.
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I do like this and am in favor of adding it to the BackendConfig
developer API, but I am also worried that we are doing some magical stuff behind the scenes that is not explicit to the user. We should be strict in terms of what we put in this default context manager -- otherwise we'll end up with a bunch of implicit behavior that users aren't aware of.
Some alternatives to consider and discuss pros/cons before merging this PR:
- Have these decorators as utilities that users should call explicitly.
- Don't have any default setup/teardown and show users how to achieve certain things like setting default cuda device in documentation.
@@ -16,6 +16,19 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
class TorchConfigContextManager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we actually swap to the function style context so that it's easier to reuse existing contexts?
@contextlib.contextmanager
def torch_context_manager():
# some other setup
with torch.device(ray.train.torch.get_device()):
yield
# some other teardown
def xgboost_context_manager():
# some other setup
with CommunicatorContext():
yield
# some other teardown
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine since you can return either function-based or class-based context manager.
def train_func_context(self):
def func_based_ctx_mgr():
...
yield
...
return func_based_ctx_mgr
alternatively, to reuse an existing context manager, we can subclass it as below:
class InnerContextManager:
def __enter__(self):
print("Entering InnerContextManager")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print("Exiting InnerContextManager")
return False
class OuterContextManager(InnerContextManager):
def __enter__(self):
print("Entering OuterContextManager")
super().__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
print("Exiting OuterContextManager")
return False
Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to add a quick sentence to the PR description about the decision to call this for the user rather than expose it as a utility that the user can call themselves?
torch.cuda.set_device(device) | ||
|
||
def __exit__(self, type, value, traceback): | ||
# Propagate exceptions if any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think only return True
is needed if you want to suppress exceptions https://docs.python.org/3/reference/datamodel.html#object.__exit__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description updated!
Oh actually we are not suppressing the exceptions since it's captured in the outer layer here: https://github.com/ray-project/ray/pull/43209/files#diff-8b259b33153d078b025da24134ff3b897aa1227d287d2ad38a1d2f11afb7d213R154
I updated the PR description. The BackendConfig is a developer API so it'd be safe and won't be exposed to the users. |
Currently we have multiple ways to do the initialization before calling the users'
In the future design, we need a more unified way to do the initialization. One possible way is to store all states in a global context, wrap the initialization logic with context manager around train_func. This ensures that initialization logics executes in the same thread as |
Why are these changes needed?
This PR provides a way to inject backend-specific context manager for
train_func
. It's an developer API(not for users), which enabled us to inject specific setup and teardown logics for the training function.Use case 1: PyTorch sets default cuda device
Set torch default device to the current device allocated to this worker.
Previously, ray train did not automatically set
torch.cuda.current_device
, but only set it when the user callstrain.torch.prepare_model
in the training function. If the user does not callprepare_model
, the default cuda device for all workers will be "cuda:0", which is not ideal and may cause problems (Moving all tensors to device 0).### Use case 2: XGBoost CommunicatorContextPrevious discussions: #42767 (comment)To make XGBoost training distributed, users have to call the training function under a context manager.### Use case 3: LightGBM set Env VarsRelated issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.