-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] MultiDiscreteTensorSpec #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8f12b5c to
21983b1
Compare
Codecov Report
@@ Coverage Diff @@
## main #783 +/- ##
==========================================
+ Coverage 88.74% 88.81% +0.06%
==========================================
Files 123 123
Lines 21170 21256 +86
==========================================
+ Hits 18787 18878 +91
+ Misses 2383 2378 -5
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
19ebdf5 to
542b474
Compare
82efab3 to
435d1e2
Compare
da9eb40 to
1147efc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I left some minor comments
Can you elaborate a bit more the description of the PR?
@matteobettini curious to see if that suits your purpose
| ] | ||
| ).squeeze() | ||
| _size = [self._size] if self._size > 1 else [] | ||
| return x.T.reshape([*shape, *_size]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leads to the following warning if the number of dims of x is greater than 2
<string>:3: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matricesor `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /Users/distiller/project/conda/conda-bld/pytorch_1646756029501/work/aten/src/ATen/native/TensorShape.cpp:2318.)
| ).squeeze() | ||
|
|
||
| def is_in(self, val: torch.Tensor) -> bool: | ||
| vals = self._split(val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also check the dtype here
(note to myself: we should check that is_in always checks the dtype)
Yep it seems to do what we want. Maybe one little thig is. with the This is what currently happens Instead I think it would be nicer |
|
You're right |
|
I will work on it :) and I'm also working on multi-dimensional nvec (I forgot this case): # Gym Example
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]))
>> d.sample()
array([[0, 0],
[2, 3]]) |
Description
The goal of this PR is to add MultiDiscreteTensorSpec for n categorical actions (#781)
Example:
This PR don't support yet
nvecwith several axes.TODO (
still in progressReady for reviewing):to_onehot()andto_categorical()Motivation and Context
This will close the issue #781
Types of changes
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!