-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Trainer sub-class PPO/DDPPO (instead of build_trainer()
).
#20571
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
look totally awesome, love it.
just to double check, you didn't make any logic change with this PR right?
@@ -21,14 +21,16 @@ | |||
import time | |||
|
|||
import ray | |||
from ray.rllib.agents.ppo import ppo | |||
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, \ | |||
PPOTrainer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't fit in the last line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean. The line is too long for the LINTer and needs to be split by a \
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah sorry, it looked really strange on my laptop. my bad.
rllib/agents/ppo/ddppo.py
Outdated
@@ -41,8 +43,8 @@ | |||
|
|||
# Adds the following updates to the `PPOTrainer` config in | |||
# rllib/agents/ppo/ppo.py. | |||
DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs( | |||
ppo.DEFAULT_CONFIG, | |||
DEFAULT_CONFIG = PPOTrainer.merge_trainer_configs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kinda feel like merge_trainer_configs() should be a util on Trainer class?
so all these agents would just do
Trainer.merge_trainer_configs(
....
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is defined in Trainer
(not PPOTrainer), but since PPOTrainer is-a Trainer, it works like this, too. But yeah, we should probably call it like Trainer.merge_trainer_configs()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
raise ValueError("Only gloo, mpi, or nccl is supported for " | ||
"the backend of PyTorch distributed.") | ||
# `num_gpus` must be 0/None, since all optimization happens on Workers. | ||
if config["num_gpus"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just curious, if my eval worker runs on head node and needs gpu, which param do I use to configure it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config:
evaluation_config:
num_gpus_per_worker: ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌
.batch_across_shards() # List[(grad_info, count)] | ||
.for_each(RecordStats())) | ||
|
||
train_op = train_op.for_each(update_worker_global_vars) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor minor question, maybe just chain this call right above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't actually touch this code, just moved it into the method. Not sure about too much chaining. Honestly, we should probably chain rather less than more as it makes the already complex execution plans even more confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I can't argue against it 😆
config["rollout_fragment_length"] | ||
if config["train_batch_size"] > 0 and \ | ||
config["train_batch_size"] % calculated_min_rollout_size != 0: | ||
new_rollout_fragment_length = config["train_batch_size"] // ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
high level question, should we do this for all on-policy agents?
train batch size is such a mysterious thing for us.
logics like this living in a specific agent makes things less consistent.
I know some agents don't work like this, but for the ones do, should we put this in a util function, so they can all do this at the beginning of a run.
definitely not belong to this PR, just curious what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely agree with you that train batch sizes should be handled not by individual trainers, but in a more generic way, as we discussed offline. This just bubbled up here b/c I had to move that block of code. But yeah, we should make a separate PR in which we implement "guaranteed batch sizes" for all algos.
…ner_sub_class_ppo
Hey @gjoliver , correct, no logic change. Just sometimes have to make a few "adjustments" to keep stuff backward compatible wrt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, thanks!
…ner_sub_class_ppo
…ner_sub_class_ppo
Trainer sub-class for PPOTrainer and DDPPOTrainer (instead of
build_trainer()
).Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.