-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Improve collector buffer initialisation when policy spec is unavailable #1547
Conversation
Signed-off-by: Matteo Bettini <matbet@meta.com>
if key in self._tensordict_out.keys(isinstance(key, tuple)): | ||
continue | ||
self._tensordict_out.set(key, spec.zero()) | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
above here i just refactored
else: | ||
# otherwise, we perform a small number of steps with the policy to | ||
# determine the relevant keys with which to pre-populate _tensordict_out. | ||
# This is the safest thing to do if the spec has None fields or if there is | ||
# no spec at all. | ||
# See #505 for additional context. | ||
self._tensordict_out.update(self._tensordict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we update the self._tensordict_out
with the real data coming from env.reset()
.zero_() | ||
) | ||
# in addition to outputs of the policy, we add traj_ids and step_count to | ||
self._tensordict_out = self.policy(self._tensordict_out.to(self.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and feed that to the policy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are passing, so LGTM :) Thanks a mil for this!
…unavailable (pytorch#1547) Signed-off-by: Matteo Bettini <matbet@meta.com> Co-authored-by: vmoens <vincentmoens@gmail.com>
Depends on #1539
Fixes #1565
This pr makes the collector use a sample policy forward on the reset data to generate their buffer when the policy spec is not available or partial.
The prior approach initialised the policy keys by passing the policy a
env.fake_tensordict()
full of zeros, this made certain policies that use action masks or other masks throw errors as these masks were all False.This solution increases the generality of this initialization and makes sure that the policy is fed data that is suit to it