Skip to content

Conversation

@riiswa
Copy link
Contributor

@riiswa riiswa commented Jan 3, 2023

Description

I rewrited the code of the PR #783 to support nvec with several axes (Sorry for the double PR I pushed late ^^). Related to #781.

Example:

>>> ts = MultiDiscreteTensorSpec([[4, 2], [6, 9]])
>>> ts.rand()
tensor([[0, 1],
        [3, 6]])

The code is a bit complicated, but it's necessary to generalize to n dimensions, don't hesitate to tell me if you have any questions or if I should add some comments

Motivation and Context

Improve compatibility with Gym Spaces that have MultiDiscrete that support nvec with several axes.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 3, 2023
@riiswa riiswa force-pushed the feature/multdiscretetensorspec branch from 4362e67 to af2eb52 Compare January 3, 2023 19:42
@riiswa
Copy link
Contributor Author

riiswa commented Jan 4, 2023

Should I add the expected behavior of to_categorical and to_one_hot to this PR ? If so, I will put it in draft form until then

@riiswa riiswa force-pushed the feature/multdiscretetensorspec branch from 81b6e3e to 8280a72 Compare January 4, 2023 14:48
@vmoens vmoens added the enhancement New feature or request label Jan 5, 2023
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a bunch of comments.
Not sure they all make sense

return x.permute(*torch.arange(x.ndim - 1, -1, -1)).reshape([*shape, *_size])

def _project(self, val: torch.Tensor) -> torch.Tensor:
if val.dtype not in (torch.int, torch.long):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use this class with non integer dtype?
If no, we could just check that val.dtype matches the dtype of the object.
If yes, we could check self.dtype instead and always call round

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like that ?

if not self.dtype.is_floating_point:
    val = torch.round(val)

@vmoens vmoens changed the title [Feature] MultiDiscreteTensorSpec nvec with several axes. [Feature] MultiDiscreteTensorSpec nvec with several axes Jan 5, 2023
@riiswa
Copy link
Contributor Author

riiswa commented Jan 5, 2023

I tried to rewrite the code in agreement with your comments. Thank you!

@riiswa riiswa marked this pull request as draft January 5, 2023 21:12
@riiswa riiswa force-pushed the feature/multdiscretetensorspec branch from b152de3 to f575f51 Compare January 5, 2023 21:16
@riiswa riiswa marked this pull request as ready for review January 5, 2023 21:23
@riiswa riiswa force-pushed the feature/multdiscretetensorspec branch from 8ec4325 to 62e324d Compare January 6, 2023 08:13
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure I grasp everything can you comment a bit further?

return [val] if self._size < 2 else val.split(1, -1)
x = self._rand(self.space, shape)
if self.nvec.ndim > 1:
x = x.transpose(len(shape), -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do? Can you explain?

say spec.shape = [3, 4] and shape=[1, 2]
you want x to have shape [1, 2, 3, 4]. From what I understand the output of _rand has that shape already.
Why do you invert the dim with size 4 with the one with size 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in this case x have a shape of torch.Size([1, 2, 4, 3]) without the transpose. So I have to invert 4 and 3 (no the one with size 2, because index start with 0).

This is the log of the _rand algorithm (after and before stacking):

[torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2])]
torch.Size([1, 2, 4])
[torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2])]
torch.Size([1, 2, 4])
[torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2])]
torch.Size([1, 2, 4])
[torch.Size([1, 2, 4]), torch.Size([1, 2, 4]), torch.Size([1, 2, 4])]
torch.Size([1, 2, 4, 3])

Because at the end I should stack in -2 dim

What do you think about this solution:

    def _rand(self, space: Box, shape: torch.Size, i: int):
		....
                x.append(self._rand(_s, shape, i -1))
		....
        return torch.stack(x, -i)

Instead stacking in the last dimension I stack in a dimension according to the depth in the recursive box list, I pass the test with this solution.

@vmoens

x = self._rand(self.space, shape)
if self.nvec.ndim > 1:
x = x.transpose(len(shape), -1)
return x.squeeze(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the shape is Size([1]), shouldn't we keep the last dim?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a oversight, when ts.shape == self.shape == torch.Size([1]), the previous computation add an empty dimension, so I added this case:

if self.shape == torch.Size([1]):
	x = x.squeeze(-1)
return x

@riiswa
Copy link
Contributor Author

riiswa commented Jan 6, 2023

The CI looks broken :/

Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@vmoens
Copy link
Collaborator

vmoens commented Jan 6, 2023

The CI looks broken :/

It's weird, can you try pushing an empty commit?
git commit --allow-empty -m empty

@riiswa
Copy link
Contributor Author

riiswa commented Jan 6, 2023

Done but still broken. Maybe a key was removed (or expired) from Github

@vmoens
Copy link
Collaborator

vmoens commented Jan 6, 2023

I managed to get it to run!

@vmoens vmoens merged commit 6daedd6 into pytorch:main Jan 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants