-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Trajectory view API: Enable by default for PPO, IMPALA, PG, A3C (tf and torch). #11747
[RLlib] Trajectory view API: Enable by default for PPO, IMPALA, PG, A3C (tf and torch). #11747
Conversation
…ectory_view_api_plumbing_only
…ectory_view_api_plumbing_only
…ectory_view_api_plumbing_only � Conflicts: � rllib/policy/eager_tf_policy.py
…ectory_view_api_enable_by_default_for_some_tf
…ectory_view_api_enable_by_default_for_some_tf
episode=None): | ||
# not used, so save some bandwidth | ||
del sample_batch.data[SampleBatch.NEXT_OBS] | ||
return sample_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep this? What would happen if a user had this in a custom copy of Impala?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keeping it, however, in the test run, this will be a TrackingDict, not a SampleBatch, so it won't have the data
prop.
I had to add a key check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding this to TrackingDict for backwards compat?
@property
def data(self):
return self # backwards compat with SampleBatch
I really want to make sure there are zero lines of code change in the policy files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem fixed
T = policy.config["rollout_fragment_length"] | ||
B = tensor.shape[0] // T | ||
# Cover cases, where we send a (small) test batch through this loss | ||
# function. | ||
if B == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we retain compatibility here by sending a bigger batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, yeah, maybe the test batch should be large anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made test batch large enough (32).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few compatibility issues left I think
@sven1977 also, please make sure to assign PRs to reviewers, otherwise it will not show up on their dashboard. |
…ectory_view_api_enable_by_default_for_some_tf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very close, last round of comments
episode=None): | ||
# not used, so save some bandwidth | ||
del sample_batch.data[SampleBatch.NEXT_OBS] | ||
return sample_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding this to TrackingDict for backwards compat?
@property
def data(self):
return self # backwards compat with SampleBatch
I really want to make sure there are zero lines of code change in the policy files.
rllib/agents/ppo/appo_tf_policy.py
Outdated
@@ -358,9 +358,6 @@ def postprocess_trajectory( | |||
use_critic=policy.config["use_critic"]) | |||
else: | |||
batch = sample_batch | |||
# TODO: (sven) remove this del once we have trajectory view API fully in | |||
# place. | |||
del batch.data["new_obs"] # not used, so save some bandwidth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep this for now? Will it crash?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
rllib/policy/tf_policy_template.py
Outdated
@@ -62,8 +63,12 @@ def build_tf_policy( | |||
Policy, ModelV2, TensorType, TensorType, TensorType | |||
], Tuple[TensorType, type, List[TensorType]]]] = None, | |||
mixins: Optional[List[type]] = None, | |||
view_requirements_fn: Optional[Callable[[Policy], Dict[ | |||
str, ViewRequirement]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the only way to specify view reqs would be through custom models. So we should remove this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but I still wanted to leave the user some opportunity to add new ones. But it's not needed by any algos right now.
@@ -174,8 +174,8 @@ def build_torch_policy( | |||
mixins (Optional[List[type]]): Optional list of any class mixins for | |||
the returned policy class. These mixins will be applied in order | |||
and will have higher precedence than the TorchPolicy class. | |||
view_requirements_fn (Callable[[], | |||
Dict[str, ViewRequirement]]): An optional callable to retrieve | |||
view_requirements_fn (Optional[Callable[[Policy], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; please fix the 3 comments prior to merge
episode=None): | ||
# not used, so save some bandwidth | ||
del sample_batch.data[SampleBatch.NEXT_OBS] | ||
return sample_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem fixed
rllib/agents/dqn/dqn_tf_policy.py
Outdated
@@ -424,7 +424,7 @@ def postprocess_nstep_and_prio(policy: Policy, | |||
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS]) | |||
new_priorities = (np.abs(convert_to_numpy(td_errors)) + | |||
policy.config["prioritized_replay_eps"]) | |||
batch.data[PRIO_WEIGHTS] = new_priorities | |||
batch[PRIO_WEIGHTS] = new_priorities |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -209,7 +208,7 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): | |||
T = tensor.shape[0] // B | |||
else: | |||
# Important: chop the tensor into batches at known episode cut | |||
# boundaries. TODO(ekl) this is kind of a hack | |||
# boundaries. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it still a hack? Please restore the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a "hack", since IMPALA explicitly indicates through the divisibility requirement that the batch must be divisible by rollout_fragment_length.
…ectory_view_api_enable_by_default_for_some_tf
This PR is based on #11717 (which needs to be merged first!)
a) the model (self.inference_view_requirements) for model forward passes
b) the policy (self.view_requirements), which holds a superset of its model's view requirements plus its own (loss, postprocessing) view requirements.
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.