-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Add RLTrainer interface, implementation, and examples #23465
Conversation
@@ -4,7 +4,7 @@ | |||
from ray import tune | |||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer | |||
|
|||
from ray.ml.examples.tensorflow.tensorflow_mnist_example import train_func | |||
from ray.ml.examples.tf.tensorflow_mnist_example import train_func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reason for this rename is that otherwise the examples won't work if executed from the examples working dir (as we try to import tensorflow as tf
- and it will use the local directory then)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks pretty good for RLlib stuff. 2 minor questions.
|
||
trainer = RLTrainer( | ||
run_config=RunConfig(stop={"training_iteration": 5}), | ||
scaling_config={ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for being out-of-touch with the latest AIR api, why do we have RunConfig, but scaling_config is a plain dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's currently the case - agree it's a bit inconsistent here. We should probably at least accept dicts for runconfig, too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, ScalingConfig is soon to be a dataclass as well. After we (I) make the change so that Tune can construct a search space from it.
@krfricke Can you add a PR description? |
|
||
trainer = RLTrainer( | ||
run_config=RunConfig(stop={"training_iteration": 5}), | ||
scaling_config={ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, ScalingConfig is soon to be a dataclass as well. After we (I) make the change so that Tune can construct a search space from it.
# Conflicts: # python/ray/tune/trial_runner.py
# Conflicts: # python/ray/tune/tune.py
# Conflicts: # python/ray/ml/config.py
|
||
return config | ||
|
||
def training_loop(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me think that we should reconsider our interface for Trainer
. Currently we make the assumption that training_loop
will be overridden by all subclasses (hence why it is an abstract method), but is not the case here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this is just a thought- it should not block this PR!)
# Conflicts: # python/ray/tune/impl/tuner_internal.py
Why are these changes needed?
This PR adds a RLTrainer to Ray AIR. It works for both offline and online use cases. In offline training, it will leverage the
datasets
key of the Trainer API to specify a dataset reader input, used e.g. in Behavioral Cloning (BC). In online training, it is a wrapper around the rllib trainables making use of the parameter layering enabled by the Trainer API.Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.