-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] MultiDiscreteTensorSpec nvec with several axes #789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4362e67 to
af2eb52
Compare
|
Should I add the expected behavior of to_categorical and to_one_hot to this PR ? If so, I will put it in draft form until then |
81b6e3e to
8280a72
Compare
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a bunch of comments.
Not sure they all make sense
torchrl/data/tensor_specs.py
Outdated
| return x.permute(*torch.arange(x.ndim - 1, -1, -1)).reshape([*shape, *_size]) | ||
|
|
||
| def _project(self, val: torch.Tensor) -> torch.Tensor: | ||
| if val.dtype not in (torch.int, torch.long): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use this class with non integer dtype?
If no, we could just check that val.dtype matches the dtype of the object.
If yes, we could check self.dtype instead and always call round
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like that ?
if not self.dtype.is_floating_point:
val = torch.round(val)
|
I tried to rewrite the code in agreement with your comments. Thank you! |
b152de3 to
f575f51
Compare
8ec4325 to
62e324d
Compare
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not sure I grasp everything can you comment a bit further?
torchrl/data/tensor_specs.py
Outdated
| return [val] if self._size < 2 else val.split(1, -1) | ||
| x = self._rand(self.space, shape) | ||
| if self.nvec.ndim > 1: | ||
| x = x.transpose(len(shape), -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do? Can you explain?
say spec.shape = [3, 4] and shape=[1, 2]
you want x to have shape [1, 2, 3, 4]. From what I understand the output of _rand has that shape already.
Why do you invert the dim with size 4 with the one with size 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in this case x have a shape of torch.Size([1, 2, 4, 3]) without the transpose. So I have to invert 4 and 3 (no the one with size 2, because index start with 0).
This is the log of the _rand algorithm (after and before stacking):
[torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2])]
torch.Size([1, 2, 4])
[torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2])]
torch.Size([1, 2, 4])
[torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([1, 2])]
torch.Size([1, 2, 4])
[torch.Size([1, 2, 4]), torch.Size([1, 2, 4]), torch.Size([1, 2, 4])]
torch.Size([1, 2, 4, 3])Because at the end I should stack in -2 dim
What do you think about this solution:
def _rand(self, space: Box, shape: torch.Size, i: int):
....
x.append(self._rand(_s, shape, i -1))
....
return torch.stack(x, -i)Instead stacking in the last dimension I stack in a dimension according to the depth in the recursive box list, I pass the test with this solution.
torchrl/data/tensor_specs.py
Outdated
| x = self._rand(self.space, shape) | ||
| if self.nvec.ndim > 1: | ||
| x = x.transpose(len(shape), -1) | ||
| return x.squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the shape is Size([1]), shouldn't we keep the last dim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a oversight, when ts.shape == self.shape == torch.Size([1]), the previous computation add an empty dimension, so I added this case:
if self.shape == torch.Size([1]):
x = x.squeeze(-1)
return x|
The CI looks broken :/ |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
It's weird, can you try pushing an empty commit? |
|
Done but still broken. Maybe a key was removed (or expired) from Github |
|
I managed to get it to run! |
Description
I rewrited the code of the PR #783 to support nvec with several axes (Sorry for the double PR I pushed late ^^). Related to #781.
Example:
The code is a bit complicated, but it's necessary to generalize to n dimensions, don't hesitate to tell me if you have any questions or if I should add some comments
Motivation and Context
Improve compatibility with Gym Spaces that have
MultiDiscretethat support nvec with several axes.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!