New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Update docstring and user guides for train_loop_config
#43691
[Train] Update docstring and user guides for train_loop_config
#43691
Conversation
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@@ -23,7 +23,7 @@ For reference, the final code is as follows: | |||
from ray.train.torch import TorchTrainer | |||
from ray.train import ScalingConfig | |||
|
|||
def train_func(config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not showing config
argument in the first place, since we didn't specify train_loop_config
in TorchTrainer
in this code snippet. Users will be confused about where to put the train_func
arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I agree, the config should not be promoted since it's mostly unnecessary for Train.
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
You can specify the input argument for `train_func` via the Trainer's `train_loop_config` parameter. | ||
|
||
.. warning:: | ||
|
||
Avoid passing large data objects through `train_loop_config` to reduce the | ||
serialization and deserialization overhead. Instead, it's preferred to | ||
initialize large objects (e.g. datasets, models) directly in `train_func`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a code snippet to show how to populate these? I think we want to show that it's a dictionary.
def train_func(config):
config[...]
config = {...}
trainer = TorchTrainer(train_func, train_loop_config=config, ...)
In the warning we can also show an example as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added two examples to
- highlight the config format
- show the good and bad practices of setting
train_loop_config
.
@@ -190,6 +190,13 @@ Begin by wrapping your code in a :ref:`training function <train-overview-trainin | |||
|
|||
Each distributed training worker executes this function. | |||
|
|||
You can specify the input argument for `train_func` via the Trainer's `train_loop_config` parameter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optionally, we can extract this section out to a separate file and include it, similar to what's being done here.
In the future we may just have a full separate user guide for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I've extracted the common paragraph into a separate doc.
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
…pdate_train_loop_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM from data side.
Why are these changes needed?
It's been a common issue that Ray Train users try to pass large data/model object through
train_loop_config
, which introduce large serialization overhead, and may incur some deserialization issues (e.g. deserialize cuda tensor on cpu actor (TrainTrainable
)).This PR adds comments in the user guide and docstring to warn users against similar attempts.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.