Skip to content

Conversation

@albertbou92
Copy link
Contributor

@albertbou92 albertbou92 commented Dec 24, 2022

Description

In the init method of the SyncDataCollector class, a small number of steps is taken with the policy to determine the relevant keys of the output TensorDict. When the policy device and the environment device are different, that can raise a RuntimeError since the input provided to the policy is located in the environment device.

This PR only makes sure that the TensorDict provided to the policy is in the policy device, and then moves the output TensorDict to the environment device again.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 24, 2022
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Can you add a test where device and passing_device differ to check non-regression over this bug fix?

@albertbou92
Copy link
Contributor Author

albertbou92 commented Dec 27, 2022

Added tests for all device/passing_device combinations for SyncDataCollector, MultiSyncDataCollector and MultiaSyncDataCollector.

I noticed that for MultiSyncDataCollector, the yielded TensorDicts had its device attribute set to None.
I think the problem can be in the TensorDict.cat method, which returns a TensorDict whose "next" TensorDict has device set to None.

As a sanity check, I added a line of code to make sure the MultiSyncDataCollector yielded TensorDicts are cast to the correct device.

@vmoens
Copy link
Collaborator

vmoens commented Dec 28, 2022

@albertbou92 interesting

I noticed that for MultiSyncDataCollector, the yielded TensorDicts had its device attribute set to None.

Does this happen when the passing devices match too? Or just when there's more than one?
torch.cat should conserve the device (since we're sure that all device match). If it doesn't it's a bug. When the devices don't match, we first cast things to cpu before calling cat. It's definitely not ultra fast so we may want to parametrise that feature.

which returns a TensorDict whose "next" TensorDict has device set to None.
You mean that the tensordict has a device but the nested tensordict doesn't? Weird...

@albertbou92
Copy link
Contributor Author

@albertbou92 interesting

I noticed that for MultiSyncDataCollector, the yielded TensorDicts had its device attribute set to None.

Does this happen when the passing devices match too? Or just when there's more than one? torch.cat should conserve the device (since we're sure that all device match). If it doesn't it's a bug. When the devices don't match, we first cast things to cpu before calling cat. It's definitely not ultra fast so we may want to parametrise that feature.

which returns a TensorDict whose "next" TensorDict has device set to None.
You mean that the tensordict has a device but the nested tensordict doesn't? Weird...

@albertbou92 albertbou92 reopened this Dec 28, 2022
@albertbou92
Copy link
Contributor Author

Does this happen when the passing devices match too?

Yes it happens for all device combinations

@albertbou92
Copy link
Contributor Author

You mean that the tensordict has a device but the nested tensordict doesn't? Weird...

exactly

@vmoens
Copy link
Collaborator

vmoens commented Dec 28, 2022

You mean that the tensordict has a device but the nested tensordict doesn't? Weird...

exactly

Ok then it's a tensordict bug, I'll have a look into it.

@albertbou92
Copy link
Contributor Author

albertbou92 commented Jan 2, 2023

Should I remove the line of code that makes sure the MultiSyncDataCollector yielded TensorDicts are cast to the correct device.

We could keep it as a sanity check, won't add overhead if the TensorDicts is already in the correct device.

@vmoens
Copy link
Collaborator

vmoens commented Jan 2, 2023

Should I remove the line of code that makes sure the MultiSyncDataCollector yielded TensorDicts are cast to the correct device.

We could keep it as a sanity check, won't add overhead if the TensorDicts is already in the correct device.

Which line of code are we talking about?

By the way: can you check if the device bug with nested tensordicts is still present?

@albertbou92
Copy link
Contributor Author

albertbou92 commented Jan 2, 2023

the device bug with nested TensorDicts seems to be fixed now!

In my code I added to the MultiSyncDataCollector iterator method a out_buffer = out_buffer.to(prev_device) after concatenating all different dicts to make sure the yielded TensorDict was placed to the passing_device. Now that the cat method is fixed might be unnecessary.

I can either remove it or we can keep as a sanity check.

@vmoens
Copy link
Collaborator

vmoens commented Jan 2, 2023

If it's not necessary I would remove it. Those things take little time but they quickly pile up to a consequent overhead.

@albertbou92
Copy link
Contributor Author

ok, removed that checking line.

Now the PR just fixed the bug in SyncDataCollector when policy device and passing_device are different and adds tests to make sure all collectors work with all device combinations.

del collector


@pytest.mark.parametrize("device", ["cuda", "cpu"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably use ˋget_available_devices` here as some tests run on machines that don't have a cuda device



@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("passing_device", ["cuda", "cpu"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

del collector


@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to test on cpu only? Maybe we can skip it in that case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok yes I can remove it in that case

@vmoens vmoens merged commit de9f488 into pytorch:main Jan 2, 2023
@albertbou92 albertbou92 deleted the bugfix_collector_init branch January 18, 2024 10:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants