-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix collector reset with truncation #1021
Conversation
torchrl/collectors/collectors.py
Outdated
if (_reset is None and done.any()) or ( | ||
_reset is not None and done[_reset].any() | ||
): | ||
reset_idx = done_or_terminated.squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here let's keep in mind that the shape of done is not always [*batch_size,1] but it follows the done_spec which could be [*batch_size,*F]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it
Then what should we do? Reduce done until its shape matches the tensordict's?
reset_idx = done_or_terminated.squeeze(-1) | |
reset_idx = done_or_terminated | |
while reset_idx.ndim > self._tensordict.ndim: | |
reset_idx = reset_idx.any(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do it a few lines later:
done_or_terminated.sum(
tuple(range(self._tensordict.batch_dims, done_or_terminated.ndim)),
dtype=torch.bool,
)
we can do it once for these 2 use cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, done
torchrl/collectors/collectors.py
Outdated
td_reset[traj_done_or_terminated], inplace=True | ||
) | ||
done = self._tensordict[traj_done_or_terminated].get("done") | ||
if (_reset is None and done.any()) or (_reset is not None and done.any()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check here might be worth keeping like before because
self._tensordict.get("done")[_reset]
checks allt the done_sec dims, while
self._tensordict[traj_done_or_terminated].get("done")
uses only traj_done_or_terminated
which is smaller than the _reset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check here might be worth keeping like before because
self._tensordict.get("done")[_reset]
checks allt the done_sec dims, whileself._tensordict[traj_done_or_terminated].get("done")
uses onlytraj_done_or_terminated
which is smaller than the _reset
Not sure I get it
Say done.shape = [3, 4, 5]
and traj_done_or_terminated.shape = [3]
self._tensordict[traj_done_or_terminated].get("done").any()
. tells us if any (leave) env is done.
What are you suggesting we do instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taking your example if _reset.shape = [3, 4, 5]
but if only _reset[0,0,0] = True
, when you compute traj_done_or_terminated
you get [True, False,False]
.
now when you do self._tensordict[traj_done_or_terminated].get("done")
you get a done of shape [4,5]
and you check that none of it is True, while you should check only [0,0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I see why that is necessary since both options will raise the exception if anything is done.
If in MARL you have agents that can be done when init both these options will raise an error (bc an agent is done).
In other words: if in MARL an agent can be done after init, then there's a chance that you'll bump into this error bc it can be "naturally" done after a reset.
I guess my question is: in the scenario where there are bags of envs with a batch size that is richer than the tensordict batch-size, is there any rationale to check that only the envs that were done are not done anymore, and not that after we call reset (presumably only on these env) nothing from these bags of envs is done anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To rephrase is: if in you example I just call reset on [0, 0, 0]
and after reset I still have a done, it can only come from [0, 0, 0]
since I did not touch the other envs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking the same thing
What about: we take the simplest option (for now) and leave it as is until we figure out what to do with envs that can start with a done or env that keep being executed when done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if in this pr has an or which can be removed BTW dunno if you saw it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I did not get that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np it is fixed now
Description
Alternative to #1015