66import argparse
77from copy import deepcopy
88
9- from torchrl . modules .functional_modules import FunctionalModuleWithBuffers
9+ from tensordict . nn .functional_modules import FunctionalModuleWithBuffers
1010
1111_has_functorch = True
1212try :
4141from torchrl .envs .transforms import TensorDictPrimer , TransformedEnv
4242from torchrl .modules import (
4343 DistributionalQValueActor ,
44- ProbabilisticTensorDictModule ,
4544 QValueActor ,
46- TensorDictModule ,
47- TensorDictSequential ,
45+ SafeModule ,
46+ SafeProbabilisticModule ,
47+ SafeSequential ,
4848 WorldModelWrapper ,
4949)
5050from torchrl .modules .distributions .continuous import NormalParamWrapper , TanhNormal
@@ -787,9 +787,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
787787 - torch .ones (action_dim ), torch .ones (action_dim ), (action_dim ,)
788788 )
789789 net = NormalParamWrapper (nn .Linear (obs_dim , 2 * action_dim ))
790- module = TensorDictModule (
791- net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ]
792- )
790+ module = SafeModule (net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ])
793791 actor = ProbabilisticActor (
794792 spec = CompositeSpec (action = action_spec , loc = None , scale = None ),
795793 module = module ,
@@ -1112,9 +1110,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
11121110 - torch .ones (action_dim ), torch .ones (action_dim ), (action_dim ,)
11131111 )
11141112 net = NormalParamWrapper (nn .Linear (obs_dim , 2 * action_dim ))
1115- module = TensorDictModule (
1116- net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ]
1117- )
1113+ module = SafeModule (net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ])
11181114 actor = ProbabilisticActor (
11191115 module = module ,
11201116 distribution_class = TanhNormal ,
@@ -1167,13 +1163,9 @@ def __init__(self):
11671163 def forward (self , hidden , act ):
11681164 return self .linear (torch .cat ([hidden , act ], - 1 ))
11691165
1170- common = TensorDictModule (
1171- CommonClass (), in_keys = ["observation" ], out_keys = ["hidden" ]
1172- )
1166+ common = SafeModule (CommonClass (), in_keys = ["observation" ], out_keys = ["hidden" ])
11731167 actor_subnet = ProbabilisticActor (
1174- TensorDictModule (
1175- ActorClass (), in_keys = ["hidden" ], out_keys = ["loc" , "scale" ]
1176- ),
1168+ SafeModule (ActorClass (), in_keys = ["hidden" ], out_keys = ["loc" , "scale" ]),
11771169 dist_in_keys = ["loc" , "scale" ],
11781170 distribution_class = TanhNormal ,
11791171 return_log_prob = True ,
@@ -1544,9 +1536,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
15441536 - torch .ones (action_dim ), torch .ones (action_dim ), (action_dim ,)
15451537 )
15461538 net = NormalParamWrapper (nn .Linear (obs_dim , 2 * action_dim ))
1547- module = TensorDictModule (
1548- net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ]
1549- )
1539+ module = SafeModule (net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ])
15501540 actor = ProbabilisticActor (
15511541 module = module ,
15521542 distribution_class = TanhNormal ,
@@ -1779,9 +1769,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
17791769 - torch .ones (action_dim ), torch .ones (action_dim ), (action_dim ,)
17801770 )
17811771 net = NormalParamWrapper (nn .Linear (obs_dim , 2 * action_dim ))
1782- module = TensorDictModule (
1783- net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ]
1784- )
1772+ module = SafeModule (net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ])
17851773 actor = ProbabilisticActor (
17861774 module = module ,
17871775 distribution_class = TanhNormal ,
@@ -2005,9 +1993,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value):
20051993 gamma = 0.9
20061994 value_net = ValueOperator (nn .Linear (n_obs , 1 ), in_keys = ["observation" ])
20071995 net = NormalParamWrapper (nn .Linear (n_obs , 2 * n_act ))
2008- module = TensorDictModule (
2009- net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ]
2010- )
1996+ module = SafeModule (net , in_keys = ["observation" ], out_keys = ["loc" , "scale" ])
20111997 actor_net = ProbabilisticActor (
20121998 module ,
20131999 distribution_class = TanhNormal ,
@@ -2154,7 +2140,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
21542140
21552141 # World Model and reward model
21562142 rssm_rollout = RSSMRollout (
2157- TensorDictModule (
2143+ SafeModule (
21582144 rssm_prior ,
21592145 in_keys = ["state" , "belief" , "action" ],
21602146 out_keys = [
@@ -2164,7 +2150,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
21642150 ("next" , "belief" ),
21652151 ],
21662152 ),
2167- TensorDictModule (
2153+ SafeModule (
21682154 rssm_posterior ,
21692155 in_keys = [("next" , "belief" ), ("next" , "encoded_latents" )],
21702156 out_keys = [
@@ -2178,20 +2164,20 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
21782164 out_features = 1 , depth = 2 , num_cells = mlp_num_units , activation_class = nn .ELU
21792165 )
21802166 # World Model and reward model
2181- world_modeler = TensorDictSequential (
2182- TensorDictModule (
2167+ world_modeler = SafeSequential (
2168+ SafeModule (
21832169 obs_encoder ,
21842170 in_keys = [("next" , "pixels" )],
21852171 out_keys = [("next" , "encoded_latents" )],
21862172 ),
21872173 rssm_rollout ,
2188- TensorDictModule (
2174+ SafeModule (
21892175 obs_decoder ,
21902176 in_keys = [("next" , "state" ), ("next" , "belief" )],
21912177 out_keys = [("next" , "reco_pixels" )],
21922178 ),
21932179 )
2194- reward_module = TensorDictModule (
2180+ reward_module = SafeModule (
21952181 reward_module ,
21962182 in_keys = [("next" , "state" ), ("next" , "belief" )],
21972183 out_keys = ["reward" ],
@@ -2225,8 +2211,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22252211 reward_module = MLP (
22262212 out_features = 1 , depth = 2 , num_cells = mlp_num_units , activation_class = nn .ELU
22272213 )
2228- transition_model = TensorDictSequential (
2229- TensorDictModule (
2214+ transition_model = SafeSequential (
2215+ SafeModule (
22302216 rssm_prior ,
22312217 in_keys = ["state" , "belief" , "action" ],
22322218 out_keys = [
@@ -2237,7 +2223,7 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22372223 ],
22382224 ),
22392225 )
2240- reward_model = TensorDictModule (
2226+ reward_model = SafeModule (
22412227 reward_module ,
22422228 in_keys = ["state" , "belief" ],
22432229 out_keys = ["reward" ],
@@ -2271,8 +2257,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22712257 num_cells = mlp_num_units ,
22722258 activation_class = nn .ELU ,
22732259 )
2274- actor_model = ProbabilisticTensorDictModule (
2275- TensorDictModule (
2260+ actor_model = SafeProbabilisticModule (
2261+ SafeModule (
22762262 actor_module ,
22772263 in_keys = ["state" , "belief" ],
22782264 out_keys = ["loc" , "scale" ],
@@ -2294,7 +2280,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
22942280 return actor_model
22952281
22962282 def _create_value_model (self , rssm_hidden_dim , state_dim , mlp_num_units = 200 ):
2297- value_model = TensorDictModule (
2283+ value_model = SafeModule (
22982284 MLP (
22992285 out_features = 1 ,
23002286 depth = 3 ,
@@ -2396,7 +2382,7 @@ def test_dreamer_env(self, device, imagination_horizon, discount_loss):
23962382 # test reconstruction
23972383 with pytest .raises (ValueError , match = "No observation decoder provided" ):
23982384 mb_env .decode_obs (rollout )
2399- mb_env .obs_decoder = TensorDictModule (
2385+ mb_env .obs_decoder = SafeModule (
24002386 nn .LazyLinear (4 , device = device ),
24012387 in_keys = ["state" ],
24022388 out_keys = ["reco_observation" ],
@@ -2915,13 +2901,13 @@ def test_shared_params(dest, expected_dtype, expected_device):
29152901 if torch .cuda .device_count () == 0 and dest == "cuda" :
29162902 pytest .skip ("no cuda device available" )
29172903 module_hidden = torch .nn .Linear (4 , 4 )
2918- td_module_hidden = TensorDictModule (
2904+ td_module_hidden = SafeModule (
29192905 module = module_hidden ,
29202906 spec = None ,
29212907 in_keys = ["observation" ],
29222908 out_keys = ["hidden" ],
29232909 )
2924- module_action = TensorDictModule (
2910+ module_action = SafeModule (
29252911 NormalParamWrapper (torch .nn .Linear (4 , 8 )),
29262912 in_keys = ["hidden" ],
29272913 out_keys = ["loc" , "scale" ],
0 commit comments