-
Notifications
You must be signed in to change notification settings - Fork 418
[Feature] Migrate to tensordict.nn.TensorDictModule
#700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I believe the tests are failing because they require the changes in pytorch/tensordict#66. Not sure what the best way to test this is? Since afaik currently TorchRL does not import anything from |
# Conflicts: # torchrl/modules/__init__.py # torchrl/modules/functional_modules.py # torchrl/modules/tensordict_module/common.py # torchrl/modules/tensordict_module/probabilistic.py # torchrl/modules/tensordict_module/sequence.py # torchrl/modules/utils/__init__.py # torchrl/trainers/helpers/collectors.py # torchrl/trainers/helpers/trainers.py
|
|
||
|
|
||
| class TensorDictModule(nn.Module): | ||
| class TensorDictModule(_TensorDictModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering whether we should consider changing the name here, as it could be confusing if both tensordict and TorchRL have TensorDictModule classes.
I don't have a good name, but I'm thinking something like TensorDictModuleWithSpec to convey the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's something I thought about too
Maybe we can drop the TensorDict part (it's implicitly the case since torchrl is all-in into adopting tensordict)
ModuleWithSpec? SpecModule? SafeModule? RLModule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropping TensorDict part sounds good to me.
Of your suggestions I'm most drawn to SafeModule and SpecModule in that order I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the slow answer
We can also think about user experience, and for that RLModule seems good
In ReAgent they have a (improperly named) ModelBase class that all models inherit from. It's not only bad because Model-Based in a subfield in RL, but also because it fails to account for what it is (the base class of all policies and value networks).
SafeModule is also good as the main difference is mainly that putting a spec allows for safe sampling / mapping (it's more about safety than about the spec IMO). In that sense RLModule is less self-explanatory.
TLDR let's go for SafeModule :p
Codecov Report
@@ Coverage Diff @@
## main #700 +/- ##
==========================================
+ Coverage 88.77% 88.81% +0.04%
==========================================
Files 122 121 -1
Lines 21151 20490 -661
==========================================
- Hits 18776 18198 -578
+ Misses 2375 2292 -83
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great stuff! A lot less code: the best code is no code at all
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of suggestions for the docstrings
It's ok that things are copied for the rest i guess
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
Description
This PR adopts
tensordict.nn.TensorDictModuleand its variants and eliminates duplicated code. We don't useTensorDictModuleas a drop-in replacement sincetensordictlibrary has no concept of TensorSpecs. Instead we inherit and let the base classes do the heavy lifting, while the corresponding classes in TorchRL largely add spec validation.See supporting changes in pytorch/tensordict#66