-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] Append, init and insert transforms in ReplayBuffer #695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The remaining failure: FAILED test/test_libs.py::TestCollectorLib::test_collector_run[device1-GymEnv-env_args1-env_kwargs1] - EOFError (https://app.circleci.com/pipelines/github/pytorch/rl/4652/workflows/88023457-c917-479c-bcff-9ebd04e46af6/jobs/114173?invite=true#step-108-3532) does not seem to be related to my changes, I think? |
Nope those seem to be flacky tests |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the errors as they were, #697 will implement more informative error messages.
test/test_rb.py
Outdated
| @pytest.mark.parametrize("transform", transforms) | ||
| def test_smoke_replay_buffer_transform(transform): | ||
| rb = rb_prototype.ReplayBuffer( | ||
| collate_fn=lambda x: torch.stack(x, 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be able to remove the collate_fn as of #688
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, merged upstream.
| if self._unsqueeze_dim_orig < 0: | ||
| self._unsqueeze_dim = self._unsqueeze_dim_orig | ||
| else: | ||
| elif self.parent: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The three errors of #692 are legit and should be raised.
They come from the fact that one should not assume that data comes with a certain batch size when creating a transform without parent, ie the number of first dimensions cannot be determined beforehand.
Hence
transform = UnsqueezeTransform(0)
should raise an issue because we don't know if the batch size of the tensors that will come in and hard to debug errors will occur. Most of the time, users will know the last dimensions of their tensors and can perfectly code
transform = UnsqueezeTransform(-n)
to get the desired behaviour.
| if self.last_dim >= 0: | ||
| self.last_dim = self.last_dim - len(observation_spec.shape) | ||
| break | ||
| if isinstance(self.parent, EnvBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #697 and my comment below regarding this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, reverted. This will break the test until #692 is merged.
Codecov Report
@@ Coverage Diff @@
## main #695 +/- ##
==========================================
+ Coverage 88.60% 88.70% +0.10%
==========================================
Files 121 121
Lines 20690 20759 +69
==========================================
+ Hits 18333 18415 +82
+ Misses 2357 2344 -13
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
* amend * [BugFix] ConvNet forward method with tensors of more than 4 dimensions (#686) * cnn forward fix * more general code * cnn testing * precommit run check * convnet tests * [Feature] add `standard_normal` for RewardScaling (#682) * Add standard_normal * give attribute access * Update standard_normal * Update tests * Fix tests * Address in-place scaling of reward * Improvise tests * [Feature] Jumanji envs (#674) * amend * [Feature] Default collate_fn (#688) * init * amend * amend * [BugFix] Fix Examples (#687) * amend * [Refactoring] Replace direct gym version checks with decorated functions (#691) * [Refactoring] Replace gym version checking with decorated functions (#) Initial commit. Only tests. * Refactoring in gym.py * More refactoring in gym.py * Completed refactoring * amend * amend * Version 0.0.3 (#696) * [Docs] Host TensorDict docs inside TorchRL docs (#693) * Pull tensordict docs into TorchRL docs * Add banner for tensordict docs * [BugFix] Fix docs build (#698) * [BugFix] Proper error messages for orphan transform creation (#697) * amend * [Feature] Append, init and insert transforms in ReplayBuffer (#695) * lint Co-authored-by: albertbou92 <albertbou92@users.noreply.github.com> Co-authored-by: Aditya Gandhamal <61016383+adityagandhamal@users.noreply.github.com> Co-authored-by: yingchenlin <yc.jon.lin@gmail.com> Co-authored-by: Sergey Ordinskiy <113687736+ordinskiy@users.noreply.github.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com> Co-authored-by: Alan Schelten <alan@schelten.net>
* init * tests1 * run examples in tests * [Feature] MPPI Planner (#694) * amend * [BugFix] ConvNet forward method with tensors of more than 4 dimensions (#686) * cnn forward fix * more general code * cnn testing * precommit run check * convnet tests * [Feature] add `standard_normal` for RewardScaling (#682) * Add standard_normal * give attribute access * Update standard_normal * Update tests * Fix tests * Address in-place scaling of reward * Improvise tests * [Feature] Jumanji envs (#674) * amend * [Feature] Default collate_fn (#688) * init * amend * amend * [BugFix] Fix Examples (#687) * amend * [Refactoring] Replace direct gym version checks with decorated functions (#691) * [Refactoring] Replace gym version checking with decorated functions (#) Initial commit. Only tests. * Refactoring in gym.py * More refactoring in gym.py * Completed refactoring * amend * amend * Version 0.0.3 (#696) * [Docs] Host TensorDict docs inside TorchRL docs (#693) * Pull tensordict docs into TorchRL docs * Add banner for tensordict docs * [BugFix] Fix docs build (#698) * [BugFix] Proper error messages for orphan transform creation (#697) * amend * [Feature] Append, init and insert transforms in ReplayBuffer (#695) * lint Co-authored-by: albertbou92 <albertbou92@users.noreply.github.com> Co-authored-by: Aditya Gandhamal <61016383+adityagandhamal@users.noreply.github.com> Co-authored-by: yingchenlin <yc.jon.lin@gmail.com> Co-authored-by: Sergey Ordinskiy <113687736+ordinskiy@users.noreply.github.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com> Co-authored-by: Alan Schelten <alan@schelten.net> * lint Co-authored-by: albertbou92 <albertbou92@users.noreply.github.com> Co-authored-by: Aditya Gandhamal <61016383+adityagandhamal@users.noreply.github.com> Co-authored-by: yingchenlin <yc.jon.lin@gmail.com> Co-authored-by: Sergey Ordinskiy <113687736+ordinskiy@users.noreply.github.com> Co-authored-by: Tom Begley <tomcbegley@gmail.com> Co-authored-by: Alan Schelten <alan@schelten.net>
Description
Implemented passing transforms in init, appending and inserting transforms as methods in ReplayBuffer
closes #612 and closes #692.
Skips any code in transforms that accesses the parent without it being defined.
I think this approach is preferable to: #690
Inserting at an out of bounds index throws an error as implemented in Compose. Note that this is not the same behavior as a native python list, which just inserts at the beginning or end of list and doesn't throw.
I'm not quite sure what is meant by transform_*_spec returning an error, since these methods are defined as public on transform subclasses and are not called in this implementation.
We could prevent the use of them by subclassing Compose, but I don't think it's necessary?
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!