-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] New ConnectorV3 API #05: PPO runs in single-agent mode in this API stack #42272
[RLlib] New ConnectorV3 API #05: PPO runs in single-agent mode in this API stack #42272
Conversation
…runner_support_connectors_04_learner_api_changes
…runner_support_connectors_04_learner_api_changes
…runner_support_connectors_04_learner_api_changes
…runner_support_connectors_04_learner_api_changes
@@ -550,24 +638,3 @@ def training_step(self) -> ResultDict: | |||
self.workers.local_worker().set_global_vars(global_vars) | |||
|
|||
return train_results | |||
|
|||
def postprocess_episodes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer needed here. Episodes are sent directly to Learner(s) as-is.
@@ -39,6 +47,78 @@ def build(self) -> None: | |||
) | |||
) | |||
|
|||
@override(Learner) | |||
def _preprocess_train_data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: Only called on the new API stack + EnvRunners.
if not episodes: | ||
return batch, episodes | ||
|
||
# Make all episodes one ts longer in order to just have a single batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New way to do GAE:
- elongate all episodes by one artificial ts.
- perform vf-predictions AND bootstrap value predictions in one single batch (b/c we have the extra timestep!)
-
- use the learner connector to make sure this forward pass is done using the correct (custom?) batch format.
- remove extra timesteps from episodes (and computed advantages)
SampleBatch.VF_PREDS, | ||
SampleBatch.ACTION_DIST_INPUTS, | ||
] | ||
return self.output_specs_inference() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplified
@@ -40,6 +40,11 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: | |||
the policy distribution to be used for computing KL divergence between the old | |||
policy and the new policy during training. | |||
""" | |||
# TODO (sven): Make this the only bahevior once PPO has been migrated | |||
# to new API stack (including EnvRunners!). | |||
if self.config.model_config_dict.get("uses_new_env_runners"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temporary hack to make sure RLModule knows, when it still has to compute vf-preds via forward_exploration
(old and hybrid API stacks).
@@ -272,6 +281,40 @@ def __init__( | |||
# the final results dict in the `self.compile_update_results()` method. | |||
self._metrics = defaultdict(dict) | |||
|
|||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved here for better ordering of methods (used to be all the way at the bottom of class).
|
||
# Build learner connector pipeline used on this Learner worker. | ||
# TODO (sven): Support multi-agent cases. | ||
if self.config.uses_new_env_runners and not self.config.is_multi_agent(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now: Only on new API stack + EnvRunner + single-agent: use Learner connector (w/o this PPO on new stack would not learn).
rllib/utils/minibatch_utils.py
Outdated
@@ -87,7 +86,13 @@ def __iter__(self): | |||
def get_len(b): | |||
return len(b[SampleBatch.SEQ_LENS]) | |||
|
|||
n_steps = int( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug fix. When slicing on a BxT batch, we should slice properly along B-axis (with the correct slice size!).
return value | ||
|
||
data = tree.map_structure(map_, self) | ||
infos = self.pop(SampleBatch.INFOS, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplifications.
# we return the values here and slice them separately | ||
# TODO(Artur): Clean this hack up. | ||
return value | ||
return value[start_padded:stop_padded] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplifications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stmp
…runner_support_connectors_05_ppo_w_connectorv2s
…runner_support_connectors_05_ppo_w_connectorv2s
@sven1977 Could you speak more to why GAE support was dropped for APPO in this release?
|
EnvRunners support new ConnectorV3 API; PPO runs in single-agent mode in this API stack
This PR:
train_batch_size_per_learner
to better distinguish between total effective batch size and batch size per (GPU) learner worker.forward_exploration
to perform a value-function pass. This is an essential improvement in code quality as we now have full separation between the sampling- and the learning worlds. The EnvRunner (sampling world) is no longer concerned with having to think about what the PPOLearner (learning world) might need and only needs to compute actions for the next env step.Benchmark results:
Learns Pong in ~5min via examples/connectors/connector_v2_frame_stacking.py example script:
Args:
--num-gpus=8 --num-env-runners=95 --framework=torch
on commit: 790a537
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.