Skip to content

Conversation

@tcbegley
Copy link
Contributor

Description

pytorch/tensordict#88 deprecates the get_dist and get_dist_params methods of TensorDictSequential and hence torchrl.modules.SafeSequential which subclasses it. This PR adapts code and tests in TorchRL for that change.

This supercedes an earlier PR tcbegley#5 (rebased onto new commits + targeting TorchRL main rather than a branch on my fork).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 30, 2022
)
return SafeSequential(self.module[0], self.module[1])

def get_value_operator(self) -> SafeSequential:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could self.module[2] here be a SafeProbabilisticModule and if so, does that mean we need to make a similar change to get_value_operator?

Comment on lines +2924 to +2925
if isinstance(loss.actor_network, SafeProbabilisticModule):
key = ("module", *key) if isinstance(key, tuple) else ("module", key)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit messy, the actor network parameters have a different structure to the q-value network because they've be wrapped with a SafeProbabilisticModule.

I think that extra layer is avoidable, but maybe we could wrap the value network with a SafeProbabilisticModule too using Delta distributions in order that the parameter structure remains in sync?

Comment on lines -194 to -203
if isinstance(self.actor_network, TensorDictSequential):
sample_key = self.actor_network[-1].sample_out_key[0]
tensordict_actor_dist = self.actor_network[-1].build_dist_from_params(
td_params
)
else:
sample_key = self.actor_network.sample_out_key[0]
tensordict_actor_dist = self.actor_network.build_dist_from_params(
td_params
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think under these changes the actor network will never be sequential (it will always be wrapped with a probabilistic module), so this should be safe to delete. Does that seem right?

@tcbegley
Copy link
Contributor Author

@vmoens comment from tcbegley#5 for reference

It feels like we cleaned up one module but made added a lot of complexity elsewhere.
What do you think?
Maybe the the ProbabilisticModule should be constructed above the Sequential in all cases?

e.g. ActorCritic related objects would not care about it, since they would be wrapped by a ProbabilisticModule anyway.

@codecov
Copy link

codecov bot commented Dec 1, 2022

Codecov Report

Merging #719 (3e3ddf3) into main (dc6b736) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #719      +/-   ##
==========================================
- Coverage   88.66%   88.65%   -0.01%     
==========================================
  Files         120      120              
  Lines       20185    20181       -4     
==========================================
- Hits        17897    17892       -5     
- Misses       2288     2289       +1     
Flag Coverage Δ
habitat-gpu 25.03% <20.00%> (-0.02%) ⬇️
linux-cpu 85.58% <100.00%> (?)
linux-gpu 86.53% <100.00%> (-0.01%) ⬇️
linux-jumanji 30.25% <20.00%> (-0.02%) ⬇️
linux-outdeps-gpu 72.20% <96.72%> (-0.02%) ⬇️
linux-stable-cpu 85.43% <100.00%> (-0.01%) ⬇️
linux-stable-gpu 86.17% <100.00%> (-0.01%) ⬇️
linux_examples-gpu 43.07% <20.00%> (-0.02%) ⬇️
macos-cpu 85.25% <100.00%> (-0.01%) ⬇️
olddeps-gpu 75.99% <86.88%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
test/test_cost.py 96.18% <100.00%> (+0.01%) ⬆️
test/test_helpers.py 91.59% <100.00%> (ø)
test/test_tensordictmodules.py 96.84% <100.00%> (-0.04%) ⬇️
torchrl/modules/tensordict_module/actors.py 89.47% <100.00%> (-0.26%) ⬇️
torchrl/objectives/redq.py 91.42% <100.00%> (-0.32%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@vmoens vmoens changed the title Sequential probabilistic changes [Refactor] Sequential probabilistic changes Dec 12, 2022
@tcbegley
Copy link
Contributor Author

Superceded by #728 + #738

@tcbegley tcbegley closed this Dec 16, 2022
@tcbegley tcbegley deleted the seq-prob branch March 29, 2023 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants