-
Notifications
You must be signed in to change notification settings - Fork 418
[Refactor] Sequential probabilistic changes #719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ) | ||
| return SafeSequential(self.module[0], self.module[1]) | ||
|
|
||
| def get_value_operator(self) -> SafeSequential: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could self.module[2] here be a SafeProbabilisticModule and if so, does that mean we need to make a similar change to get_value_operator?
| if isinstance(loss.actor_network, SafeProbabilisticModule): | ||
| key = ("module", *key) if isinstance(key, tuple) else ("module", key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels a bit messy, the actor network parameters have a different structure to the q-value network because they've be wrapped with a SafeProbabilisticModule.
I think that extra layer is avoidable, but maybe we could wrap the value network with a SafeProbabilisticModule too using Delta distributions in order that the parameter structure remains in sync?
| if isinstance(self.actor_network, TensorDictSequential): | ||
| sample_key = self.actor_network[-1].sample_out_key[0] | ||
| tensordict_actor_dist = self.actor_network[-1].build_dist_from_params( | ||
| td_params | ||
| ) | ||
| else: | ||
| sample_key = self.actor_network.sample_out_key[0] | ||
| tensordict_actor_dist = self.actor_network.build_dist_from_params( | ||
| td_params | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think under these changes the actor network will never be sequential (it will always be wrapped with a probabilistic module), so this should be safe to delete. Does that seem right?
|
@vmoens comment from tcbegley#5 for reference
|
Codecov Report
@@ Coverage Diff @@
## main #719 +/- ##
==========================================
- Coverage 88.66% 88.65% -0.01%
==========================================
Files 120 120
Lines 20185 20181 -4
==========================================
- Hits 17897 17892 -5
- Misses 2288 2289 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Description
pytorch/tensordict#88 deprecates the
get_distandget_dist_paramsmethods ofTensorDictSequentialand hencetorchrl.modules.SafeSequentialwhich subclasses it. This PR adapts code and tests in TorchRL for that change.This supercedes an earlier PR tcbegley#5 (rebased onto new commits + targeting TorchRL main rather than a branch on my fork).