-
Notifications
You must be signed in to change notification settings - Fork 418
[BugFix] SyncDataCollector init when device and env_device are different #765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Can you add a test where device and passing_device differ to check non-regression over this bug fix?
|
Added tests for all device/passing_device combinations for SyncDataCollector, MultiSyncDataCollector and MultiaSyncDataCollector. I noticed that for MultiSyncDataCollector, the yielded TensorDicts had its device attribute set to None. As a sanity check, I added a line of code to make sure the MultiSyncDataCollector yielded TensorDicts are cast to the correct device. |
|
@albertbou92 interesting
Does this happen when the passing devices match too? Or just when there's more than one?
|
|
Yes it happens for all device combinations |
exactly |
Ok then it's a tensordict bug, I'll have a look into it. |
|
Should I remove the line of code that makes sure the MultiSyncDataCollector yielded TensorDicts are cast to the correct device. We could keep it as a sanity check, won't add overhead if the TensorDicts is already in the correct device. |
Which line of code are we talking about? By the way: can you check if the device bug with nested tensordicts is still present? |
|
the device bug with nested TensorDicts seems to be fixed now! In my code I added to the MultiSyncDataCollector iterator method a I can either remove it or we can keep as a sanity check. |
|
If it's not necessary I would remove it. Those things take little time but they quickly pile up to a consequent overhead. |
|
ok, removed that checking line. Now the PR just fixed the bug in SyncDataCollector when policy device and passing_device are different and adds tests to make sure all collectors work with all device combinations. |
| del collector | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably use ˋget_available_devices` here as some tests run on machines that don't have a cuda device
|
|
||
|
|
||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||
| @pytest.mark.parametrize("passing_device", ["cuda", "cpu"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
test/test_collector.py
Outdated
| del collector | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to test on cpu only? Maybe we can skip it in that case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok yes I can remove it in that case
Description
In the init method of the SyncDataCollector class, a small number of steps is taken with the policy to determine the relevant keys of the output TensorDict. When the policy device and the environment device are different, that can raise a RuntimeError since the input provided to the policy is located in the environment device.
This PR only makes sure that the TensorDict provided to the policy is in the policy device, and then moves the output TensorDict to the environment device again.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!