-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Trajectory view API (prep PR for switching on by default across all RLlib; plumbing only) #11717
[RLlib] Trajectory view API (prep PR for switching on by default across all RLlib; plumbing only) #11717
Conversation
…ectory_view_api_plumbing_only
Did you mean to assign? Will review |
Yes, please review. Fixing the remaining tests rn. ... |
…ectory_view_api_plumbing_only
Tests should be all fixed now, but I'll keep checking on 'em. |
…ectory_view_api_plumbing_only � Conflicts: � rllib/policy/eager_tf_policy.py
rllib/policy/eager_tf_policy.py
Outdated
@@ -636,7 +637,8 @@ def _stats(self, outputs, samples, grads): | |||
}) | |||
return fetches | |||
|
|||
def _initialize_loss_with_dummy_batch(self): | |||
@override(Policy) | |||
def _initialize_loss_dynamically(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the name change? The previous one seemed more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/models/modelv2.py
Outdated
@@ -318,6 +319,24 @@ def is_time_major(self) -> bool: | |||
""" | |||
return self.time_major is True | |||
|
|||
@PublicAPI | |||
def update_view_requirements_from_init_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should be exposing this as a public API of model. Instead, why not have the policies add the state view requirements internally, without mutating or otherwise requiring the model to do anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main comment is to move update_view_requirements_from_init_state into implementations instead of exposing it publicy.
rllib/policy/policy.py
Outdated
SampleBatch.AGENT_INDEX: ViewRequirement(), | ||
SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(), | ||
SampleBatch.ACTION_LOGP: ViewRequirement(), | ||
SampleBatch.VF_PREDS: ViewRequirement(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's kind of odd to see agent-specific keys like VF_PREDS here. Can't we infer these dynamically always, and omit them from this dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can remove the extra-action fetch keys. These can indeed be added (if required) after the model's test call.
The others need to stay due to them being required (maybe) by the loss function. If any field is not required, it'll be removed automatically, so the user shouldn't really case. We always get the slimmest possible final view req dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
Note: This is the base/maximum requirement dict, from which later | ||
some requirements will be subtracted again automatically to streamline | ||
data collection, batch creation, and data transfer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not keep this as the empty dict and instead infer columns to add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not that easy:
- Model may require some unknown inputs (such as prev-actions, etc..), which are only visible from the model's
self.inference_view_requirements
- Policy may need some standard inputs, such as "t", "episode_id" in its loss (or learn_on_batch) methods.
So we do need to fill in the standard values at first (then remove if not needed). The only exception are models extra-action-fetches, which probably shouldn't be in the initial dict and can be added after the model test-call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
So we are initially only adding the following base columns:
OBS
NEXT_OBS
ACTIONS
REWARDS
DONES
INFOS
EPS_ID
AGENT_INDEX
t (<- time step)
We then add the model's own inference requirements, including inferring some requirements from the model's init-state (done in the policy now, the model does not do this anymore).
Then we do the model forward test-pass and add the returned extra-action outs to the view reqs.
Then we call postprocessing and the loss, after which we erase all columns that are not needed (thereby differentiating between postprocessing and loss (some cols are only needed for postprocessing)).
""" | ||
sample_batch_size = max(self.batch_divisibility_req, 2) | ||
B = 2 # For RNNs, have B=2, T=[depends on sample_batch_size] | ||
self._dummy_batch = self._get_dummy_batch_from_view_requirements( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not create them from a sample batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean? We need to wrap the dummy batch into a tracking dict anyways (so it doesn't really matter what's underneath, a SampleBatch or plain dict).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, _dummy_batch
is now a SampleBatch, which will be wrapped into a tracking dict prioir to calling postprocessing_fn and loss.
…ectory_view_api_plumbing_only
Tue Nov 3 10:53:28 UTC 2020 Flake8.... |
Suggested renaming: I think this makes the supersetting a bit more clear (inference includes model forward, as does training) |
Trajectory view API prep PR for switching on by default across all RLlib;
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.