Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix collector reset with truncation #1021

Merged
merged 5 commits into from
Apr 4, 2023
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Apr 4, 2023

Description

Alternative to #1015

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 4, 2023
if (_reset is None and done.any()) or (
_reset is not None and done[_reset].any()
):
reset_idx = done_or_terminated.squeeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here let's keep in mind that the shape of done is not always [*batch_size,1] but it follows the done_spec which could be [*batch_size,*F]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it
Then what should we do? Reduce done until its shape matches the tensordict's?

Suggested change
reset_idx = done_or_terminated.squeeze(-1)
reset_idx = done_or_terminated
while reset_idx.ndim > self._tensordict.ndim:
reset_idx = reset_idx.any(-1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it a few lines later:

done_or_terminated.sum(
                tuple(range(self._tensordict.batch_dims, done_or_terminated.ndim)),
                dtype=torch.bool,
            )

we can do it once for these 2 use cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, done

td_reset[traj_done_or_terminated], inplace=True
)
done = self._tensordict[traj_done_or_terminated].get("done")
if (_reset is None and done.any()) or (_reset is not None and done.any()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check here might be worth keeping like before because
self._tensordict.get("done")[_reset]
checks allt the done_sec dims, while
self._tensordict[traj_done_or_terminated].get("done")
uses only traj_done_or_terminated which is smaller than the _reset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check here might be worth keeping like before because self._tensordict.get("done")[_reset] checks allt the done_sec dims, while self._tensordict[traj_done_or_terminated].get("done") uses only traj_done_or_terminated which is smaller than the _reset

Not sure I get it
Say done.shape = [3, 4, 5] and traj_done_or_terminated.shape = [3]

self._tensordict[traj_done_or_terminated].get("done").any(). tells us if any (leave) env is done.
What are you suggesting we do instead?

Copy link
Contributor

@matteobettini matteobettini Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking your example if _reset.shape = [3, 4, 5] but if only _reset[0,0,0] = True, when you compute traj_done_or_terminated you get [True, False,False].

now when you do self._tensordict[traj_done_or_terminated].get("done") you get a done of shape [4,5] and you check that none of it is True, while you should check only [0,0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I see why that is necessary since both options will raise the exception if anything is done.

If in MARL you have agents that can be done when init both these options will raise an error (bc an agent is done).
In other words: if in MARL an agent can be done after init, then there's a chance that you'll bump into this error bc it can be "naturally" done after a reset.

I guess my question is: in the scenario where there are bags of envs with a batch size that is richer than the tensordict batch-size, is there any rationale to check that only the envs that were done are not done anymore, and not that after we call reset (presumably only on these env) nothing from these bags of envs is done anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To rephrase is: if in you example I just call reset on [0, 0, 0] and after reset I still have a done, it can only come from [0, 0, 0] since I did not touch the other envs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing
What about: we take the simplest option (for now) and leave it as is until we figure out what to do with envs that can start with a done or env that keep being executed when done?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if in this pr has an or which can be removed BTW dunno if you saw it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I did not get that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np it is fixed now

@vmoens vmoens added the bug Something isn't working label Apr 4, 2023
@vmoens vmoens merged commit 164a438 into main Apr 4, 2023
albertbou92 pushed a commit to PyTorchRL/rl that referenced this pull request Apr 12, 2023
@vmoens vmoens deleted the fix_collector_reset branch May 12, 2023 09:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants