Skip to content

Commit a4506df

Browse files
tcbegleyvmoens
andauthored
[Feature] Migrate to tensordict.nn.TensorDictModule (#700)
* Migrate TensorDictModule * Migrate functional modules * Migrate probabilistic modules * Patch set_exploration_mode * Lint and format * Migrate sequential modules * Adopt tensordict.nn.utils where possible * Rerun CI * Delete tests duplicated from tensordict * minor * Remove references to torchrl.modules.TensorDictWrapper * Rename TensorDictModule -> SafeModule * Delete redundant methods following inheritance fixes * Some docstring improvements Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent c52caac commit a4506df

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+506
-1945
lines changed

README.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL:
3636
```diff
3737
- obs, done = env.reset()
3838
+ tensordict = env.reset()
39-
policy = TensorDictModule(
39+
policy = SafeModule(
4040
model,
4141
in_keys=["observation_pixels", "observation_vector"],
4242
out_keys=["action"],
@@ -106,14 +106,14 @@ Here's another example of an off-policy training loop in TorchRL (assuming that
106106

107107
Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information.
108108

109-
The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
110-
109+
The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible!
110+
111111
<details>
112112
<summary>Code</summary>
113113

114114
```diff
115115
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
116-
+ td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
116+
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
117117
src = torch.rand((10, 32, 512))
118118
tgt = torch.rand((20, 32, 512))
119119
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
@@ -122,19 +122,19 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm
122122
+ out = tensordict["out"]
123123
```
124124

125-
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
125+
The `SafeSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
126126
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
127127
```python
128128
encoder_module = TransformerEncoder(...)
129-
encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
129+
encoder = SafeModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
130130
decoder_module = TransformerDecoder(...)
131-
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
132-
transformer = TensorDictSequential(encoder, decoder)
131+
decoder = SafeModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
132+
transformer = SafeSequential(encoder, decoder)
133133
assert transformer.in_keys == ["src", "src_mask", "tgt"]
134134
assert transformer.out_keys == ["memory", "output"]
135135
```
136136

137-
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
137+
`SafeSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
138138
```python
139139
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
140140
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
@@ -261,9 +261,9 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm
261261
kernel_sizes=[8, 4, 3],
262262
strides=[4, 2, 1],
263263
)
264-
# Wrap it in a TensorDictModule, indicating what key to read in and where to
264+
# Wrap it in a SafeModule, indicating what key to read in and where to
265265
# write out the output
266-
common_module = TensorDictModule(
266+
common_module = SafeModule(
267267
common_module,
268268
in_keys=["pixels"],
269269
out_keys=["hidden"],
@@ -277,10 +277,10 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm
277277
activation=nn.ELU,
278278
)
279279
)
280-
# Wrap the nn.Module in a ProbabilisticTensorDictModule, indicating how
280+
# Wrap the nn.Module in a SafeProbabilisticModule, indicating how
281281
# to build the torch.distribution.Distribution object and what to do with it
282-
policy_module = ProbabilisticTensorDictModule( # stochastic policy
283-
TensorDictModule(
282+
policy_module = SafeProbabilisticModule( # stochastic policy
283+
SafeModule(
284284
policy_module,
285285
in_keys=["hidden"],
286286
out_keys=["loc", "scale"],
@@ -409,7 +409,7 @@ pip3 install torchrl
409409
This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one
410410
should install the library locally (see below).
411411

412-
The **nightly build** can be installed via
412+
The **nightly build** can be installed via
413413
```
414414
pip install torchrl-nightly
415415
```

docs/source/reference/envs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ With these, the following methods are implemented:
5050
having reproducible results.
5151
- :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for
5252
a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`.
53-
The policy should be coded using a :obj:`TensorDictModule` (or any other
53+
The policy should be coded using a :obj:`SafeModule` (or any other
5454
:obj:`TensorDict`-compatible module).
5555

5656

docs/source/reference/modules.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ TensorDict modules
1111
:toctree: generated/
1212
:template: rl_template_noinherit.rst
1313

14-
TensorDictModule
15-
ProbabilisticTensorDictModule
16-
TensorDictSequential
17-
TensorDictModuleWrapper
14+
SafeModule
15+
SafeProbabilisticModule
16+
SafeSequential
1817
Actor
1918
ProbabilisticActor
2019
ValueOperator

test/smoke_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ def test_imports():
66
) # noqa: F401
77
from torchrl.envs import Transform, TransformedEnv # noqa: F401
88
from torchrl.envs.gym_like import GymLikeEnv # noqa: F401
9-
from torchrl.modules import TensorDictModule # noqa: F401
9+
from torchrl.modules import SafeModule # noqa: F401
1010
from torchrl.objectives.common import LossModule # noqa: F401

test/test_collector.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@
3535
from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv
3636
from torchrl.envs.libs.gym import _has_gym, GymEnv
3737
from torchrl.envs.transforms import TransformedEnv, VecNorm
38-
from torchrl.modules import (
39-
Actor,
40-
LSTMNet,
41-
OrnsteinUhlenbeckProcessWrapper,
42-
TensorDictModule,
43-
)
38+
from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule
4439

4540
# torch.set_default_dtype(torch.double)
4641

@@ -754,7 +749,7 @@ def create_env():
754749
return ContinuousActionVecMockEnv()
755750

756751
n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1]
757-
policy = TensorDictModule(
752+
policy = SafeModule(
758753
torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"]
759754
)
760755
policy(create_env().reset())
@@ -898,7 +893,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
898893
next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec),
899894
)
900895

901-
policy = TensorDictModule(**policy_kwargs)
896+
policy = SafeModule(**policy_kwargs)
902897

903898
env_maker = lambda: GymEnv(PENDULUM_VERSIONED)
904899

@@ -985,12 +980,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker):
985980

986981
if collector_class is not SyncDataCollector:
987982
assert all(
988-
isinstance(p, TensorDictModule) for p in collector._policy_dict.values()
983+
isinstance(p, SafeModule) for p in collector._policy_dict.values()
989984
)
990985
assert all(p.out_keys == out_keys for p in collector._policy_dict.values())
991986
assert all(p.module is policy for p in collector._policy_dict.values())
992987
else:
993-
assert isinstance(collector.policy, TensorDictModule)
988+
assert isinstance(collector.policy, SafeModule)
994989
assert collector.policy.out_keys == out_keys
995990
assert collector.policy.module is policy
996991

test/test_cost.py

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import argparse
77
from copy import deepcopy
88

9-
from torchrl.modules.functional_modules import FunctionalModuleWithBuffers
9+
from tensordict.nn.functional_modules import FunctionalModuleWithBuffers
1010

1111
_has_functorch = True
1212
try:
@@ -41,10 +41,10 @@
4141
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
4242
from torchrl.modules import (
4343
DistributionalQValueActor,
44-
ProbabilisticTensorDictModule,
4544
QValueActor,
46-
TensorDictModule,
47-
TensorDictSequential,
45+
SafeModule,
46+
SafeProbabilisticModule,
47+
SafeSequential,
4848
WorldModelWrapper,
4949
)
5050
from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
@@ -787,9 +787,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
787787
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
788788
)
789789
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
790-
module = TensorDictModule(
791-
net, in_keys=["observation"], out_keys=["loc", "scale"]
792-
)
790+
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
793791
actor = ProbabilisticActor(
794792
spec=CompositeSpec(action=action_spec, loc=None, scale=None),
795793
module=module,
@@ -1112,9 +1110,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
11121110
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
11131111
)
11141112
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
1115-
module = TensorDictModule(
1116-
net, in_keys=["observation"], out_keys=["loc", "scale"]
1117-
)
1113+
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
11181114
actor = ProbabilisticActor(
11191115
module=module,
11201116
distribution_class=TanhNormal,
@@ -1167,13 +1163,9 @@ def __init__(self):
11671163
def forward(self, hidden, act):
11681164
return self.linear(torch.cat([hidden, act], -1))
11691165

1170-
common = TensorDictModule(
1171-
CommonClass(), in_keys=["observation"], out_keys=["hidden"]
1172-
)
1166+
common = SafeModule(CommonClass(), in_keys=["observation"], out_keys=["hidden"])
11731167
actor_subnet = ProbabilisticActor(
1174-
TensorDictModule(
1175-
ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]
1176-
),
1168+
SafeModule(ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]),
11771169
dist_in_keys=["loc", "scale"],
11781170
distribution_class=TanhNormal,
11791171
return_log_prob=True,
@@ -1544,9 +1536,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
15441536
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
15451537
)
15461538
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
1547-
module = TensorDictModule(
1548-
net, in_keys=["observation"], out_keys=["loc", "scale"]
1549-
)
1539+
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
15501540
actor = ProbabilisticActor(
15511541
module=module,
15521542
distribution_class=TanhNormal,
@@ -1779,9 +1769,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
17791769
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
17801770
)
17811771
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
1782-
module = TensorDictModule(
1783-
net, in_keys=["observation"], out_keys=["loc", "scale"]
1784-
)
1772+
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
17851773
actor = ProbabilisticActor(
17861774
module=module,
17871775
distribution_class=TanhNormal,
@@ -2005,9 +1993,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value):
20051993
gamma = 0.9
20061994
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
20071995
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
2008-
module = TensorDictModule(
2009-
net, in_keys=["observation"], out_keys=["loc", "scale"]
2010-
)
1996+
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
20111997
actor_net = ProbabilisticActor(
20121998
module,
20131999
distribution_class=TanhNormal,
@@ -2154,7 +2140,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
21542140

21552141
# World Model and reward model
21562142
rssm_rollout = RSSMRollout(
2157-
TensorDictModule(
2143+
SafeModule(
21582144
rssm_prior,
21592145
in_keys=["state", "belief", "action"],
21602146
out_keys=[
@@ -2164,7 +2150,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
21642150
("next", "belief"),
21652151
],
21662152
),
2167-
TensorDictModule(
2153+
SafeModule(
21682154
rssm_posterior,
21692155
in_keys=[("next", "belief"), ("next", "encoded_latents")],
21702156
out_keys=[
@@ -2178,20 +2164,20 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
21782164
out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU
21792165
)
21802166
# World Model and reward model
2181-
world_modeler = TensorDictSequential(
2182-
TensorDictModule(
2167+
world_modeler = SafeSequential(
2168+
SafeModule(
21832169
obs_encoder,
21842170
in_keys=[("next", "pixels")],
21852171
out_keys=[("next", "encoded_latents")],
21862172
),
21872173
rssm_rollout,
2188-
TensorDictModule(
2174+
SafeModule(
21892175
obs_decoder,
21902176
in_keys=[("next", "state"), ("next", "belief")],
21912177
out_keys=[("next", "reco_pixels")],
21922178
),
21932179
)
2194-
reward_module = TensorDictModule(
2180+
reward_module = SafeModule(
21952181
reward_module,
21962182
in_keys=[("next", "state"), ("next", "belief")],
21972183
out_keys=["reward"],
@@ -2225,8 +2211,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22252211
reward_module = MLP(
22262212
out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU
22272213
)
2228-
transition_model = TensorDictSequential(
2229-
TensorDictModule(
2214+
transition_model = SafeSequential(
2215+
SafeModule(
22302216
rssm_prior,
22312217
in_keys=["state", "belief", "action"],
22322218
out_keys=[
@@ -2237,7 +2223,7 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22372223
],
22382224
),
22392225
)
2240-
reward_model = TensorDictModule(
2226+
reward_model = SafeModule(
22412227
reward_module,
22422228
in_keys=["state", "belief"],
22432229
out_keys=["reward"],
@@ -2271,8 +2257,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22712257
num_cells=mlp_num_units,
22722258
activation_class=nn.ELU,
22732259
)
2274-
actor_model = ProbabilisticTensorDictModule(
2275-
TensorDictModule(
2260+
actor_model = SafeProbabilisticModule(
2261+
SafeModule(
22762262
actor_module,
22772263
in_keys=["state", "belief"],
22782264
out_keys=["loc", "scale"],
@@ -2294,7 +2280,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22942280
return actor_model
22952281

22962282
def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
2297-
value_model = TensorDictModule(
2283+
value_model = SafeModule(
22982284
MLP(
22992285
out_features=1,
23002286
depth=3,
@@ -2396,7 +2382,7 @@ def test_dreamer_env(self, device, imagination_horizon, discount_loss):
23962382
# test reconstruction
23972383
with pytest.raises(ValueError, match="No observation decoder provided"):
23982384
mb_env.decode_obs(rollout)
2399-
mb_env.obs_decoder = TensorDictModule(
2385+
mb_env.obs_decoder = SafeModule(
24002386
nn.LazyLinear(4, device=device),
24012387
in_keys=["state"],
24022388
out_keys=["reco_observation"],
@@ -2915,13 +2901,13 @@ def test_shared_params(dest, expected_dtype, expected_device):
29152901
if torch.cuda.device_count() == 0 and dest == "cuda":
29162902
pytest.skip("no cuda device available")
29172903
module_hidden = torch.nn.Linear(4, 4)
2918-
td_module_hidden = TensorDictModule(
2904+
td_module_hidden = SafeModule(
29192905
module=module_hidden,
29202906
spec=None,
29212907
in_keys=["observation"],
29222908
out_keys=["hidden"],
29232909
)
2924-
module_action = TensorDictModule(
2910+
module_action = SafeModule(
29252911
NormalParamWrapper(torch.nn.Linear(4, 8)),
29262912
in_keys=["hidden"],
29272913
out_keys=["loc", "scale"],

0 commit comments

Comments
 (0)