-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] Add Step Counter transform #756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @riiswa! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
@vmoens I'm not sure I fully understood how to edit the observation_spec, in any case here is my PR, I'm waiting for your changes suggestions :) |
17f5ba3 to
c6754bc
Compare
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
c6754bc to
715b4f7
Compare
Codecov Report
@@ Coverage Diff @@
## main #756 +/- ##
==========================================
+ Coverage 88.72% 88.74% +0.02%
==========================================
Files 123 123
Lines 20944 21000 +56
==========================================
+ Hits 18583 18637 +54
- Misses 2361 2363 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work!
Have a look at my comments
|
|
||
| @_apply_to_composite | ||
| def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: | ||
| self._transform_spec(observation_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe i'm missing something, but I think you should do something like
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
observation_spec["step_count"] = UnboundedDiscreteTensorSpec(dtype=torch.int64)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all your suggestions! I made the changes, but I would like to know how can I test the transform_observation_spec method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a test function you can run
from torchrl.envs.utils import check_env_specs
env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(50))
check_env_specs(env)
It will run a small rollout + get fake data from your specs, and check that they match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this error, I don't really know what I missed, it comes from the shape of the tensor is []:
test/test_transforms.py:1600 (TestTransforms.test_step_counter_observation_spec)
Traceback (most recent call last):
File "/Users/waris/Projects/rl/test/test_transforms.py", line 1604, in test_step_counter_observation_spec
check_env_specs(env)
File "/Users/waris/Projects/rl/torchrl/envs/utils.py", line 168, in check_env_specs
fake_tensordict = env.fake_tensordict().flatten_keys(".")
File "/Users/waris/Projects/rl/torchrl/envs/common.py", line 636, in fake_tensordict
fake_obs = observation_spec.rand(self.batch_size)
File "/Users/waris/Projects/rl/torchrl/data/tensor_specs.py", line 1286, in rand
_dict = {
File "/Users/waris/Projects/rl/torchrl/data/tensor_specs.py", line 1287, in <dictcomp>
key: self[key].rand(shape)
File "/Users/waris/Projects/rl/torchrl/data/tensor_specs.py", line 655, in rand
r = torch.rand(*shape, *interval.shape, device=interval.device)
TypeError: rand() missing 1 required positional arguments: "size"There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On line 655, try putting the shape in between brackets
Like rand([*shape1, *shape2],...)
Also make sure you've merged main into your branch!
4a93ac0 to
543ac75
Compare
a4af5ff to
6c05635
Compare
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This should be adapted to the Nd*Spec removal though :)
6c05635 to
922d61a
Compare
Description
Add a transform that counts the steps from a reset and sets the done state to True after a certain number of steps.
This transform can be used in parallel / multi env settings, as the counter has the size of the env.batch_size. The "step_counter" key will be added to the output of tensordict when reset or step in the environment.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!