diff --git a/README.md b/README.md index 194472a3736..9ae69a24cda 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL: ```diff - obs, done = env.reset() + tensordict = env.reset() - policy = TensorDictModule( + policy = SafeModule( model, in_keys=["observation_pixels", "observation_vector"], out_keys=["action"], @@ -106,14 +106,14 @@ Here's another example of an off-policy training loop in TorchRL (assuming that Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information. -The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible! - +The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible! +
Code ```diff transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) - + td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"]) + + td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"]) src = torch.rand((10, 32, 512)) tgt = torch.rand((20, 32, 512)) + tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32]) @@ -122,19 +122,19 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm + out = tensordict["out"] ``` - The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way. + The `SafeSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way. For instance, here is an implementation of a transformer using the encoder and decoder blocks: ```python encoder_module = TransformerEncoder(...) - encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"]) + encoder = SafeModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"]) decoder_module = TransformerDecoder(...) - decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"]) - transformer = TensorDictSequential(encoder, decoder) + decoder = SafeModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"]) + transformer = SafeSequential(encoder, decoder) assert transformer.in_keys == ["src", "src_mask", "tgt"] assert transformer.out_keys == ["memory", "output"] ``` - `TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys: + `SafeSequential` allows to isolate subgraphs by querying a set of desired input / output keys: ```python transformer.select_subsequence(out_keys=["memory"]) # returns the encoder transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder @@ -261,9 +261,9 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ) - # Wrap it in a TensorDictModule, indicating what key to read in and where to + # Wrap it in a SafeModule, indicating what key to read in and where to # write out the output - common_module = TensorDictModule( + common_module = SafeModule( common_module, in_keys=["pixels"], out_keys=["hidden"], @@ -277,10 +277,10 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm activation=nn.ELU, ) ) - # Wrap the nn.Module in a ProbabilisticTensorDictModule, indicating how + # Wrap the nn.Module in a SafeProbabilisticModule, indicating how # to build the torch.distribution.Distribution object and what to do with it - policy_module = ProbabilisticTensorDictModule( # stochastic policy - TensorDictModule( + policy_module = SafeProbabilisticModule( # stochastic policy + SafeModule( policy_module, in_keys=["hidden"], out_keys=["loc", "scale"], @@ -409,7 +409,7 @@ pip3 install torchrl This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one should install the library locally (see below). -The **nightly build** can be installed via +The **nightly build** can be installed via ``` pip install torchrl-nightly ``` diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index ad84d38bdc1..748f65d6b68 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -50,7 +50,7 @@ With these, the following methods are implemented: having reproducible results. - :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`. - The policy should be coded using a :obj:`TensorDictModule` (or any other + The policy should be coded using a :obj:`SafeModule` (or any other :obj:`TensorDict`-compatible module). diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 5915c000f96..bf0992be77b 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -11,10 +11,9 @@ TensorDict modules :toctree: generated/ :template: rl_template_noinherit.rst - TensorDictModule - ProbabilisticTensorDictModule - TensorDictSequential - TensorDictModuleWrapper + SafeModule + SafeProbabilisticModule + SafeSequential Actor ProbabilisticActor ValueOperator diff --git a/test/smoke_test.py b/test/smoke_test.py index 630171d4082..f0db69def86 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -6,5 +6,5 @@ def test_imports(): ) # noqa: F401 from torchrl.envs import Transform, TransformedEnv # noqa: F401 from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 - from torchrl.modules import TensorDictModule # noqa: F401 + from torchrl.modules import SafeModule # noqa: F401 from torchrl.objectives.common import LossModule # noqa: F401 diff --git a/test/test_collector.py b/test/test_collector.py index 9d384665106..4b8b70d8444 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -35,12 +35,7 @@ from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv from torchrl.envs.libs.gym import _has_gym, GymEnv from torchrl.envs.transforms import TransformedEnv, VecNorm -from torchrl.modules import ( - Actor, - LSTMNet, - OrnsteinUhlenbeckProcessWrapper, - TensorDictModule, -) +from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule # torch.set_default_dtype(torch.double) @@ -754,7 +749,7 @@ def create_env(): return ContinuousActionVecMockEnv() n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] - policy = TensorDictModule( + policy = SafeModule( torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] ) policy(create_env().reset()) @@ -898,7 +893,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec), ) - policy = TensorDictModule(**policy_kwargs) + policy = SafeModule(**policy_kwargs) env_maker = lambda: GymEnv(PENDULUM_VERSIONED) @@ -985,12 +980,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): if collector_class is not SyncDataCollector: assert all( - isinstance(p, TensorDictModule) for p in collector._policy_dict.values() + isinstance(p, SafeModule) for p in collector._policy_dict.values() ) assert all(p.out_keys == out_keys for p in collector._policy_dict.values()) assert all(p.module is policy for p in collector._policy_dict.values()) else: - assert isinstance(collector.policy, TensorDictModule) + assert isinstance(collector.policy, SafeModule) assert collector.policy.out_keys == out_keys assert collector.policy.module is policy diff --git a/test/test_cost.py b/test/test_cost.py index 233bcce5788..83bf01f49cb 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6,7 +6,7 @@ import argparse from copy import deepcopy -from torchrl.modules.functional_modules import FunctionalModuleWithBuffers +from tensordict.nn.functional_modules import FunctionalModuleWithBuffers _has_functorch = True try: @@ -41,10 +41,10 @@ from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv from torchrl.modules import ( DistributionalQValueActor, - ProbabilisticTensorDictModule, QValueActor, - TensorDictModule, - TensorDictSequential, + SafeModule, + SafeProbabilisticModule, + SafeSequential, WorldModelWrapper, ) from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal @@ -777,9 +777,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( spec=CompositeSpec(action=action_spec, loc=None, scale=None), module=module, @@ -1096,9 +1094,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -1151,13 +1147,9 @@ def __init__(self): def forward(self, hidden, act): return self.linear(torch.cat([hidden, act], -1)) - common = TensorDictModule( - CommonClass(), in_keys=["observation"], out_keys=["hidden"] - ) + common = SafeModule(CommonClass(), in_keys=["observation"], out_keys=["hidden"]) actor_subnet = ProbabilisticActor( - TensorDictModule( - ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"] - ), + SafeModule(ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]), dist_in_keys=["loc", "scale"], distribution_class=TanhNormal, return_log_prob=True, @@ -1528,9 +1520,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -1763,9 +1753,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -1989,9 +1977,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): gamma = 0.9 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor_net = ProbabilisticActor( module, distribution_class=TanhNormal, @@ -2138,7 +2124,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 # World Model and reward model rssm_rollout = RSSMRollout( - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -2148,7 +2134,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 ("next", "belief"), ], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ @@ -2162,20 +2148,20 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU ) # World Model and reward model - world_modeler = TensorDictSequential( - TensorDictModule( + world_modeler = SafeSequential( + SafeModule( obs_encoder, in_keys=[("next", "pixels")], out_keys=[("next", "encoded_latents")], ), rssm_rollout, - TensorDictModule( + SafeModule( obs_decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "reco_pixels")], ), ) - reward_module = TensorDictModule( + reward_module = SafeModule( reward_module, in_keys=[("next", "state"), ("next", "belief")], out_keys=["reward"], @@ -2209,8 +2195,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): reward_module = MLP( out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU ) - transition_model = TensorDictSequential( - TensorDictModule( + transition_model = SafeSequential( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -2221,7 +2207,7 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): ], ), ) - reward_model = TensorDictModule( + reward_model = SafeModule( reward_module, in_keys=["state", "belief"], out_keys=["reward"], @@ -2255,8 +2241,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): num_cells=mlp_num_units, activation_class=nn.ELU, ) - actor_model = ProbabilisticTensorDictModule( - TensorDictModule( + actor_model = SafeProbabilisticModule( + SafeModule( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], @@ -2278,7 +2264,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): return actor_model def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): - value_model = TensorDictModule( + value_model = SafeModule( MLP( out_features=1, depth=3, @@ -2380,7 +2366,7 @@ def test_dreamer_env(self, device, imagination_horizon, discount_loss): # test reconstruction with pytest.raises(ValueError, match="No observation decoder provided"): mb_env.decode_obs(rollout) - mb_env.obs_decoder = TensorDictModule( + mb_env.obs_decoder = SafeModule( nn.LazyLinear(4, device=device), in_keys=["state"], out_keys=["reco_observation"], @@ -2896,13 +2882,13 @@ def test_shared_params(dest, expected_dtype, expected_device): if torch.cuda.device_count() == 0 and dest == "cuda": pytest.skip("no cuda device available") module_hidden = torch.nn.Linear(4, 4) - td_module_hidden = TensorDictModule( + td_module_hidden = SafeModule( module=module_hidden, spec=None, in_keys=["observation"], out_keys=["hidden"], ) - module_action = TensorDictModule( + module_action = SafeModule( NormalParamWrapper(torch.nn.Linear(4, 8)), in_keys=["hidden"], out_keys=["loc", "scale"], diff --git a/test/test_env.py b/test/test_env.py index 97b1cd5f8e8..c4379ec203d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -46,13 +46,7 @@ ) from torchrl.envs.utils import step_mdp from torchrl.envs.vec_env import ParallelEnv, SerialEnv -from torchrl.modules import ( - Actor, - ActorCriticOperator, - MLP, - TensorDictModule, - ValueOperator, -) +from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator from torchrl.modules.tensordict_module import WorldModelWrapper gym_version = None @@ -305,12 +299,12 @@ def test_mb_rollout(self, device, seed=0): torch.manual_seed(seed) np.random.seed(seed) world_model = WorldModelWrapper( - TensorDictModule( + SafeModule( ActionObsMergeLinear(5, 4), in_keys=["hidden_observation", "action"], out_keys=["hidden_observation"], ), - TensorDictModule( + SafeModule( nn.Linear(4, 1), in_keys=["hidden_observation"], out_keys=["reward"], @@ -331,12 +325,12 @@ def test_mb_env_batch_lock(self, device, seed=0): torch.manual_seed(seed) np.random.seed(seed) world_model = WorldModelWrapper( - TensorDictModule( + SafeModule( ActionObsMergeLinear(5, 4), in_keys=["hidden_observation", "action"], out_keys=["hidden_observation"], ), - TensorDictModule( + SafeModule( nn.Linear(4, 1), in_keys=["hidden_observation"], out_keys=["reward"], @@ -551,13 +545,13 @@ def test_parallel_env_with_policy( ) policy = ActorCriticOperator( - TensorDictModule( + SafeModule( spec=None, module=nn.LazyLinear(12), in_keys=["observation"], out_keys=["hidden"], ), - TensorDictModule( + SafeModule( spec=None, module=nn.LazyLinear(env0.action_spec.shape[-1]), in_keys=["hidden"], diff --git a/test/test_exploration.py b/test/test_exploration.py index 40c3be7d78b..17f64983865 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,7 +14,7 @@ from torchrl.data import CompositeSpec, NdBoundedTensorSpec from torchrl.envs.transforms.transforms import gSDENoise from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import TensorDictModule, TensorDictSequential +from torchrl.modules import SafeModule, SafeSequential from torchrl.modules.distributions import TanhNormal from torchrl.modules.distributions.continuous import ( IndependentNormal, @@ -60,7 +60,7 @@ def test_ou(device, seed=0): def test_ou_wrapper(device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) - module = TensorDictModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = NdBoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( spec=action_spec, @@ -112,7 +112,7 @@ def test_additivegaussian_sd( (d_act,), device=device, ) - module = TensorDictModule( + module = SafeModule( net, in_keys=["observation"], out_keys=["loc", "scale"], @@ -172,9 +172,7 @@ def test_additivegaussian_wrapper( ): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = NdBoundedTensorSpec( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), @@ -229,9 +227,9 @@ def test_gsde( if gSDE: model = torch.nn.LazyLinear(action_dim, device=device) in_keys = ["observation"] - module = TensorDictSequential( - TensorDictModule(model, in_keys=in_keys, out_keys=["action"]), - TensorDictModule( + module = SafeSequential( + SafeModule(model, in_keys=in_keys, out_keys=["action"]), + SafeModule( LazygSDEModule(device=device), in_keys=["action", "observation", "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -243,7 +241,7 @@ def test_gsde( in_keys = ["observation"] model = torch.nn.LazyLinear(action_dim * 2, device=device) wrapper = NormalParamWrapper(model) - module = TensorDictModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) + module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"min": -bound, "max": bound} spec = NdBoundedTensorSpec( diff --git a/test/test_functorch.py b/test/test_functorch.py index e84b41a8679..7b043968afb 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -10,12 +10,12 @@ except ImportError: _has_functorch = False from tensordict import TensorDict -from torch import nn -from torchrl.modules import TensorDictModule, TensorDictSequential -from torchrl.modules.functional_modules import ( +from tensordict.nn.functional_modules import ( FunctionalModule, FunctionalModuleWithBuffers, ) +from torch import nn +from torchrl.modules import SafeModule, SafeSequential @pytest.mark.skipif( @@ -77,7 +77,7 @@ def test_vmap_tdmodule(moduletype, batch_params): raise NotImplementedError if moduletype == "linear": fmodule, params = FunctionalModule._create_from(module) - tdmodule = TensorDictModule(fmodule, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) x = torch.randn(10, 1, 3) td = TensorDict({"x": x}, [10]) if batch_params: @@ -89,7 +89,7 @@ def test_vmap_tdmodule(moduletype, batch_params): assert y.shape == torch.Size([10, 1, 4]) elif moduletype == "bn1": fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - tdmodule = TensorDictModule(fmodule, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) x = torch.randn(10, 2, 3) td = TensorDict({"x": x}, [10]) if batch_params: @@ -121,7 +121,7 @@ def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): else: raise NotImplementedError if moduletype == "linear": - tdmodule = TensorDictModule(module, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) x = torch.randn(10, 1, 3) td = TensorDict({"x": x}, [10]) @@ -134,7 +134,7 @@ def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): y = td["y"] assert y.shape == torch.Size([10, 1, 4]) elif moduletype == "bn1": - tdmodule = TensorDictModule(module, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) x = torch.randn(10, 2, 3) td = TensorDict({"x": x}, [10]) @@ -173,10 +173,10 @@ def test_vmap_tdsequence(moduletype, batch_params): else: raise NotImplementedError if moduletype == "linear": - tdmodule1 = TensorDictModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(fmodule2, in_keys=["y"], out_keys=["z"]) + tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) params = TensorDict({"0": params1, "1": params2}, []) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) assert {"0", "1"} == set(params.keys()) x = torch.randn(10, 1, 3) td = TensorDict({"x": x}, [10]) @@ -188,11 +188,11 @@ def test_vmap_tdsequence(moduletype, batch_params): z = td["z"] assert z.shape == torch.Size([10, 1, 5]) elif moduletype == "bn1": - tdmodule1 = TensorDictModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(fmodule2, in_keys=["y"], out_keys=["z"]) + tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) params = TensorDict({"0": params1, "1": params2}, []) buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) assert {"0", "1"} == set(params.keys()) assert {"0", "1"} == set(buffers.keys()) x = torch.randn(10, 2, 3) @@ -228,9 +228,9 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): else: raise NotImplementedError if moduletype == "linear": - tdmodule1 = TensorDictModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) + tdmodule = SafeSequential(tdmodule1, tdmodule2) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) assert {"0", "1"} == set(params.keys()) x = torch.randn(10, 1, 3) @@ -244,9 +244,9 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): z = td["z"] assert z.shape == torch.Size([10, 1, 5]) elif moduletype == "bn1": - tdmodule1 = TensorDictModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) + tdmodule = SafeSequential(tdmodule1, tdmodule2) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) assert {"0", "1"} == set(params.keys()) assert {"0", "1"} == set(buffers.keys()) diff --git a/test/test_modules.py b/test/test_modules.py index 59822843835..3a83f48c18c 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -11,6 +11,10 @@ from mocking_classes import MockBatchedUnLockedEnv from packaging import version from tensordict import TensorDict +from tensordict.nn.functional_modules import ( + FunctionalModule, + FunctionalModuleWithBuffers, +) from torch import nn from torchrl.data.tensor_specs import ( DiscreteTensorSpec, @@ -23,13 +27,9 @@ LSTMNet, ProbabilisticActor, QValueActor, - TensorDictModule, + SafeModule, ValueOperator, ) -from torchrl.modules.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, -) from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear from torchrl.modules.models.model_based import ( DreamerActor, @@ -259,10 +259,10 @@ def make_net(): @pytest.mark.parametrize("device", get_available_devices()) def test_actorcritic(device): - common_module = TensorDictModule( + common_module = SafeModule( spec=None, module=nn.Linear(3, 4), in_keys=["obs"], out_keys=["hidden"] ).to(device) - module = TensorDictModule(nn.Linear(4, 5), in_keys=["hidden"], out_keys=["param"]) + module = SafeModule(nn.Linear(4, 5), in_keys=["hidden"], out_keys=["param"]) policy_operator = ProbabilisticActor( spec=None, module=module, dist_in_keys=["param"], return_log_prob=True ).to(device) @@ -613,7 +613,7 @@ def test_rssm_rollout( ).to(device) rssm_rollout = RSSMRollout( - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -623,7 +623,7 @@ def test_rssm_rollout( ("next", "belief"), ], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 5b80a7872cf..60f513ae213 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -15,7 +15,7 @@ _has_functorch = True except ImportError: - from torchrl.modules.functional_modules import ( + from tensordict.nn.functional_modules import ( FunctionalModule, FunctionalModuleWithBuffers, ) @@ -30,15 +30,13 @@ NdUnboundedContinuousTensorSpec, ) from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import NormalParamWrapper, TanhNormal, TensorDictModule +from torchrl.modules import NormalParamWrapper, SafeModule, TanhNormal from torchrl.modules.tensordict_module.common import ( ensure_tensordict_compatible, is_tensordict_compatible, ) -from torchrl.modules.tensordict_module.probabilistic import ( - ProbabilisticTensorDictModule, -) -from torchrl.modules.tensordict_module.sequence import TensorDictSequential +from torchrl.modules.tensordict_module.probabilistic import SafeProbabilisticModule +from torchrl.modules.tensordict_module.sequence import SafeSequential class TestTDModule: @@ -53,7 +51,7 @@ def __init__(self, in_1, out_1, out_2, out_3): def forward(self, x): return self.linear_1(x), self.linear_2(x), self.linear_3(x) - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["input"], out_keys=["out_1", "out_2", "out_3"], @@ -68,7 +66,7 @@ def forward(self, x): assert td.get("out_3").shape == torch.Size([3, 2]) # Using "_" key to ignore some output - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["input"], out_keys=["_", "_", "out_3"], @@ -98,7 +96,7 @@ def forward(self, x): # warning due to "_" in spec keys with pytest.warns(UserWarning, match='got a spec with key "_"'): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( MultiHeadLinear(5, 4, 3), in_keys=["input"], out_keys=["_", "out_2"], @@ -129,7 +127,7 @@ def test_stateful(self, safe, spec_type, lazy): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( module=net, spec=spec, in_keys=["in"], @@ -138,7 +136,7 @@ def test_stateful(self, safe, spec_type, lazy): ) return else: - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( module=net, spec=spec, in_keys=["in"], @@ -171,7 +169,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) net = nn.Linear(3, 4 * param_multiplier) in_keys = ["in"] - net = TensorDictModule( + net = SafeModule( module=NormalParamWrapper(net), spec=None, in_keys=in_keys, @@ -206,7 +204,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=net, spec=spec, dist_in_keys=dist_in_keys, @@ -216,7 +214,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) ) return else: - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=net, spec=spec, dist_in_keys=dist_in_keys, @@ -260,7 +258,7 @@ def test_functional(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -269,7 +267,7 @@ def test_functional(self, safe, spec_type): ) return else: - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -298,7 +296,7 @@ def test_functional_probabilistic(self, safe, spec_type): in_keys = ["in"] net = NormalParamWrapper(net) fnet, params = make_functional(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -322,7 +320,7 @@ def test_functional_probabilistic(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -332,7 +330,7 @@ def test_functional_probabilistic(self, safe, spec_type): ) return else: - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -361,7 +359,7 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) in_keys = ["in"] net = NormalParamWrapper(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -385,7 +383,7 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -395,7 +393,7 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): ) return else: - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -442,7 +440,7 @@ def test_functional_with_buffer(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -451,7 +449,7 @@ def test_functional_with_buffer(self, safe, spec_type): ) return else: - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -480,7 +478,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): in_keys = ["in"] net = NormalParamWrapper(net) fnet, params, buffers = make_functional_with_buffers(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -504,7 +502,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -514,7 +512,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -543,7 +541,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_ty net = nn.BatchNorm1d(32 * param_multiplier) in_keys = ["in"] net = NormalParamWrapper(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -567,7 +565,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_ty match="is not a valid configuration as the tensor specs are not " "specified", ): - ProbabilisticTensorDictModule( + SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -577,7 +575,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_ty ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -624,7 +622,7 @@ def test_vmap(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -633,7 +631,7 @@ def test_vmap(self, safe, spec_type): ) return else: - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -690,7 +688,7 @@ def test_vmap_probabilistic(self, safe, spec_type): net = NormalParamWrapper(net) in_keys = ["in"] fnet, params = make_functional(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -714,7 +712,7 @@ def test_vmap_probabilistic(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -724,7 +722,7 @@ def test_vmap_probabilistic(self, safe, spec_type): ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -781,7 +779,7 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) net = NormalParamWrapper(net) in_keys = ["in"] - tdnet = TensorDictModule( + tdnet = SafeModule( module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -805,7 +803,7 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -815,7 +813,7 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -865,28 +863,14 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): class TestTDSequence: def test_in_key_warning(self): with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] ) with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] ) - def test_key_exclusion(self): - module1 = TensorDictModule( - nn.Linear(3, 4), in_keys=["key1", "key2"], out_keys=["foo1"] - ) - module2 = TensorDictModule( - nn.Linear(3, 4), in_keys=["key1", "key3"], out_keys=["key1"] - ) - module3 = TensorDictModule( - nn.Linear(3, 4), in_keys=["foo1", "key3"], out_keys=["key2"] - ) - seq = TensorDictSequential(module1, module2, module3) - assert set(seq.in_keys) == {"key1", "key2", "key3"} - assert set(seq.out_keys) == {"foo1", "key1", "key2"} - @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) @pytest.mark.parametrize("lazy", [True, False]) @@ -914,21 +898,21 @@ def test_stateful(self, safe, spec_type, lazy): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( spec=spec, module=net2, in_keys=["hidden"], @@ -936,7 +920,7 @@ def test_stateful(self, safe, spec_type, lazy): safe=False, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -981,9 +965,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) + net2 = SafeModule(module=net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1002,21 +984,21 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( spec=spec, module=net2, dist_in_keys=["loc", "scale"], @@ -1024,7 +1006,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): safe=False, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1082,24 +1064,24 @@ def test_functional(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( fnet2, spec=spec, in_keys=["hidden"], out_keys=["out"], safe=safe, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1143,9 +1125,7 @@ def test_functional_probabilistic(self, safe, spec_type): fnet1, params1 = make_functional(net1) fdummy_net, _ = make_functional(dummy_net) fnet2, params2 = make_functional(net2) - fnet2 = TensorDictModule( - module=fnet2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) + fnet2 = SafeModule(module=fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) if isinstance(params1, TensorDictBase): params = TensorDict({"0": params1, "1": params2}, []) else: @@ -1168,17 +1148,17 @@ def test_functional_probabilistic(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, dist_in_keys=["loc", "scale"], @@ -1186,7 +1166,7 @@ def test_functional_probabilistic(self, safe, spec_type): safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1255,24 +1235,24 @@ def test_functional_with_buffer( if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( fnet2, spec=spec, in_keys=["hidden"], out_keys=["out"], safe=safe, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1323,8 +1303,8 @@ def test_functional_with_buffer_probabilistic( fnet1, params1, buffers1 = make_functional_with_buffers(net1) fdummy_net, _, _ = make_functional_with_buffers(dummy_net) # fnet2, params2, buffers2 = make_functional_with_buffers(net2) - # fnet2 = TensorDictModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) - net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) + # fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) + net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) fnet2, (params2, buffers2) = net2.make_functional_with_buffers() if isinstance(params1, TensorDictBase): @@ -1353,17 +1333,17 @@ def test_functional_with_buffer_probabilistic( if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, dist_in_keys=["loc", "scale"], @@ -1371,7 +1351,7 @@ def test_functional_with_buffer_probabilistic( safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1417,7 +1397,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) + net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1436,10 +1416,10 @@ def test_functional_with_buffer_probabilistic_laterconstruct( if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( net2, spec=spec, dist_in_keys=["loc", "scale"], @@ -1447,7 +1427,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct( safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() @@ -1494,28 +1474,28 @@ def test_vmap(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( fnet2, spec=spec, in_keys=["hidden"], out_keys=["out"], safe=safe, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1582,7 +1562,7 @@ def test_vmap_probabilistic(self, safe, spec_type): net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) fnet2, params2 = make_functional(net2) - fnet2 = TensorDictModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) + fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) params = params1 + params2 @@ -1603,14 +1583,14 @@ def test_vmap_probabilistic(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, sample_out_key=["out"], @@ -1618,7 +1598,7 @@ def test_vmap_probabilistic(self, safe, spec_type): safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) # vmap = True params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] @@ -1656,53 +1636,6 @@ def test_vmap_probabilistic(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - @pytest.mark.parametrize("functional", [True, False]) - def test_submodule_sequence(self, functional): - td_module_1 = TensorDictModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = TensorDictModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = TensorDictSequential(td_module_1, td_module_2) - - if functional: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1, (params, buffers) = sub_seq_1.make_functional_with_buffers() - sub_seq_1( - td_1, - params=params, - buffers=buffers, - ) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2, (params, buffers) = sub_seq_2.make_functional_with_buffers() - sub_seq_2( - td_2, - params=params, - buffers=buffers, - ) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - else: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - @pytest.mark.parametrize("stack", [True, False]) @pytest.mark.parametrize("functional", [True, False]) def test_sequential_partial(self, stack, functional): @@ -1723,7 +1656,7 @@ def test_sequential_partial(self, stack, functional): else: fnet2 = net2 params2 = None - fnet2 = TensorDictModule(fnet2, in_keys=["b"], out_keys=["loc", "scale"]) + fnet2 = SafeModule(fnet2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) net3 = NormalParamWrapper(net3) @@ -1732,21 +1665,21 @@ def test_sequential_partial(self, stack, functional): else: fnet3 = net3 params3 = None - fnet3 = TensorDictModule(fnet3, in_keys=["c"], out_keys=["loc", "scale"]) + fnet3 = SafeModule(fnet3, in_keys=["c"], out_keys=["loc", "scale"]) spec = NdBoundedTensorSpec(-0.1, 0.1, 4) spec = CompositeSpec(out=spec, loc=None, scale=None) kwargs = {"distribution_class": TanhNormal} - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["a"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, sample_out_key=["out"], @@ -1754,7 +1687,7 @@ def test_sequential_partial(self, stack, functional): safe=True, **kwargs, ) - tdmodule3 = ProbabilisticTensorDictModule( + tdmodule3 = SafeProbabilisticModule( fnet3, spec=spec, sample_out_key=["out"], @@ -1762,7 +1695,7 @@ def test_sequential_partial(self, stack, functional): safe=True, **kwargs, ) - tdmodule = TensorDictSequential( + tdmodule = SafeSequential( tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) @@ -1817,36 +1750,6 @@ def test_sequential_partial(self, stack, functional): assert "out" in td.keys() assert "b" in td.keys() - def test_subsequence_weight_update(self): - td_module_1 = TensorDictModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = TensorDictModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = TensorDictSequential(td_module_1, td_module_2) - - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - copy = sub_seq_1[0].module.weight.clone() - - opt = torch.optim.SGD(td_module.parameters(), lr=0.1) - opt.zero_grad() - td_1 = td_module(td_1) - td_1["out"].mean().backward() - opt.step() - - assert not torch.allclose(copy, sub_seq_1[0].module.weight) - assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight) - - if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) - def test_is_tensordict_compatible(): class MultiHeadLinear(nn.Module): @@ -1859,7 +1762,7 @@ def __init__(self, in_1, out_1, out_2, out_3): def forward(self, x): return self.linear_1(x), self.linear_2(x), self.linear_3(x) - td_module = TensorDictModule( + td_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["in_1", "in_2"], out_keys=["out_1", "out_2"], @@ -1914,7 +1817,7 @@ def __init__(self, in_1, out_1, out_2, out_3): def forward(self, x): return self.linear_1(x), self.linear_2(x), self.linear_3(x) - td_module = TensorDictModule( + td_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["in_1", "in_2"], out_keys=["out_1", "out_2"], @@ -1952,4 +1855,9 @@ def forward(self, in_1, in_2): out_keys=["out_1", "out_2", "out_3"], ) assert set(ensured_module.in_keys) == {"x"} - assert isinstance(ensured_module, TensorDictModule) + assert isinstance(ensured_module, SafeModule) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 72d63332ca9..82dd40cd7d9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -29,7 +29,7 @@ from ..data.utils import CloudpickleWrapper, DEVICE_TYPING from ..envs.common import EnvBase from ..envs.vec_env import _BatchedEnv -from ..modules.tensordict_module import ProbabilisticTensorDictModule, TensorDictModule +from ..modules.tensordict_module import SafeModule, SafeProbabilisticModule from .utils import split_trajectories _TIMEOUT = 1.0 @@ -84,21 +84,21 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: def _policy_is_tensordict_compatible(policy: nn.Module): sig = inspect.signature(policy.forward) - if isinstance(policy, TensorDictModule) or ( + if isinstance(policy, SafeModule) or ( len(sig.parameters) == 1 and hasattr(policy, "in_keys") and hasattr(policy, "out_keys") ): - # if the policy is a TensorDictModule or takes a single argument and defines + # if the policy is a SafeModule or takes a single argument and defines # in_keys and out_keys then we assume it can already deal with TensorDict input # to forward and we return True return True elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): - # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # if it's not a SafeModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. return False - # if in_keys or out_keys were defined but policy is not a TensorDictModule or + # if in_keys or out_keys were defined but policy is not a SafeModule or # accepts multiple arguments then it's likely the user is trying to do something # that will have undetermined behaviour, we raise an error raise TypeError( @@ -107,7 +107,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module): "should take a single argument of type TensorDict to policy.forward and define " "both in_keys and out_keys. Alternatively, policy.forward can accept " "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." + "TorchRL will attempt to automatically wrap the policy with a SafeModule." ) @@ -116,15 +116,13 @@ def _get_policy_and_device( self, policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, device: Optional[DEVICE_TYPING] = None, observation_spec: TensorSpec = None, - ) -> Tuple[ - ProbabilisticTensorDictModule, torch.device, Union[None, Callable[[], dict]] - ]: + ) -> Tuple[SafeProbabilisticModule, torch.device, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. From a policy and a device, assigns the self.device attribute to @@ -135,7 +133,7 @@ def _get_policy_and_device( create_env_fn (Callable or list of callables): an env creator function (or a list of creators) create_env_kwargs (dictionary): kwargs for the env creator - policy (ProbabilisticTensorDictModule, optional): a policy to be used + policy (SafeProbabilisticModule, optional): a policy to be used device (int, str or torch.device, optional): device where to place the policy observation_spec (TensorSpec, optional): spec of the observations @@ -163,13 +161,13 @@ def _get_policy_and_device( # callables should be supported as policies. if not _policy_is_tensordict_compatible(policy): # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with TensorDictModule + # so we attempt to auto-wrap policy with SafeModule if observation_spec is None: raise ValueError( "Unable to read observation_spec from the environment. This is " "required to check compatibility of the environment and policy " "since the policy is a nn.Module that operates on tensors " - "rather than a TensorDictModule or a nn.Module that accepts a " + "rather than a SafeModule or a nn.Module that accepts a " "TensorDict as input and defines in_keys and out_keys." ) sig = inspect.signature(policy.forward) @@ -183,18 +181,18 @@ def _get_policy_and_device( if isinstance(output, tuple): out_keys.extend(f"output{i+1}" for i in range(len(output) - 1)) - policy = TensorDictModule( + policy = SafeModule( policy, in_keys=list(sig.parameters), out_keys=out_keys ) else: raise TypeError( "Arguments to policy.forward are incompatible with entries in " "env.observation_spec. If you want TorchRL to automatically " - "wrap your policy with a TensorDictModule then the arguments " + "wrap your policy with a SafeModule then the arguments " "to policy.forward must correspond one-to-one with entries in " "env.observation_spec that are prefixed with 'next_'. For more " "complex behaviour and more control you can consider writing " - "your own TensorDictModule." + "your own SafeModule." ) try: @@ -307,7 +305,7 @@ def __init__( ], # noqa: F821 policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -520,7 +518,7 @@ def iterator(self) -> Iterator[TensorDictBase]: def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase: policy_device = self.device if hasattr(self.policy, "in_keys"): - # some keys may be absent -- TensorDictModule is resilient to missing keys + # some keys may be absent -- SafeModule is resilient to missing keys td = td.select(*self.policy.in_keys, strict=False) if self._td_policy is None: self._td_policy = td.to(policy_device) @@ -720,7 +718,7 @@ class _MultiDataCollector(_DataCollector): Args: create_env_fn (list of Callabled): list of Callables, each returning an instance of EnvBase - policy (Callable, optional): Instance of ProbabilisticTensorDictModule class. + policy (Callable, optional): Instance of SafeProbabilisticModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of frames may well be greater than this as the closing signals are sent to the @@ -779,7 +777,7 @@ def __init__( create_env_fn: Sequence[Callable[[], EnvBase]], policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -1306,7 +1304,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable, optional): Instance of ProbabilisticTensorDictModule class. + policy (Callable, optional): Instance of SafeProbabilisticModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of @@ -1361,7 +1359,7 @@ def __init__( create_env_fn: Callable[[], EnvBase], policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 128408dd6f2..1ff0cd03712 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -13,7 +13,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase -from torchrl.modules.tensordict_module import TensorDictModule +from torchrl.modules.tensordict_module import SafeModule class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): @@ -53,12 +53,12 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): >>> import torch.nn as nn >>> from torchrl.modules import MLP, WorldModelWrapper >>> world_model = WorldModelWrapper( - ... TensorDictModule( + ... SafeModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), - ... TensorDictModule( + ... SafeModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], @@ -114,7 +114,7 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): def __init__( self, - world_model: TensorDictModule, + world_model: SafeModule, params: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None, device: DEVICE_TYPING = "cpu", diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 432682812b2..fb902d692f7 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -13,7 +13,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase -from torchrl.modules.tensordict_module import TensorDictModule +from torchrl.modules.tensordict_module import SafeModule class DreamerEnv(ModelBasedEnvBase): @@ -21,10 +21,10 @@ class DreamerEnv(ModelBasedEnvBase): def __init__( self, - world_model: TensorDictModule, + world_model: SafeModule, prior_shape: Tuple[int, ...], belief_shape: Tuple[int, ...], - obs_decoder: TensorDictModule = None, + obs_decoder: SafeModule = None, device: DEVICE_TYPING = "cpu", dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 36a5cd8df36..ef86b45b0c4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Union - import pkg_resources +from tensordict.nn.probabilistic import ( # noqa + interaction_mode as exploration_mode, + set_interaction_mode as set_exploration_mode, +) from tensordict.tensordict import TensorDictBase -from torch.autograd.grad_mode import _DecoratorContextManager AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} @@ -150,38 +151,3 @@ def _check_dmlab(): # "screeps": None, # https://github.com/screeps/screeps # "ml-agents": None, } - -EXPLORATION_MODE = None - - -class set_exploration_mode(_DecoratorContextManager): - """Sets the exploration mode of all ProbabilisticTDModules to the desired mode. - - Args: - mode (str): mode to use when the policy is being called. - - Examples: - >>> policy = Actor(action_spec, module=network, default_interaction_mode="mode") - >>> env.rollout(policy=policy, max_steps=100) # rollout with the "mode" interaction mode - >>> with set_exploration_mode("random"): - >>> env.rollout(policy=policy, max_steps=100) # rollout with the "random" interaction mode - - """ - - def __init__(self, mode: str = "mode"): - super().__init__() - self.mode = mode - - def __enter__(self) -> None: - global EXPLORATION_MODE - self.prev = EXPLORATION_MODE - EXPLORATION_MODE = self.mode - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - global EXPLORATION_MODE - EXPLORATION_MODE = self.prev - - -def exploration_mode() -> Union[str, None]: - """Returns the exploration mode currently set.""" - return EXPLORATION_MODE diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8460e40160b..8c6ca5d8593 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -13,12 +13,13 @@ TanhNormal, TruncatedNormal, ) -from .functional_modules import ( - extract_buffers, - extract_weights, - FunctionalModule, - FunctionalModuleWithBuffers, -) + +# from .functional_modules import ( +# FunctionalModule, +# FunctionalModuleWithBuffers, +# extract_weights, +# extract_buffers, +# ) from .models import ( ConvNet, DdpgCnnActor, @@ -50,11 +51,10 @@ EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, - ProbabilisticTensorDictModule, QValueActor, - TensorDictModule, - TensorDictModuleWrapper, - TensorDictSequential, + SafeModule, + SafeProbabilisticModule, + SafeSequential, ValueOperator, WorldModelWrapper, ) diff --git a/torchrl/modules/functional_modules.py b/torchrl/modules/functional_modules.py deleted file mode 100644 index b40a6a22242..00000000000 --- a/torchrl/modules/functional_modules.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from copy import deepcopy - -import torch -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase -from torch import nn - -_RESET_OLD_TENSORDICT = True -try: - import functorch._src.vmap - - _has_functorch = True -except ImportError: - _has_functorch = False - -# Monky-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked -if _has_functorch: - from functorch._src.vmap import ( - _add_batch_dim, - _broadcast_to_and_flatten, - _get_name, - _remove_batch_dim, - _validate_and_get_batch_size, - Tensor, - tree_flatten, - tree_unflatten, - ) - - # Monkey-patches - - def _process_batched_inputs(in_dims, args, func): - if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -expected `in_dims` to be int or a (potentially nested) tuple -matching the structure of inputs, got: {type(in_dims)}.""" - ) - if len(args) == 0: - raise ValueError( - f"""vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add -inputs, or you are trying to vmap over a function with no inputs. -The latter is unsupported.""" - ) - - flat_args, args_spec = tree_flatten(args) - flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) - if flat_in_dims is None: - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -in_dims is not compatible with the structure of `inputs`. -in_dims has structure {tree_flatten(in_dims)[1]} but inputs -has structure {args_spec}.""" - ) - - for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): - if not isinstance(in_dim, int) and in_dim is not None: - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -Got in_dim={in_dim} for an input but in_dim must be either -an integer dimension or None.""" - ) - if isinstance(in_dim, int) and not isinstance( - arg, (Tensor, TensorDictBase) - ): - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -Got in_dim={in_dim} for an input but the input is of type -{type(arg)}. We cannot vmap over non-Tensor arguments, -please use None as the respective in_dim""" - ) - if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -Got in_dim={in_dim} for some input, but that input is a Tensor -of dimensionality {arg.dim()} so expected in_dim to satisfy --{arg.dim()} <= in_dim < {arg.dim()}.""" - ) - if in_dim is not None and in_dim < 0: - flat_in_dims[i] = in_dim % arg.dim() - - return ( - _validate_and_get_batch_size(flat_in_dims, flat_args), - flat_in_dims, - flat_args, - args_spec, - ) - - functorch._src.vmap._process_batched_inputs = _process_batched_inputs - - def _create_batched_inputs(flat_in_dims, flat_args, vmap_level: int, args_spec): - # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] - # If tensordict, we remove the dim at batch_size[in_dim] such that the TensorDict can accept - # the batched tensors. This will be added in _unwrap_batched - batched_inputs = [ - arg - if in_dim is None - else arg.apply( - lambda _arg, in_dim=in_dim: _add_batch_dim(_arg, in_dim, vmap_level), - batch_size=[b for i, b in enumerate(arg.batch_size) if i != in_dim], - ) - if isinstance(arg, TensorDictBase) - else _add_batch_dim(arg, in_dim, vmap_level) - for in_dim, arg in zip(flat_in_dims, flat_args) - ] - return tree_unflatten(batched_inputs, args_spec) - - functorch._src.vmap._create_batched_inputs = _create_batched_inputs - - def _unwrap_batched( - batched_outputs, out_dims, vmap_level: int, batch_size: int, func - ): - flat_batched_outputs, output_spec = tree_flatten(batched_outputs) - - for out in flat_batched_outputs: - # Change here: - if isinstance(out, (TensorDictBase, torch.Tensor)): - continue - raise ValueError( - f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " - f"Tensors, got type {type(out)} as a return." - ) - - def incompatible_error(): - raise ValueError( - f"vmap({_get_name(func)}, ..., out_dims={out_dims})(): " - f"out_dims is not compatible with the structure of `outputs`. " - f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " - f"has structure {output_spec}." - ) - - # Here: - if isinstance(batched_outputs, (TensorDictBase, torch.Tensor)): - # Some weird edge case requires us to spell out the following - # see test_out_dims_edge_case - if isinstance(out_dims, int): - flat_out_dims = [out_dims] - elif isinstance(out_dims, tuple) and len(out_dims) == 1: - flat_out_dims = out_dims - out_dims = out_dims[0] - else: - incompatible_error() - else: - flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) - if flat_out_dims is None: - incompatible_error() - - flat_outputs = [] - for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims): - if not isinstance(batched_output, TensorDictBase): - out = _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) - else: - out = batched_output.apply( - lambda x, out_dim=out_dim: _remove_batch_dim( - x, vmap_level, batch_size, out_dim - ), - batch_size=[batch_size, *batched_output.batch_size], - ) - flat_outputs.append(out) - return tree_unflatten(flat_outputs, output_spec) - - functorch._src.vmap._unwrap_batched = _unwrap_batched - -# Tensordict-compatible Functional modules - - -class FunctionalModule(nn.Module): - """This is the callable object returned by :func:`make_functional`.""" - - def __init__(self, stateless_model): - super(FunctionalModule, self).__init__() - self.stateless_model = stateless_model - - @staticmethod - def _create_from(model, disable_autograd_tracking=False): - # TODO: We don't need to copy the model to create a stateless copy - model_copy = deepcopy(model) - param_tensordict = extract_weights(model_copy) - if disable_autograd_tracking: - param_tensordict.apply(lambda x: x.requires_grad_(False), inplace=True) - return FunctionalModule(model_copy), param_tensordict - - def forward(self, params, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state( - self.stateless_model, params, return_old_tensordict=_RESET_OLD_TENSORDICT - ) - try: - return self.stateless_model(*args, **kwargs) - finally: - # Remove the loaded state on self.stateless_model - if _RESET_OLD_TENSORDICT: - _swap_state(self.stateless_model, old_state) - - -class FunctionalModuleWithBuffers(nn.Module): - """This is the callable object returned by :func:`make_functional`.""" - - def __init__(self, stateless_model): - super(FunctionalModuleWithBuffers, self).__init__() - self.stateless_model = stateless_model - - @staticmethod - def _create_from(model, disable_autograd_tracking=False): - # TODO: We don't need to copy the model to create a stateless copy - model_copy = deepcopy(model) - param_tensordict = extract_weights(model_copy) - buffers = extract_buffers(model_copy) - if buffers is None: - buffers = TensorDict( - {}, param_tensordict.batch_size, device=param_tensordict.device - ) - if disable_autograd_tracking: - param_tensordict.apply(lambda x: x.requires_grad_(False), inplace=True) - return FunctionalModuleWithBuffers(model_copy), param_tensordict, buffers - - def forward(self, params, buffers, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state( - self.stateless_model, params, return_old_tensordict=_RESET_OLD_TENSORDICT - ) - old_state_buffers = _swap_state( - self.stateless_model, buffers, return_old_tensordict=_RESET_OLD_TENSORDICT - ) - - try: - return self.stateless_model(*args, **kwargs) - finally: - # Remove the loaded state on self.stateless_model - if _RESET_OLD_TENSORDICT: - _swap_state(self.stateless_model, old_state) - _swap_state(self.stateless_model, old_state_buffers) - - -# Some utils for these - - -def extract_weights(model: nn.Module): - """Extracts the weights of a model in a tensordict.""" - tensordict = TensorDict({}, []) - for name, param in list(model.named_parameters(recurse=False)): - setattr(model, name, None) - tensordict[name] = param - for name, module in model.named_children(): - module_tensordict = extract_weights(module) - if module_tensordict is not None: - tensordict[name] = module_tensordict - if len(tensordict.keys()): - return tensordict - else: - return None - - -def extract_buffers(model: nn.Module): - """Extracts the buffers of a model in a tensordict.""" - tensordict = TensorDict({}, []) - for name, param in list(model.named_buffers(recurse=False)): - setattr(model, name, None) - tensordict[name] = param - for name, module in model.named_children(): - module_tensordict = extract_buffers(module) - if module_tensordict is not None: - tensordict[name] = module_tensordict - if len(tensordict.keys()): - return tensordict - else: - return None - - -def _swap_state(model, tensordict, return_old_tensordict=False): - # if return_old_tensordict: - # old_tensordict = tensordict.clone(recurse=False) - # old_tensordict.batch_size = [] - - if return_old_tensordict: - old_tensordict = TensorDict({}, [], device=tensordict.device) - - for key, value in list(tensordict.items()): - if isinstance(value, TensorDictBase): - _swap_state(getattr(model, key), value) - else: - if return_old_tensordict: - old_attr = getattr(model, key) - if old_attr is None: - old_attr = torch.tensor([]).view(*value.shape, 0) - delattr(model, key) - setattr(model, key, value) - if return_old_tensordict: - old_tensordict.set(key, old_attr) - if return_old_tensordict: - return old_tensordict diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index a18454fbbdf..7de0a7c9caa 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -264,18 +264,18 @@ class gSDEModule(nn.Module): Examples: >>> from tensordict import TensorDict - >>> from torchrl.modules import TensorDictModule, TensorDictSequential, ProbabilisticActor, TanhNormal + >>> from torchrl.modules import SafeModule, SafeSequential, ProbabilisticActor, TanhNormal >>> batch, state_dim, action_dim = 3, 7, 5 >>> model = nn.Linear(state_dim, action_dim) - >>> deterministic_policy = TensorDictModule(model, in_keys=["obs"], out_keys=["action"]) - >>> stochatstic_part = TensorDictModule( + >>> deterministic_policy = SafeModule(model, in_keys=["obs"], out_keys=["action"]) + >>> stochatstic_part = SafeModule( ... gSDEModule(action_dim, state_dim), ... in_keys=["action", "obs", "_eps_gSDE"], ... out_keys=["loc", "scale", "action", "_eps_gSDE"]) >>> stochatstic_part = ProbabilisticActor(stochatstic_part, ... dist_in_keys=["loc", "scale"], ... distribution_class=TanhNormal) - >>> stochatstic_policy = TensorDictSequential(deterministic_policy, stochatstic_part) + >>> stochatstic_policy = SafeSequential(deterministic_policy, stochatstic_part) >>> tensordict = TensorDict({'obs': torch.randn(state_dim), '_epx_gSDE': torch.zeros(1)}, []) >>> _ = stochatstic_policy(tensordict) >>> print(tensordict) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 9b94ed4912b..064565ccc79 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -11,8 +11,8 @@ from torchrl.envs.utils import step_mdp from torchrl.modules.distributions import NormalParamWrapper from torchrl.modules.models.models import MLP -from torchrl.modules.tensordict_module.common import TensorDictModule -from torchrl.modules.tensordict_module.sequence import TensorDictSequential +from torchrl.modules.tensordict_module.common import SafeModule +from torchrl.modules.tensordict_module.sequence import SafeSequential class DreamerActor(nn.Module): @@ -151,15 +151,15 @@ class RSSMRollout(nn.Module): Reference: https://arxiv.org/abs/1811.04551 Args: - rssm_prior (TensorDictModule): Prior network. - rssm_posterior (TensorDictModule): Posterior network. + rssm_prior (SafeModule): Prior network. + rssm_posterior (SafeModule): Posterior network. """ - def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule): + def __init__(self, rssm_prior: SafeModule, rssm_posterior: SafeModule): super().__init__() - _module = TensorDictSequential(rssm_prior, rssm_posterior) + _module = SafeSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys self.rssm_prior = rssm_prior diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index d11c9ab12fd..dd69c8b4e16 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -47,7 +47,7 @@ class CEMPlanner(MPCPlannerBase): >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec >>> from torchrl.envs.model_based import ModelBasedEnvBase - >>> from torchrl.modules import TensorDictModule + >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) @@ -71,12 +71,12 @@ class CEMPlanner(MPCPlannerBase): >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( - ... TensorDictModule( + ... SafeModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), - ... TensorDictModule( + ... SafeModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 63ecba7991c..a9d1e4ca942 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -9,13 +9,13 @@ from tensordict.tensordict import TensorDictBase from torchrl.envs import EnvBase -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule -class MPCPlannerBase(TensorDictModule, metaclass=abc.ABCMeta): +class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta): """MPCPlannerBase abstract Module. - This class inherits from :obj:`TensorDictModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step. + This class inherits from :obj:`SafeModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step. At the end of the planning step, the :obj:`MPCPlanner` will return a proposed action. Args: diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 06cc365104f..a94b8eeb12b 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -13,12 +13,12 @@ QValueActor, ValueOperator, ) -from .common import TensorDictModule, TensorDictModuleWrapper +from .common import SafeModule from .exploration import ( AdditiveGaussianWrapper, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) -from .probabilistic import ProbabilisticTensorDictModule -from .sequence import TensorDictSequential +from .probabilistic import SafeProbabilisticModule +from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 1092064eab9..dba80fc67a5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -6,21 +6,17 @@ from typing import Optional, Sequence, Tuple, Union import torch +from tensordict.nn import TensorDictModuleWrapper from torch import nn from torchrl.data import CompositeSpec, TensorSpec, UnboundedContinuousTensorSpec from torchrl.modules.models.models import DistributionalDQNnet -from torchrl.modules.tensordict_module.common import ( - TensorDictModule, - TensorDictModuleWrapper, -) -from torchrl.modules.tensordict_module.probabilistic import ( - ProbabilisticTensorDictModule, -) -from torchrl.modules.tensordict_module.sequence import TensorDictSequential +from torchrl.modules.tensordict_module.common import SafeModule +from torchrl.modules.tensordict_module.probabilistic import SafeProbabilisticModule +from torchrl.modules.tensordict_module.sequence import SafeSequential -class Actor(TensorDictModule): +class Actor(SafeModule): """General class for deterministic actors in RL. The Actor class comes with default values for the out_keys (["action"]) @@ -72,7 +68,7 @@ def __init__( ) -class ProbabilisticActor(ProbabilisticTensorDictModule): +class ProbabilisticActor(SafeProbabilisticModule): """General class for probabilistic actors in RL. The Actor class comes with default values for the out_keys (["action"]) @@ -91,7 +87,7 @@ class ProbabilisticActor(ProbabilisticTensorDictModule): >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> fmodule, params, buffers = functorch.make_functional_with_buffers( ... module) - >>> tensordict_module = TensorDictModule(fmodule, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> tensordict_module = SafeModule(fmodule, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( ... module=tensordict_module, ... spec=action_spec, @@ -114,7 +110,7 @@ class ProbabilisticActor(ProbabilisticTensorDictModule): def __init__( self, - module: TensorDictModule, + module: SafeModule, dist_in_keys: Union[str, Sequence[str]], sample_out_key: Optional[Sequence[str]] = None, spec: Optional[TensorSpec] = None, @@ -138,7 +134,7 @@ def __init__( ) -class ValueOperator(TensorDictModule): +class ValueOperator(SafeModule): """General class for value functions in RL. The ValueOperator class comes with default values for the in_keys and @@ -531,7 +527,7 @@ def __init__( ) -class ActorValueOperator(TensorDictSequential): +class ActorValueOperator(SafeSequential): """Actor-value operator. This class wraps together an actor and a value model that share a common observation embedding network: @@ -563,9 +559,9 @@ class ActorValueOperator(TensorDictSequential): will both return a stand-alone TDModule with the dedicated functionality. Args: - common_operator (TensorDictModule): a common operator that reads observations and produces a hidden variable - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + common_operator (SafeModule): a common operator that reads observations and produces a hidden variable + policy_operator (SafeModule): a policy operator that reads the hidden variable and returns an action + value_operator (SafeModule): a value operator, that reads the hidden variable and returns a value Examples: >>> import torch @@ -575,14 +571,14 @@ class ActorValueOperator(TensorDictSequential): >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) - >>> td_module_hidden = TensorDictModule( + >>> td_module_hidden = SafeModule( ... module=module_hidden, ... spec=spec_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) - >>> module_action = TensorDictModule( + >>> module_action = SafeModule( ... NormalParamWrapper(torch.nn.Linear(4, 8)), ... in_keys=["hidden"], ... out_keys=["loc", "scale"], @@ -638,9 +634,9 @@ class ActorValueOperator(TensorDictSequential): def __init__( self, - common_operator: TensorDictModule, - policy_operator: TensorDictModule, - value_operator: TensorDictModule, + common_operator: SafeModule, + policy_operator: SafeModule, + value_operator: SafeModule, ): super().__init__( common_operator, @@ -648,13 +644,13 @@ def __init__( value_operator, ) - def get_policy_operator(self) -> TensorDictSequential: + def get_policy_operator(self) -> SafeSequential: """Returns a stand-alone policy operator that maps an observation to an action.""" - return TensorDictSequential(self.module[0], self.module[1]) + return SafeSequential(self.module[0], self.module[1]) - def get_value_operator(self) -> TensorDictSequential: + def get_value_operator(self) -> SafeSequential: """Returns a stand-alone value network operator that maps an observation to a value estimate.""" - return TensorDictSequential(self.module[0], self.module[2]) + return SafeSequential(self.module[0], self.module[2]) class ActorCriticOperator(ActorValueOperator): @@ -689,9 +685,9 @@ class ActorCriticOperator(ActorValueOperator): parent object, as the value is computed based on the policy output. Args: - common_operator (TensorDictModule): a common operator that reads observations and produces a hidden variable - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + common_operator (SafeModule): a common operator that reads observations and produces a hidden variable + policy_operator (SafeModule): a policy operator that reads the hidden variable and returns an action + value_operator (SafeModule): a value operator, that reads the hidden variable and returns a value Examples: >>> import torch @@ -701,7 +697,7 @@ class ActorCriticOperator(ActorValueOperator): >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) - >>> td_module_hidden = TensorDictModule( + >>> td_module_hidden = SafeModule( ... module=module_hidden, ... spec=spec_hidden, ... in_keys=["observation"], @@ -709,7 +705,7 @@ class ActorCriticOperator(ActorValueOperator): ... ) >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) >>> module_action = NormalParamWrapper(torch.nn.Linear(4, 8)) - >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) + >>> module_action = SafeModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( ... module=module_action, ... spec=spec_action, @@ -793,7 +789,7 @@ def get_value_operator(self) -> TensorDictModuleWrapper: ) -class ActorCriticWrapper(TensorDictSequential): +class ActorCriticWrapper(SafeSequential): """Actor-value operator without common module. This class wraps together an actor and a value model that do not share a common observation embedding network: @@ -820,8 +816,8 @@ class ActorCriticWrapper(TensorDictSequential): will both return a stand-alone TDModule with the dedicated functionality. Args: - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + policy_operator (SafeModule): a policy operator that reads the hidden variable and returns an action + value_operator (SafeModule): a value operator, that reads the hidden variable and returns a value Examples: >>> import torch @@ -877,18 +873,18 @@ class ActorCriticWrapper(TensorDictSequential): def __init__( self, - policy_operator: TensorDictModule, - value_operator: TensorDictModule, + policy_operator: SafeModule, + value_operator: SafeModule, ): super().__init__( policy_operator, value_operator, ) - def get_policy_operator(self) -> TensorDictSequential: + def get_policy_operator(self) -> SafeSequential: """Returns a stand-alone policy operator that maps an observation to an action.""" return self.module[0] - def get_value_operator(self) -> TensorDictSequential: + def get_value_operator(self) -> SafeSequential: """Returns a stand-alone value network operator that maps an observation to a value estimate.""" return self.module[1] diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 82294617a67..c092197eb7c 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -7,20 +7,15 @@ import inspect import warnings -from copy import deepcopy -from textwrap import indent -from typing import Any, Iterable, List, Optional, Sequence, Type, Union +from typing import Iterable, Optional, Type, Union import torch from torchrl.data.utils import DEVICE_TYPING -from torchrl.modules import functional_modules _has_functorch = False try: - import functorch - from functorch import FunctionalModule, FunctionalModuleWithBuffers, vmap - from functorch._src.make_functional import _swap_state + from functorch import FunctionalModule, FunctionalModuleWithBuffers _has_functorch = True except ImportError: @@ -29,19 +24,16 @@ "functional programming should work, but functionality and performance " "may be affected. Consider installing functorch and/or upgrating pytorch." ) - from torchrl.modules.functional_modules import ( + from tensordict.nn.functional_modules import ( FunctionalModule, FunctionalModuleWithBuffers, ) +from tensordict.nn import TensorDictModule from tensordict.tensordict import TensorDictBase -from torch import nn, Tensor +from torch import nn from torchrl.data import CompositeSpec, TensorSpec -from torchrl.modules.functional_modules import ( - FunctionalModule as rlFunctionalModule, - FunctionalModuleWithBuffers as rlFunctionalModuleWithBuffers, -) def _check_all_str(list_of_str, first_level=True): @@ -62,7 +54,7 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): spec = module.spec if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): raise RuntimeError( - "safe TensorDictModules with multiple out_keys require a CompositeSpec with matching keys. Got " + "safe SafeModules with multiple out_keys require a CompositeSpec with matching keys. Got " f"keys {module.out_keys}." ) elif not isinstance(spec, CompositeSpec): @@ -88,8 +80,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): ) -class TensorDictModule(nn.Module): - """A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. +class SafeModule(TensorDictModule): + """An :obj:``SafeModule`` is a :obj:``tensordict.nn.TensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. Args: module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional @@ -106,8 +98,8 @@ class TensorDictModule(nn.Module): If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is :obj:`False`. - Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. The domain spec can - be passed along if needed. TensorDictModule support functional and regular :obj:`nn.Module` objects. In the functional + Embedding a neural network in a SafeModule only requires to specify the input and output keys. The domain spec can + be passed along if needed. SafeModule support functional and regular :obj:`nn.Module` objects. In the functional case, the 'params' (and 'buffers') keyword argument must be specified: Examples: @@ -115,12 +107,12 @@ class TensorDictModule(nn.Module): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TensorDictModule + >>> from torchrl.modules import SafeModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> spec = NdUnboundedContinuousTensorSpec(8) >>> module = torch.nn.GRUCell(4, 8) >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) - >>> td_fmodule = TensorDictModule( + >>> td_fmodule = SafeModule( ... module=fmodule, ... spec=spec, ... in_keys=["input", "hidden"], @@ -137,7 +129,7 @@ class TensorDictModule(nn.Module): device=cpu) In the stateful case: - >>> td_module = TensorDictModule( + >>> td_module = SafeModule( ... module=module, ... spec=spec, ... in_keys=["input", "hidden"], @@ -173,36 +165,21 @@ class TensorDictModule(nn.Module): def __init__( self, module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module + FunctionalModule, FunctionalModuleWithBuffers, SafeModule, nn.Module ], in_keys: Iterable[str], out_keys: Iterable[str], spec: Optional[TensorSpec] = None, safe: bool = False, ): - - super().__init__() - - if not out_keys: - raise RuntimeError(f"out_keys were not passed to {self.__class__.__name__}") - if not in_keys: - raise RuntimeError(f"in_keys were not passed to {self.__class__.__name__}") - self.out_keys = out_keys - _check_all_str(self.out_keys) - self.in_keys = in_keys - _check_all_str(self.in_keys) - - if "_" in in_keys: - warnings.warn( - 'key "_" is for ignoring output, it should not be used in input keys' - ) + super().__init__(module, in_keys, out_keys) if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") elif spec is not None and not isinstance(spec, CompositeSpec): if len(self.out_keys) > 1: raise RuntimeError( - f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. " + f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. " "Consider using a CompositeSpec object or no spec at all." ) spec = CompositeSpec(**{self.out_keys[0]: spec}) @@ -231,28 +208,11 @@ def __init__( and all(_spec is None for _spec in spec.values()) ): raise RuntimeError( - "`TensorDictModule(spec=None, safe=True)` is not a valid configuration as the tensor " + "`SafeModule(spec=None, safe=True)` is not a valid configuration as the tensor " "specs are not specified" ) self.register_forward_hook(_forward_hook_safe_action) - self.module = module - - @property - def is_functional(self): - if not _has_functorch: - return isinstance( - self.module, - ( - functional_modules.FunctionalModule, - functional_modules.FunctionalModuleWithBuffers, - ), - ) - return isinstance( - self.module, - (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers), - ) - @property def spec(self) -> CompositeSpec: return self._spec @@ -265,144 +225,6 @@ def spec(self, spec: CompositeSpec) -> None: ) self._spec = spec - def _write_to_tensordict( - self, - tensordict: TensorDictBase, - tensors: List, - tensordict_out: Optional[TensorDictBase] = None, - out_keys: Optional[Iterable[str]] = None, - vmap: Optional[int] = None, - ) -> TensorDictBase: - - if out_keys is None: - out_keys = self.out_keys - if ( - (tensordict_out is None) - and vmap - and (isinstance(vmap, bool) or vmap[-1] is None) - ): - dim = tensors[0].shape[0] - tensordict_out = tensordict.expand(dim, *tensordict.batch_size).contiguous() - elif tensordict_out is None: - tensordict_out = tensordict - for _out_key, _tensor in zip(out_keys, tensors): - if _out_key != "_": - tensordict_out.set(_out_key, _tensor) - return tensordict_out - - def _make_vmap(self, buffers, kwargs, n_input): - if "vmap" in kwargs and kwargs["vmap"]: - if not isinstance(kwargs["vmap"], (tuple, bool)): - raise RuntimeError( - "vmap argument must be a boolean or a tuple of dim expensions." - ) - # if vmap is a tuple, we make sure the number of inputs after params and buffers match - if isinstance(kwargs["vmap"], (tuple, list)): - err_msg = f"the vmap argument had {len(kwargs['vmap'])} elements, but the module has {len(self.in_keys)} inputs" - if isinstance( - self.module, - (FunctionalModuleWithBuffers, rlFunctionalModuleWithBuffers), - ): - if len(kwargs["vmap"]) == 3: - _vmap = ( - *kwargs["vmap"][:2], - *[kwargs["vmap"][2]] * len(self.in_keys), - ) - elif len(kwargs["vmap"]) == 2 + len(self.in_keys): - _vmap = kwargs["vmap"] - else: - raise RuntimeError(err_msg) - elif isinstance(self.module, (FunctionalModule, rlFunctionalModule)): - if len(kwargs["vmap"]) == 2: - _vmap = ( - *kwargs["vmap"][:1], - *[kwargs["vmap"][1]] * len(self.in_keys), - ) - elif len(kwargs["vmap"]) == 1 + len(self.in_keys): - _vmap = kwargs["vmap"] - else: - raise RuntimeError(err_msg) - else: - raise TypeError( - f"vmap not compatible with modules of type {type(self.module)}" - ) - else: - _vmap = ( - (0, 0, *(None,) * n_input) - if buffers is not None - else (0, *(None,) * n_input) - ) - return _vmap - - def _call_module( - self, - tensors: Sequence[Tensor], - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> Union[Tensor, Sequence[Tensor]]: - err_msg = "Did not find the {0} keyword argument to be used with the functional module. Check it was passed to the TensorDictModule method." - if isinstance( - self.module, - ( - FunctionalModule, - FunctionalModuleWithBuffers, - rlFunctionalModule, - rlFunctionalModuleWithBuffers, - ), - ): - _vmap = self._make_vmap(buffers, kwargs, len(tensors)) - if _vmap: - module = vmap(self.module, _vmap) - else: - module = self.module - - if isinstance(self.module, (FunctionalModule, rlFunctionalModule)): - if params is None: - raise KeyError(err_msg.format("params")) - kwargs_pruned = { - key: item for key, item in kwargs.items() if key not in ("vmap") - } - out = module(params, *tensors, **kwargs_pruned) - return out - - elif isinstance( - self.module, (FunctionalModuleWithBuffers, rlFunctionalModuleWithBuffers) - ): - if params is None: - raise KeyError(err_msg.format("params")) - if buffers is None: - raise KeyError(err_msg.format("buffers")) - - kwargs_pruned = { - key: item for key, item in kwargs.items() if key not in ("vmap") - } - out = module(params, buffers, *tensors, **kwargs_pruned) - return out - else: - out = self.module(*tensors, **kwargs) - return out - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys) - tensors = self._call_module(tensors, params=params, buffers=buffers, **kwargs) - if not isinstance(tensors, tuple): - tensors = (tensors,) - tensordict_out = self._write_to_tensordict( - tensordict, - tensors, - tensordict_out, - vmap=kwargs.get("vmap", False), - ) - return tensordict_out - def random(self, tensordict: TensorDictBase) -> TensorDictBase: """Samples a random element in the target space, irrespective of any input. @@ -420,210 +242,18 @@ def random(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: - """See :obj:`TensorDictModule.random(...)`.""" + """See :obj:`SafeModule.random(...)`.""" return self.random(tensordict) - @property - def device(self): - for p in self.parameters(): - return p.device - return torch.device("cpu") - - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> TensorDictModule: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> SafeModule: if hasattr(self, "spec") and self.spec is not None: self.spec = self.spec.to(dest) out = super().to(dest) return out - def __repr__(self) -> str: - fields = indent( - f"module={self.module}, \n" - f"device={self.device}, \n" - f"in_keys={self.in_keys}, \n" - f"out_keys={self.out_keys}", - 4 * " ", - ) - - return f"{self.__class__.__name__}(\n{fields})" - - def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - """Transforms a stateful module in a functional module and returns its parameters and buffers. - - Unlike functorch.make_functional_with_buffers, this method supports lazy modules. - - Args: - clone (bool, optional): if True, a clone of the module is created before it is returned. - This is useful as it prevents the original module to be scraped off of its - parameters and buffers. - Defaults to True - native (bool, optional): if True, TorchRL's functional modules will be used. - Defaults to True - - Returns: - A tuple of parameter and buffer tuples - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> lazy_module = nn.LazyLinear(4) - >>> spec = NdUnboundedContinuousTensorSpec(18) - >>> td_module = TensorDictModule(lazy_module, spec, ["some_input"], - ... ["some_output"]) - >>> _, (params, buffers) = td_module.make_functional_with_buffers() - >>> print(params[0].shape) # the lazy module has been initialized - torch.Size([4, 18]) - >>> print(td_module( - ... TensorDict({'some_input': torch.randn(18)}, batch_size=[]), - ... params=params, - ... buffers=buffers)) - TensorDict( - fields={ - some_input: Tensor(torch.Size([18]), dtype=torch.float32), - some_output: Tensor(torch.Size([4]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - - """ - native = native or not _has_functorch - if clone: - self_copy = deepcopy(self) - else: - self_copy = self - - if isinstance( - self_copy.module, - ( - TensorDictModule, - FunctionalModule, - FunctionalModuleWithBuffers, - rlFunctionalModule, - rlFunctionalModuleWithBuffers, - ), - ): - raise RuntimeError( - "TensorDictModule.make_functional_with_buffers requires the " - "module to be a regular nn.Module. " - f"Found type {type(self_copy.module)}" - ) - - # check if there is a non-initialized lazy module - for m in self_copy.module.modules(): - if hasattr(m, "has_uninitialized_params") and m.has_uninitialized_params(): - pseudo_input = self_copy.spec.rand() - self_copy.module(pseudo_input) - break - - module = self_copy.module - if native: - fmodule, params, buffers = rlFunctionalModuleWithBuffers._create_from( - module - ) - else: - fmodule, params, buffers = functorch.make_functional_with_buffers(module) - self_copy.module = fmodule - - # Erase meta params - for _ in fmodule.parameters(): - none_state = [None for _ in params + buffers] - if hasattr(fmodule, "all_names_map"): - # functorch >= 0.2.0 - _swap_state(fmodule.stateless_model, fmodule.all_names_map, none_state) - else: - # functorch < 0.2.0 - _swap_state(fmodule.stateless_model, fmodule.split_names, none_state) - - break - - return self_copy, (params, buffers) - - @property - def num_params(self): - if _has_functorch and isinstance( - self.module, - (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers), - ): - return len(self.module.param_names) - else: - return 0 - - @property - def num_buffers(self): - if _has_functorch and isinstance( - self.module, (functorch.FunctionalModuleWithBuffers,) - ): - return len(self.module.buffer_names) - else: - return 0 - - -class TensorDictModuleWrapper(nn.Module): - """Wrapper calss for TensorDictModule objects. - - Once created, a TensorDictModuleWrapper will behave exactly as the TensorDictModule it contains except for the methods that are - overwritten. - - Args: - td_module (TensorDictModule): operator to be wrapped. - - Examples: - >>> # This class can be used for exploration wrappers - >>> import functorch - >>> import torch - >>> from tensordict import TensorDict - >>> from tensordict.utils import expand_as_right - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TensorDictModuleWrapper, TensorDictModule - >>> - >>> class EpsilonGreedyExploration(TensorDictModuleWrapper): - ... eps = 0.5 - ... def forward(self, tensordict, params, buffers): - ... rand_output_clone = self.random(tensordict.clone()) - ... det_output_clone = self.td_module(tensordict.clone(), params, buffers) - ... rand_output_idx = torch.rand(tensordict.shape, device=rand_output_clone.device) < self.eps - ... for key in self.out_keys: - ... _rand_output = rand_output_clone.get(key) - ... _det_output = det_output_clone.get(key) - ... rand_output_idx_expand = expand_as_right(rand_output_idx, _rand_output).to(_rand_output.dtype) - ... tensordict.set(key, - ... rand_output_idx_expand * _rand_output + (1-rand_output_idx_expand) * _det_output) - ... return tensordict - >>> - >>> td = TensorDict({"input": torch.zeros(10, 4)}, [10]) - >>> module = torch.nn.Linear(4, 4, bias=False) # should return a zero tensor if input is a zero tensor - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) - >>> spec = NdUnboundedContinuousTensorSpec(4) - >>> tensordict_module = TensorDictModule(module=fmodule, spec=spec, in_keys=["input"], out_keys=["output"]) - >>> tensordict_module_wrapped = EpsilonGreedyExploration(tensordict_module) - >>> tensordict_module_wrapped(td, params=params, buffers=buffers) - >>> print(td.get("output")) - - """ - - def __init__(self, td_module: TensorDictModule): - super().__init__() - self.td_module = td_module - if len(self.td_module._forward_hooks): - for pre_hook in self.td_module._forward_hooks: - self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) - - def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - if name not in self.__dict__ and not name.startswith("__"): - return getattr(self._modules["td_module"], name) - else: - raise AttributeError( - f"attribute {name} not recognised in {type(self).__name__}" - ) - - def forward(self, *args, **kwargs): - return self.td_module.forward(*args, **kwargs) - -def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): - """Returns `True` if a module can be used as a TensorDictModule, and False if it can't. +def is_tensordict_compatible(module: Union[SafeModule, nn.Module]): + """Returns `True` if a module can be used as a SafeModule, and False if it can't. If the signature is misleading an error is raised. @@ -661,21 +291,21 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): """ sig = inspect.signature(module.forward) - if isinstance(module, TensorDictModule) or ( + if isinstance(module, SafeModule) or ( len(sig.parameters) == 1 and hasattr(module, "in_keys") and hasattr(module, "out_keys") ): - # if the module is a TensorDictModule or takes a single argument and defines + # if the module is a SafeModule or takes a single argument and defines # in_keys and out_keys then we assume it can already deal with TensorDict input # to forward and we return True return True elif not hasattr(module, "in_keys") and not hasattr(module, "out_keys"): - # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # if it's not a SafeModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. return False - # if in_keys or out_keys were defined but module is not a TensorDictModule or + # if in_keys or out_keys were defined but module is not a SafeModule or # accepts multiple arguments then it's likely the user is trying to do something # that will have undetermined behaviour, we raise an error raise TypeError( @@ -684,18 +314,16 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): "should take a single argument of type TensorDict to module.forward and define " "both in_keys and out_keys. Alternatively, module.forward can accept " "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the module with a TensorDictModule." + "TorchRL will attempt to automatically wrap the module with a SafeModule." ) def ensure_tensordict_compatible( - module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module - ], + module: Union[FunctionalModule, FunctionalModuleWithBuffers, SafeModule, nn.Module], in_keys: Optional[Iterable[str]] = None, out_keys: Optional[Iterable[str]] = None, safe: bool = False, - wrapper_type: Optional[Type] = TensorDictModule, + wrapper_type: Optional[Type] = SafeModule, ): """Checks and ensures an object with forward method is TensorDict compatible.""" if is_tensordict_compatible(module): @@ -715,7 +343,7 @@ def ensure_tensordict_compatible( if not isinstance(module, nn.Module): raise TypeError( "Argument to ensure_tensordict_compatible should be either " - "a TensorDictModule or an nn.Module" + "a SafeModule or an nn.Module" ) sig = inspect.signature(module.forward) @@ -723,10 +351,10 @@ def ensure_tensordict_compatible( raise TypeError( "Arguments to module.forward are incompatible with entries in " "env.observation_spec. If you want TorchRL to automatically " - "wrap your module with a TensorDictModule then the arguments " + "wrap your module with a SafeModule then the arguments " "to module must correspond one-to-one with entries in " "in_keys. For more complex behaviour and more control you can " - "consider writing your own TensorDictModule." + "consider writing your own SafeModule." ) # TODO: Check whether out_keys match (at least in number) if they are provided. diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index a07b13aeccc..fe3aac62df9 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,6 +7,7 @@ import numpy as np import torch +from tensordict.nn import TensorDictModuleWrapper from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right @@ -14,8 +15,7 @@ from torchrl.envs.utils import exploration_mode from torchrl.modules.tensordict_module.common import ( _forward_hook_safe_action, - TensorDictModule, - TensorDictModuleWrapper, + SafeModule, ) @@ -30,7 +30,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): """Epsilon-Greedy PO wrapper. Args: - policy (TensorDictModule): a deterministic policy. + policy (SafeModule): a deterministic policy. eps_init (scalar, optional): initial epsilon value. default: 1.0 eps_end (scalar, optional): final epsilon value. @@ -71,7 +71,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): def __init__( self, - policy: TensorDictModule, + policy: SafeModule, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, @@ -139,7 +139,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): """Additive Gaussian PO wrapper. Args: - policy (TensorDictModule): a policy. + policy (SafeModule): a policy. sigma_init (scalar, optional): initial epsilon value. default: 1.0 sigma_end (scalar, optional): final epsilon value. @@ -162,7 +162,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): def __init__( self, - policy: TensorDictModule, + policy: SafeModule, sigma_init: float = 1.0, sigma_end: float = 0.1, annealing_num_steps: int = 1000, @@ -250,7 +250,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): zeroing the tensordict at reset time. Args: - policy (TensorDictModule): a policy + policy (SafeModule): a policy eps_init (scalar): initial epsilon value, determining the amount of noise to be added. default: 1.0 eps_end (scalar): final epsilon value, determining the amount of noise to be added. @@ -293,7 +293,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): def __init__( self, - policy: TensorDictModule, + policy: SafeModule, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 316f9536d25..3061d1017fa 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -3,34 +3,25 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import re -from copy import deepcopy -from textwrap import indent -from typing import List, Optional, Sequence, Tuple, Type, Union +from typing import Optional, Sequence, Type, Union -from tensordict.tensordict import TensorDictBase -from torch import distributions as d, Tensor +from tensordict.nn import ProbabilisticTensorDictModule from torchrl.data import TensorSpec -from torchrl.envs.utils import exploration_mode, set_exploration_mode -from torchrl.modules.distributions import Delta, distributions_maps -from torchrl.modules.tensordict_module.common import _check_all_str, TensorDictModule +from torchrl.modules.distributions import Delta +from torchrl.modules.tensordict_module.common import SafeModule -class ProbabilisticTensorDictModule(TensorDictModule): - """A probabilistic TD Module. - - `ProbabilisticTDModule` is a special case of a TDModule where the output is - sampled given some rule, specified by the input :obj:`default_interaction_mode` - argument and the :obj:`exploration_mode()` global function. +class SafeProbabilisticModule(ProbabilisticTensorDictModule, SafeModule): + """A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. It consists in a wrapper around another TDModule that returns a tensordict - updated with the distribution parameters. :obj:`ProbabilisticTensorDictModule` is + updated with the distribution parameters. :obj:`SafeProbabilisticModule` is responsible for constructing the distribution (through the :obj:`get_dist()` method) and/or sampling from this distribution (through a regular :obj:`__call__()` to the module). - A :obj:`ProbabilisticTensorDictModule` instance has two main features: + A :obj:`SafeProbabilisticModule` instance has two main features: - It reads and writes TensorDict objects - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed. @@ -39,8 +30,8 @@ class ProbabilisticTensorDictModule(TensorDictModule): the 'rsample', 'sample' method). The sampling step is skipped if the inner TDModule has already created the desired key-value pair. - By default, ProbabilisticTensorDictModule distribution class is a Delta - distribution, making ProbabilisticTensorDictModule a simple wrapper around + By default, SafeProbabilisticModule distribution class is a Delta + distribution, making SafeProbabilisticModule a simple wrapper around a deterministic mapping function. Args: @@ -90,13 +81,13 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import ProbabilisticTensorDictModule, TanhNormal, NormalParamWrapper + >>> from torchrl.modules import SafeProbabilisticModule, TanhNormal, NormalParamWrapper >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> spec = NdUnboundedContinuousTensorSpec(4) >>> net = NormalParamWrapper(torch.nn.GRUCell(4, 8)) >>> fnet, params, buffers = functorch.make_functional_with_buffers(net) - >>> module = TensorDictModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) - >>> td_module = ProbabilisticTensorDictModule( + >>> module = SafeModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) + >>> td_module = SafeProbabilisticModule( ... module=module, ... spec=spec, ... dist_in_keys=["loc", "scale"], @@ -139,7 +130,7 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut def __init__( self, - module: TensorDictModule, + module: SafeModule, dist_in_keys: Union[str, Sequence[str], dict], sample_out_key: Union[str, Sequence[str]], spec: Optional[TensorSpec] = None, @@ -151,203 +142,21 @@ def __init__( cache_dist: bool = False, n_empirical_estimate: int = 1000, ): - in_keys = module.in_keys - - # if the module returns the sampled key we wont be sampling it again - # then ProbabilisticTensorDictModule is presumably used to return the distribution using `get_dist` - if isinstance(dist_in_keys, str): - dist_in_keys = [dist_in_keys] - if isinstance(sample_out_key, str): - sample_out_key = [sample_out_key] - if not isinstance(dist_in_keys, dict): - dist_in_keys = {param_key: param_key for param_key in dist_in_keys} - for key in dist_in_keys.values(): - if key not in module.out_keys: - raise RuntimeError( - f"The key {key} could not be found in the wrapped module `{type(module)}.out_keys`." - ) - module_out_keys = module.out_keys - self.sample_out_key = sample_out_key - _check_all_str(self.sample_out_key) - sample_out_key = [key for key in sample_out_key if key not in module_out_keys] - self._requires_sample = bool(len(sample_out_key)) - out_keys = sample_out_key + module_out_keys super().__init__( - module=module, spec=spec, in_keys=in_keys, out_keys=out_keys, safe=safe - ) - self.dist_in_keys = dist_in_keys - _check_all_str(self.dist_in_keys.keys()) - _check_all_str(self.dist_in_keys.values()) - - self.default_interaction_mode = default_interaction_mode - if isinstance(distribution_class, str): - distribution_class = distributions_maps.get(distribution_class.lower()) - self.distribution_class = distribution_class - self.distribution_kwargs = ( - distribution_kwargs if distribution_kwargs is not None else {} + module=module, + dist_in_keys=dist_in_keys, + sample_out_key=sample_out_key, + default_interaction_mode=default_interaction_mode, + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=return_log_prob, + cache_dist=cache_dist, + n_empirical_estimate=n_empirical_estimate, ) - self.n_empirical_estimate = n_empirical_estimate - self._dist = None - self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False - self.return_log_prob = return_log_prob - - def _call_module( - self, - tensordict: TensorDictBase, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - return self.module(tensordict, params=params, buffers=buffers, **kwargs) - - def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - module_params = self.parameters(recurse=False) - if len(list(module_params)): - raise RuntimeError( - "make_functional_with_buffers cannot be called on ProbabilisticTensorDictModule" - "that contain parameters on the outer level." - ) - if clone: - self_copy = deepcopy(self) - else: - self_copy = self - - self_copy.module, other = self_copy.module.make_functional_with_buffers( - clone=True, - native=native, + super(ProbabilisticTensorDictModule, self).__init__( + module=module, + spec=spec, + in_keys=self.in_keys, + out_keys=self.out_keys, + safe=safe, ) - return self_copy, other - - def get_dist( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> Tuple[d.Distribution, TensorDictBase]: - interaction_mode = exploration_mode() - if interaction_mode is None: - interaction_mode = self.default_interaction_mode - with set_exploration_mode(interaction_mode): - tensordict_out = self._call_module( - tensordict, - tensordict_out=tensordict_out, - params=params, - buffers=buffers, - **kwargs, - ) - dist = self.build_dist_from_params(tensordict_out) - return dist, tensordict_out - - def build_dist_from_params(self, tensordict_out: TensorDictBase) -> d.Distribution: - try: - selected_td_out = tensordict_out.select(*self.dist_in_keys.values()) - dist_kwargs = { - dist_key: selected_td_out[td_key] - for dist_key, td_key in self.dist_in_keys.items() - } - dist = self.distribution_class(**dist_kwargs) - except TypeError as err: - if "an unexpected keyword argument" in str(err): - raise TypeError( - "distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_in_keys must match." - f"Got this error message: \n{indent(str(err), 4 * ' ')}\nwith dist_in_keys={self.dist_in_keys}" - ) - elif re.search(r"missing.*required positional arguments", str(err)): - raise TypeError( - f"TensorDict with keys {tensordict_out.keys()} does not match the distribution {self.distribution_class} keywords." - ) - else: - raise err - return dist - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - - dist, tensordict_out = self.get_dist( - tensordict, - tensordict_out=tensordict_out, - params=params, - buffers=buffers, - **kwargs, - ) - if self._requires_sample: - out_tensors = self._dist_sample(dist, interaction_mode=exploration_mode()) - if isinstance(out_tensors, Tensor): - out_tensors = (out_tensors,) - tensordict_out.update( - {key: value for key, value in zip(self.sample_out_key, out_tensors)} - ) - if self.return_log_prob: - log_prob = dist.log_prob(*out_tensors) - tensordict_out.set("sample_log_prob", log_prob) - elif self.return_log_prob: - out_tensors = [tensordict_out.get(key) for key in self.sample_out_key] - log_prob = dist.log_prob(*out_tensors) - tensordict_out.set("sample_log_prob", log_prob) - # raise RuntimeError( - # "ProbabilisticTensorDictModule.return_log_prob = True is incompatible with settings in which " - # "the submodule is responsible for sampling. To manually gather the log-probability, call first " - # "\n>>> dist, tensordict = tensordict_module.get_dist(tensordict)" - # "\n>>> tensordict.set('sample_log_prob', dist.log_prob(tensordict.get(sample_key))" - # ) - return tensordict_out - - def _dist_sample( - self, - dist: d.Distribution, - *tensors: Tensor, - interaction_mode: bool = None, - ) -> Union[Tuple[Tensor], Tensor]: - if interaction_mode is None or interaction_mode == "": - interaction_mode = self.default_interaction_mode - if not isinstance(dist, d.Distribution): - raise TypeError(f"type {type(dist)} not recognised by _dist_sample") - - if interaction_mode == "mode": - if hasattr(dist, "mode"): - return dist.mode - else: - raise NotImplementedError( - f"method {type(dist)}.mode is not implemented" - ) - - elif interaction_mode == "median": - if hasattr(dist, "median"): - return dist.median - else: - raise NotImplementedError( - f"method {type(dist)}.median is not implemented" - ) - - elif interaction_mode == "mean": - try: - return dist.mean - except (AttributeError, NotImplementedError): - if dist.has_rsample: - return dist.rsample((self.n_empirical_estimate,)).mean(0) - else: - return dist.sample((self.n_empirical_estimate,)).mean(0) - - elif interaction_mode == "random": - if dist.has_rsample: - return dist.rsample() - else: - return dist.sample() - else: - raise NotImplementedError(f"unknown interaction_mode {interaction_mode}") - - @property - def num_params(self): - return self.module.num_params - - @property - def num_buffers(self): - return self.module.num_buffers diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 057ad1e9afd..bbc3323630f 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -5,35 +5,15 @@ from __future__ import annotations -from copy import copy, deepcopy -from typing import Iterable, List, Optional, Tuple, Union - -_has_functorch = False -try: - import functorch - - _has_functorch = True -except ImportError: - print( - "failed to import functorch. TorchRL's features that do not require " - "functional programming should work, but functionality and performance " - "may be affected. Consider installing functorch and/or upgrating pytorch." - ) - FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." - -import torch -from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from torch import nn, Tensor +from tensordict.nn import TensorDictSequential +from torch import nn from torchrl.data import CompositeSpec -from torchrl.modules.tensordict_module.common import TensorDictModule -from torchrl.modules.tensordict_module.probabilistic import ( - ProbabilisticTensorDictModule, -) +from torchrl.modules.tensordict_module.common import SafeModule -class TensorDictSequential(TensorDictModule): - """A sequence of TensorDictModules. +class SafeSequential(TensorDictSequential, SafeModule): + """A sequence of SafeModules. Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. @@ -41,12 +21,12 @@ class TensorDictSequential(TensorDictModule): buffers) will be concatenated in a single list. Args: - modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially. + modules (iterable of SafeModules): ordered sequence of SafeModule instances to be run sequentially. partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the - stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts + stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. TensorDictSequence supports functional, modular and vmap coding: @@ -55,14 +35,14 @@ class TensorDictSequential(TensorDictModule): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal, TensorDictSequential, NormalParamWrapper - >>> from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule + >>> from torchrl.modules import TanhNormal, SafeSequential, NormalParamWrapper + >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> spec1 = NdUnboundedContinuousTensorSpec(4) >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> fnet1, params1, buffers1 = functorch.make_functional_with_buffers(net1) - >>> fmodule1 = TensorDictModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) - >>> td_module1 = ProbabilisticTensorDictModule( + >>> fmodule1 = SafeModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) + >>> td_module1 = SafeProbabilisticModule( ... module=fmodule1, ... spec=spec1, ... dist_in_keys=["loc", "scale"], @@ -73,13 +53,13 @@ class TensorDictSequential(TensorDictModule): >>> spec2 = NdUnboundedContinuousTensorSpec(8) >>> module2 = torch.nn.Linear(4, 8) >>> fmodule2, params2, buffers2 = functorch.make_functional_with_buffers(module2) - >>> td_module2 = TensorDictModule( + >>> td_module2 = SafeModule( ... module=fmodule2, ... spec=spec2, ... in_keys=["hidden"], ... out_keys=["output"], ... ) - >>> td_module = TensorDictSequential(td_module1, td_module2) + >>> td_module = SafeSequential(td_module1, td_module2) >>> params = params1 + params2 >>> buffers = buffers1 + buffers2 >>> _ = td_module(td, params=params, buffers=buffers) @@ -128,380 +108,23 @@ class TensorDictSequential(TensorDictModule): def __init__( self, - *modules: TensorDictModule, + *modules: SafeModule, partial_tolerant: bool = False, ): + self.partial_tolerant = partial_tolerant + in_keys, out_keys = self._compute_in_and_out_keys(modules) spec = CompositeSpec() for module in modules: - if isinstance(module, TensorDictModule) or hasattr(module, "spec"): + if isinstance(module, SafeModule) or hasattr(module, "spec"): spec.update(module.spec) else: spec.update(CompositeSpec({key: None for key in module.out_keys})) - super().__init__( + + super(TensorDictSequential, self).__init__( spec=spec, module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys, ) - - self.partial_tolerant = partial_tolerant - - def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[List]: - in_keys = [] - out_keys = [] - for module in modules: - # we sometimes use in_keys to select keys of a tensordict that are - # necessary to run a TensorDictModule. If a key is an intermediary in - # the chain, there is no reason why it should belong to the input - # TensorDict. - for in_key in module.in_keys: - if in_key not in (out_keys + in_keys): - in_keys.append(in_key) - out_keys += module.out_keys - - out_keys = [ - out_key - for i, out_key in enumerate(out_keys) - if out_key not in out_keys[i + 1 :] - ] - return in_keys, out_keys - - @staticmethod - def _find_functional_module(module: TensorDictModule) -> nn.Module: - if not _has_functorch: - raise ImportError(FUNCTORCH_ERROR) - fmodule = module - while not isinstance( - fmodule, (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers) - ): - try: - fmodule = fmodule.module - except AttributeError: - raise AttributeError( - f"couldn't find a functional module in module of type {type(module)}" - ) - return fmodule - - @property - def num_params(self): - return self.param_len[-1] - - @property - def num_buffers(self): - return self.buffer_len[-1] - - @property - def param_len(self) -> List[int]: - param_list = [] - prev = 0 - for module in self.module: - param_list.append(module.num_params + prev) - prev = param_list[-1] - return param_list - - @property - def buffer_len(self) -> List[int]: - buffer_list = [] - prev = 0 - for module in self.module: - buffer_list.append(module.num_buffers + prev) - prev = buffer_list[-1] - return buffer_list - - def _split_param( - self, param_list: Iterable[Tensor], params_or_buffers: str - ) -> Iterable[Iterable[Tensor]]: - if params_or_buffers == "params": - list_out = self.param_len - elif params_or_buffers == "buffers": - list_out = self.buffer_len - list_in = [0] + list_out[:-1] - out = [] - for a, b in zip(list_in, list_out): - out.append(param_list[a:b]) - return out - - def select_subsequence( - self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None - ) -> "TensorDictSequential": - """Returns a new TensorDictSequential with only the modules that are necessary to compute the given output keys with the given input keys. - - Args: - in_keys: input keys of the subsequence we want to select - out_keys: output keys of the subsequence we want to select - - Returns: - A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys. - """ - if in_keys is None: - in_keys = deepcopy(self.in_keys) - if out_keys is None: - out_keys = deepcopy(self.out_keys) - id_to_keep = set(range(len(self.module))) - for i, module in enumerate(self.module): - if all(key in in_keys for key in module.in_keys): - in_keys.extend(module.out_keys) - else: - id_to_keep.remove(i) - for i, module in reversed(list(enumerate(self.module))): - if i in id_to_keep: - if any(key in out_keys for key in module.out_keys): - out_keys.extend(module.in_keys) - else: - id_to_keep.remove(i) - id_to_keep = sorted(id_to_keep) - - modules = [self.module[i] for i in id_to_keep] - - if modules == []: - raise ValueError( - "No modules left after selection. Make sure that in_keys and out_keys are coherent." - ) - - return TensorDictSequential(*modules) - - def _run_module( - self, - module, - tensordict, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ): - tensordict_keys = set(tensordict.keys()) - if not self.partial_tolerant or all( - key in tensordict_keys for key in module.in_keys - ): - if params is not None or buffers is not None: - tensordict = module( - tensordict, params=params, buffers=buffers, **kwargs - ) - else: - tensordict = module(tensordict, **kwargs) - elif self.partial_tolerant and isinstance(tensordict, LazyStackedTensorDict): - for sub_td in tensordict.tensordicts: - tensordict_keys = set(sub_td.keys()) - if all(key in tensordict_keys for key in module.in_keys): - if params is not None or buffers is not None: - module(sub_td, params=params, buffers=buffers, **kwargs) - else: - module(sub_td, **kwargs) - tensordict._update_valid_keys() - return tensordict - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out=None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - if params is not None and buffers is not None: - if isinstance(params, TensorDictBase): - # TODO: implement sorted values and items - param_splits = list(zip(*sorted(params.items())))[1] - buffer_splits = list(zip(*sorted(buffers.items())))[1] - else: - param_splits = self._split_param(params, "params") - buffer_splits = self._split_param(buffers, "buffers") - for i, (module, param, buffer) in enumerate( - zip(self.module, param_splits, buffer_splits) - ): - if "vmap" in kwargs and i > 0: - # the tensordict is already expended - if not isinstance(kwargs["vmap"], tuple): - kwargs["vmap"] = (0, 0, *(0,) * len(module.in_keys)) - else: - kwargs["vmap"] = ( - *kwargs["vmap"][:2], - *(0,) * len(module.in_keys), - ) - tensordict = self._run_module( - module, tensordict, params=param, buffers=buffer, **kwargs - ) - - elif params is not None: - if isinstance(params, TensorDictBase): - # TODO: implement sorted values and items - param_splits = list(zip(*sorted(params.items())))[1] - else: - param_splits = self._split_param(params, "params") - for i, (module, param) in enumerate(zip(self.module, param_splits)): - if "vmap" in kwargs and i > 0: - # the tensordict is already expended - if not isinstance(kwargs["vmap"], tuple): - kwargs["vmap"] = (0, *(0,) * len(module.in_keys)) - else: - kwargs["vmap"] = ( - *kwargs["vmap"][:1], - *(0,) * len(module.in_keys), - ) - tensordict = self._run_module( - module, tensordict, params=param, **kwargs - ) - - elif not len(kwargs): - for module in self.module: - tensordict = self._run_module(module, tensordict, **kwargs) - else: - raise RuntimeError( - "TensorDictSequential does not support keyword arguments other than 'tensordict_out', 'in_keys', 'out_keys' 'params', 'buffers' and 'vmap'" - ) - if tensordict_out is not None: - tensordict_out.update(tensordict, inplace=True) - return tensordict_out - return tensordict - - def __len__(self): - return len(self.module) - - def __getitem__(self, index: Union[int, slice]) -> TensorDictModule: - if isinstance(index, int): - return self.module.__getitem__(index) - else: - return TensorDictSequential(*self.module.__getitem__(index)) - - def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None: - return self.module.__setitem__(idx=index, module=tensordict_module) - - def __delitem__(self, index: Union[int, slice]) -> None: - self.module.__delitem__(idx=index) - - def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - """Transforms a stateful module in a functional module and returns its parameters and buffers. - - Unlike functorch.make_functional_with_buffers, this method supports lazy modules. - - Args: - clone (bool, optional): if True, a clone of the module is created before it is returned. - This is useful as it prevents the original module to be scraped off of its - parameters and buffers. - Defaults to True - native (bool, optional): if True, TorchRL's functional modules will be used. - Defaults to True - - Returns: - A tuple of parameter and buffer tuples - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> lazy_module1 = nn.LazyLinear(4) - >>> lazy_module2 = nn.LazyLinear(3) - >>> spec1 = NdUnboundedContinuousTensorSpec(18) - >>> spec2 = NdUnboundedContinuousTensorSpec(4) - >>> td_module1 = TensorDictModule(spec=spec1, module=lazy_module1, in_keys=["some_input"], out_keys=["hidden"]) - >>> td_module2 = TensorDictModule(spec=spec2, module=lazy_module2, in_keys=["hidden"], out_keys=["some_output"]) - >>> td_module = TensorDictSequential(td_module1, td_module2) - >>> _, (params, buffers) = td_module.make_functional_with_buffers() - >>> print(params[0].shape) # the lazy module has been initialized - torch.Size([4, 18]) - >>> print(td_module( - ... TensorDict({'some_input': torch.randn(18)}, batch_size=[]), - ... params=params, - ... buffers=buffers)) - TensorDict( - fields={ - some_input: Tensor(torch.Size([18]), dtype=torch.float32), - hidden: Tensor(torch.Size([4]), dtype=torch.float32), - some_output: Tensor(torch.Size([3]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - - """ - native = native or not _has_functorch - if clone: - self_copy = deepcopy(self) - self_copy.module = copy(self_copy.module) - else: - self_copy = self - params = [] if not native else TensorDict({}, []) - buffers = [] if not native else TensorDict({}, []) - for i, module in enumerate(self.module): - self_copy.module[i], ( - _params, - _buffers, - ) = module.make_functional_with_buffers(clone=True, native=native) - if native or not _has_functorch: - params[str(i)] = _params - buffers[str(i)] = _buffers - else: - params.extend(_params) - buffers.extend(_buffers) - return self_copy, (params, buffers) - - def get_dist( - self, - tensordict: TensorDictBase, - **kwargs, - ) -> Tuple[torch.distributions.Distribution, ...]: - L = len(self.module) - - if isinstance(self.module[-1], ProbabilisticTensorDictModule): - if "params" in kwargs and "buffers" in kwargs: - params = kwargs["params"] - buffers = kwargs["buffers"] - if isinstance(params, TensorDictBase): - param_splits = list(zip(*sorted(params.items())))[1] - buffer_splits = list(zip(*sorted(buffers.items())))[1] - else: - param_splits = self._split_param(kwargs["params"], "params") - buffer_splits = self._split_param(kwargs["buffers"], "buffers") - kwargs_pruned = { - key: item - for key, item in kwargs.items() - if key not in ("params", "buffers") - } - for i, (module, param, buffer) in enumerate( - zip(self.module, param_splits, buffer_splits) - ): - if "vmap" in kwargs_pruned and i > 0: - # the tensordict is already expended - kwargs_pruned["vmap"] = (0, 0, *(0,) * len(module.in_keys)) - if i < L - 1: - tensordict = module( - tensordict, params=param, buffers=buffer, **kwargs_pruned - ) - else: - out = module.get_dist( - tensordict, params=param, buffers=buffer, **kwargs_pruned - ) - - elif "params" in kwargs: - params = kwargs["params"] - if isinstance(params, TensorDictBase): - param_splits = list(zip(*sorted(params.items())))[1] - else: - param_splits = self._split_param(kwargs["params"], "params") - kwargs_pruned = { - key: item for key, item in kwargs.items() if key not in ("params",) - } - for i, (module, param) in enumerate(zip(self.module, param_splits)): - if "vmap" in kwargs_pruned and i > 0: - # the tensordict is already expended - kwargs_pruned["vmap"] = (0, *(0,) * len(module.in_keys)) - if i < L - 1: - tensordict = module(tensordict, params=param, **kwargs_pruned) - else: - out = module.get_dist(tensordict, params=param, **kwargs_pruned) - - elif not len(kwargs): - for i, module in enumerate(self.module): - if i < L - 1: - tensordict = module(tensordict) - else: - out = module.get_dist(tensordict) - else: - raise RuntimeError( - "TensorDictSequential does not support keyword arguments other than 'params', 'buffers' and 'vmap'" - ) - - return out - else: - raise RuntimeError( - "Cannot call get_dist on a sequence of tensordicts that does not end with a probabilistic TensorDict" - ) diff --git a/torchrl/modules/tensordict_module/world_models.py b/torchrl/modules/tensordict_module/world_models.py index 10b8e5b9a5a..0243c3806a3 100644 --- a/torchrl/modules/tensordict_module/world_models.py +++ b/torchrl/modules/tensordict_module/world_models.py @@ -4,10 +4,10 @@ # LICENSE file in the root directory of this source tree. -from torchrl.modules.tensordict_module import TensorDictModule, TensorDictSequential +from torchrl.modules.tensordict_module import SafeModule, SafeSequential -class WorldModelWrapper(TensorDictSequential): +class WorldModelWrapper(SafeSequential): """World model wrapper. This module wraps together a transition model and a reward model. @@ -15,25 +15,18 @@ class WorldModelWrapper(TensorDictSequential): The reward model is used to predict the reward of the imagined transition. Args: - transition_model (TensorDictModule): a transition model that generates a new world states. - reward_model (TensorDictModule): a reward model, that reads the world state and returns a reward. + transition_model (SafeModule): a transition model that generates a new world states. + reward_model (SafeModule): a reward model, that reads the world state and returns a reward. """ - def __init__( - self, - transition_model: TensorDictModule, - reward_model: TensorDictModule, - ): - super().__init__( - transition_model, - reward_model, - ) - - def get_transition_model_operator(self) -> TensorDictSequential: + def __init__(self, transition_model: SafeModule, reward_model: SafeModule): + super().__init__(transition_model, reward_model) + + def get_transition_model_operator(self) -> SafeSequential: """Returns a transition operator that maps either an observation to a world state or a world state to the next world state.""" return self.module[0] - def get_reward_operator(self) -> TensorDictSequential: + def get_reward_operator(self) -> SafeSequential: """Returns a reward operator that maps a world state to a reward.""" return self.module[1] diff --git a/torchrl/modules/utils/mappings.py b/torchrl/modules/utils/mappings.py index 1af962e857f..406c79b304b 100644 --- a/torchrl/modules/utils/mappings.py +++ b/torchrl/modules/utils/mappings.py @@ -3,48 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Union +from typing import Callable import torch -from torch import nn +from tensordict.nn.utils import biased_softplus, inv_softplus - -def inv_softplus(bias: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - """Inverse softplus function. - - Args: - bias (float or tensor): the value to be softplus-inverted. - """ - is_tensor = True - if not isinstance(bias, torch.Tensor): - is_tensor = False - bias = torch.tensor(bias) - out = bias.expm1().clamp_min(1e-6).log() - if not is_tensor and out.numel() == 1: - return out.item() - return out - - -class biased_softplus(nn.Module): - """A biased softplus module. - - The bias indicates the value that is to be returned when a zero-tensor is - passed through the transform. - - Args: - bias (scalar): 'bias' of the softplus transform. If bias=1.0, then a _bias shift will be computed such that - softplus(0.0 + _bias) = bias. - min_val (scalar): minimum value of the transform. - default: 0.1 - """ - - def __init__(self, bias: float, min_val: float = 0.01): - super().__init__() - self.bias = inv_softplus(bias - min_val) - self.min_val = min_val - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.softplus(x + self.bias) + self.min_val +__all__ = ["biased_softplus", "expln", "inv_softplus", "mappings"] def expln(x): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 5f1ffe01618..af20007b26a 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -9,8 +9,8 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d -from torchrl.modules import TensorDictModule -from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule +from torchrl.modules import SafeModule +from torchrl.modules.tensordict_module import SafeProbabilisticModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import distance_loss @@ -26,7 +26,7 @@ class A2CLoss(LossModule): https://arxiv.org/abs/1602.01783v2 Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -36,13 +36,13 @@ class A2CLoss(LossModule): critic_coef (float): the weight of the critic loss. gamma (scalar): a discount factor for return computation. loss_function_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". - advantage_module (nn.Module): TensorDictModule used to compute tha advantage function. + advantage_module (nn.Module): SafeModule used to compute tha advantage function. """ def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key: str = "advantage", advantage_diff_key: str = "value_error", entropy_bonus: bool = True, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 5be98bd3215..a7c90521a5a 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -8,8 +8,7 @@ from typing import Iterator, List, Optional, Tuple, Union import torch - -from torchrl.modules.functional_modules import FunctionalModuleWithBuffers +from tensordict.nn.functional_modules import FunctionalModuleWithBuffers _has_functorch = False try: @@ -29,7 +28,7 @@ from torch import nn, Tensor from torch.nn import Parameter -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule class LossModule(nn.Module): @@ -65,7 +64,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def convert_to_functional( self, - module: TensorDictModule, + module: SafeModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, @@ -90,7 +89,7 @@ def convert_to_functional( def _convert_to_functional_functorch( self, - module: TensorDictModule, + module: SafeModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, @@ -250,7 +249,7 @@ def _convert_to_functional_functorch( def _convert_to_functional_native( self, - module: TensorDictModule, + module: SafeModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index d2a54ee81fe..5692f109739 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -10,7 +10,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.utils import distance_loss, hold_out_params, next_state_value @@ -22,8 +22,8 @@ class DDPGLoss(LossModule): """The DDPG Loss class. Args: - actor_network (TensorDictModule): a policy operator. - value_network (TensorDictModule): a Q value operator. + actor_network (SafeModule): a policy operator. + value_network (SafeModule): a Q value operator. gamma (scalar): a discount factor for return computation. device (str, int or torch.device, optional): a device where the losses will be computed, if it can't be found via the value operator. @@ -36,8 +36,8 @@ class DDPGLoss(LossModule): def __init__( self, - actor_network: TensorDictModule, - value_network: TensorDictModule, + actor_network: SafeModule, + value_network: SafeModule, gamma: float, loss_function: str = "l2", delay_actor: bool = False, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 9005112e7d7..555525161a4 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -9,7 +9,7 @@ from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives import ( distance_loss, hold_out_params, @@ -26,8 +26,8 @@ class REDQLoss_deprecated(LossModule): train a SAC-like algorithm. Args: - actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value Default is 2. @@ -51,8 +51,8 @@ class REDQLoss_deprecated(LossModule): def __init__( self, - actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + actor_network: SafeModule, + qvalue_network: SafeModule, num_qvalue_nets: int = 10, sub_sample_len: int = 2, gamma: Number = 0.99, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index c9f35b64649..8d839078154 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -9,7 +9,7 @@ from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import distance_loss, hold_out_net from torchrl.objectives.value.functional import vec_td_lambda_return_estimate @@ -24,7 +24,7 @@ class DreamerModelLoss(LossModule): Reference: https://arxiv.org/abs/1912.01603. Args: - world_model (TensorDictModule): the world model. + world_model (SafeModule): the world model. lambda_kl (float, optional): the weight of the kl divergence loss. Default: 1.0. lambda_reco (float, optional): the weight of the reconstruction loss. Default: 1.0. lambda_reward (float, optional): the weight of the reward loss. Default: 1.0. @@ -42,7 +42,7 @@ class DreamerModelLoss(LossModule): def __init__( self, - world_model: TensorDictModule, + world_model: SafeModule, lambda_kl: float = 1.0, lambda_reco: float = 1.0, lambda_reward: float = 1.0, @@ -133,8 +133,8 @@ class DreamerActorLoss(LossModule): Reference: https://arxiv.org/abs/1912.01603. Args: - actor_model (TensorDictModule): the actor model. - value_model (TensorDictModule): the value model. + actor_model (SafeModule): the actor model. + value_model (SafeModule): the value model. model_based_env (DreamerEnv): the model based environment. imagination_horizon (int, optional): The number of steps to unroll the model. Default: 15. @@ -147,8 +147,8 @@ class DreamerActorLoss(LossModule): def __init__( self, - actor_model: TensorDictModule, - value_model: TensorDictModule, + actor_model: SafeModule, + value_model: SafeModule, model_based_env: DreamerEnv, imagination_horizon: int = 15, gamma: int = 0.99, @@ -217,7 +217,7 @@ class DreamerValueLoss(LossModule): Reference: https://arxiv.org/abs/1912.01603. Args: - value_model (TensorDictModule): the value model. + value_model (SafeModule): the value model. value_loss (str, optional): the loss to use for the value loss. Default: "l2". gamma (float, optional): the gamma discount factor. Default: 0.99. discount_loss (bool, optional): if True, the loss is discounted with a @@ -227,7 +227,7 @@ class DreamerValueLoss(LossModule): def __init__( self, - value_model: TensorDictModule, + value_model: SafeModule, value_loss: Optional[str] = None, gamma: int = 0.99, discount_loss: bool = False, # for consistency with paper diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b711519492e..2926e2c667a 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -10,10 +10,10 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.utils import distance_loss -from ..modules.tensordict_module import ProbabilisticTensorDictModule +from ..modules.tensordict_module import SafeProbabilisticModule from .common import LossModule @@ -32,7 +32,7 @@ class PPOLoss(LossModule): https://arxiv.org/abs/1707.06347 Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -52,8 +52,8 @@ class PPOLoss(LossModule): def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key: str = "advantage", advantage_diff_key: str = "value_error", entropy_bonus: bool = True, @@ -171,7 +171,7 @@ class ClipPPOLoss(PPOLoss): loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage) Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -193,8 +193,8 @@ class ClipPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key: str = "advantage", clip_epsilon: float = 0.2, entropy_bonus: bool = True, @@ -277,7 +277,7 @@ class KLPENPPOLoss(PPOLoss): favouring a certain level of distancing between the two while still preventing them to be too much apart. Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -304,8 +304,8 @@ class KLPENPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key="advantage", dtarg: float = 0.01, beta: float = 1.0, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index beb7c31c51a..70b3d6a3e8d 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -13,7 +13,7 @@ from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.common import _has_functorch, LossModule from torchrl.objectives.utils import ( distance_loss, @@ -30,8 +30,8 @@ class REDQLoss(LossModule): train a SAC-like algorithm. Args: - actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value Default is 2. @@ -59,8 +59,8 @@ class REDQLoss(LossModule): def __init__( self, - actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + actor_network: SafeModule, + qvalue_network: SafeModule, num_qvalue_nets: int = 10, sub_sample_len: int = 2, gamma: Number = 0.99, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index b68b7b981a0..294f79c50ec 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -4,7 +4,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.envs.utils import step_mdp -from torchrl.modules import ProbabilisticTensorDictModule, TensorDictModule +from torchrl.modules import SafeModule, SafeProbabilisticModule from torchrl.objectives import distance_loss from torchrl.objectives.common import LossModule @@ -19,9 +19,9 @@ class ReinforceLoss(LossModule): def __init__( self, - actor_network: ProbabilisticTensorDictModule, + actor_network: SafeProbabilisticModule, advantage_module: Callable[[TensorDictBase], TensorDictBase], - critic: Optional[TensorDictModule] = None, + critic: Optional[SafeModule] = None, delay_value: bool = False, gamma: float = 0.99, advantage_key: str = "advantage", diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 9b0685e2178..bfc5e088813 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -12,7 +12,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor -from torchrl.modules import ProbabilisticActor, TensorDictModule +from torchrl.modules import ProbabilisticActor, SafeModule from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.utils import distance_loss, next_state_value @@ -28,8 +28,8 @@ class SACLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (TensorDictModule): Q(s, a) parametric model - value_network (TensorDictModule): V(s) parametric model\ + qvalue_network (SafeModule): Q(s, a) parametric model + value_network (SafeModule): V(s) parametric model\ qvalue_network_bis (ProbabilisticTDModule, optional): if required, the Q-value can be computed twice independently using two separate networks. The minimum predicted value will then be used for @@ -68,8 +68,8 @@ class SACLoss(LossModule): def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, - value_network: TensorDictModule, + qvalue_network: SafeModule, + value_network: SafeModule, num_qvalue_nets: int = 2, gamma: Number = 0.99, priotity_key: str = "td_error", diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 467c0cb7c7f..4f2da57c93a 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -13,7 +13,7 @@ from torch.nn import functional as F from torchrl.envs.utils import step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule class _context_manager: @@ -293,7 +293,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @torch.no_grad() def next_state_value( tensordict: TensorDictBase, - operator: Optional[TensorDictModule] = None, + operator: Optional[SafeModule] = None, next_val_key: str = "state_action_value", gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 6996252847e..6ee6ef3503b 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -10,7 +10,7 @@ from torch import nn, Tensor from torchrl.envs.utils import step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.value.functional import ( td_lambda_advantage_estimate, vec_generalized_advantage_estimate, @@ -26,7 +26,7 @@ class TDEstimate(nn.Module): Args: gamma (scalar): exponential mean discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. + value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. gradient_mode (bool, optional): if True, gradients are propagated throught @@ -38,7 +38,7 @@ class TDEstimate(nn.Module): def __init__( self, gamma: Union[float, torch.Tensor], - value_network: TensorDictModule, + value_network: SafeModule, average_rewards: bool = False, gradient_mode: bool = False, value_key: str = "state_value", @@ -129,7 +129,7 @@ class TDLambdaEstimate(nn.Module): Args: gamma (scalar): exponential mean discount. lmbda (scalar): trajectory discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. + value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. gradient_mode (bool, optional): if True, gradients are propagated throught @@ -144,7 +144,7 @@ def __init__( self, gamma: Union[float, torch.Tensor], lmbda: Union[float, torch.Tensor], - value_network: TensorDictModule, + value_network: SafeModule, average_rewards: bool = False, gradient_mode: bool = False, value_key: str = "state_value", @@ -251,7 +251,7 @@ class GAE(nn.Module): Args: gamma (scalar): exponential mean discount. lmbda (scalar): trajectory discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. + value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool): if True, rewards will be standardized before the GAE is computed. gradient_mode (bool): if True, gradients are propagated throught the computation of the value function. Default is `False`. @@ -262,7 +262,7 @@ def __init__( self, gamma: Union[float, torch.Tensor], lmbda: float, - value_network: TensorDictModule, + value_network: SafeModule, average_rewards: bool = False, gradient_mode: bool = False, ): diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 9e38a00f4b0..d7facd322c1 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Type, Union +from tensordict.nn import TensorDictModuleWrapper from tensordict.tensordict import TensorDictBase from torchrl.collectors.collectors import ( @@ -17,7 +18,7 @@ from torchrl.data import MultiStep from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.modules import ProbabilisticTensorDictModule, TensorDictModuleWrapper +from torchrl.modules import SafeProbabilisticModule def sync_async_collector( @@ -248,7 +249,7 @@ def _make_collector( def make_collector_offpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[TensorDictModuleWrapper, ProbabilisticTensorDictModule], + actor_model_explore: Union[TensorDictModuleWrapper, SafeProbabilisticModule], cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> _DataCollector: @@ -256,7 +257,7 @@ def make_collector_offpolicy( Args: make_env (Callable): environment creator - actor_model_explore (TensorDictModule): Model instance used for evaluation and exploration update + actor_model_explore (SafeModule): Model instance used for evaluation and exploration update cfg (DictConfig): config for creating collector object make_env_kwargs (dict): kwargs for the env creator @@ -312,7 +313,7 @@ def make_collector_offpolicy( def make_collector_onpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[TensorDictModuleWrapper, ProbabilisticTensorDictModule], + actor_model_explore: Union[TensorDictModuleWrapper, SafeProbabilisticModule], cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> _DataCollector: @@ -320,7 +321,7 @@ def make_collector_onpolicy( Args: make_env (Callable): environment creator - actor_model_explore (TensorDictModule): Model instance used for evaluation and exploration update + actor_model_explore (SafeModule): Model instance used for evaluation and exploration update cfg (DictConfig): config for creating collector object make_env_kwargs (dict): kwargs for the env creator diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index f24771b0288..24742d62ee0 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -24,9 +24,9 @@ ActorValueOperator, NoisyLinear, NormalParamWrapper, - ProbabilisticTensorDictModule, - TensorDictModule, - TensorDictSequential, + SafeModule, + SafeProbabilisticModule, + SafeSequential, ) from torchrl.modules.distributions import ( Delta, @@ -315,7 +315,7 @@ def make_ddpg_actor( actor_net = DdpgMlpActor(**actor_net_default_kwargs) gSDE_state_key = "observation_vector" out_keys = ["param"] - actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + actor_module = SafeModule(actor_net, in_keys=in_keys, out_keys=out_keys) if cfg.gSDE: min = env_specs["action_spec"].space.minimum @@ -325,9 +325,9 @@ def make_ddpg_actor( transform = d.ComposeTransform( transform, d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2) ) - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform, learn_sigma=False), in_keys=["param", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -549,7 +549,7 @@ def make_a2c_model( out_features=hidden_features, activate_last_layer=True, ) - common_operator = TensorDictModule( + common_operator = SafeModule( spec=None, module=common_module, in_keys=in_keys_actor, @@ -565,13 +565,13 @@ def make_a2c_model( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) in_keys = ["hidden"] - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys, out_keys=["loc", "scale"] ) else: in_keys = ["hidden"] gSDE_state_key = "hidden" - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -589,9 +589,9 @@ def make_a2c_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSD"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -640,13 +640,13 @@ def make_a2c_model( actor_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] ) else: in_keys = in_keys_actor gSDE_state_key = in_keys_actor[0] - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -664,9 +664,9 @@ def make_a2c_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -838,7 +838,7 @@ def make_ppo_model( out_features=hidden_features, activate_last_layer=True, ) - common_operator = TensorDictModule( + common_operator = SafeModule( spec=None, module=common_module, in_keys=in_keys_actor, @@ -854,13 +854,13 @@ def make_ppo_model( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) in_keys = ["hidden"] - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys, out_keys=["loc", "scale"] ) else: in_keys = ["hidden"] gSDE_state_key = "hidden" - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -878,9 +878,9 @@ def make_ppo_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -929,13 +929,13 @@ def make_ppo_model( actor_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] ) else: in_keys = in_keys_actor gSDE_state_key = in_keys_actor[0] - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -953,9 +953,9 @@ def make_ppo_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -1140,7 +1140,7 @@ def make_sac_model( scale_lb=cfg.scale_lb, ) in_keys_actor = in_keys - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=[ @@ -1151,7 +1151,7 @@ def make_sac_model( else: gSDE_state_key = in_keys[0] - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -1169,9 +1169,9 @@ def make_sac_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -1387,14 +1387,14 @@ def make_redq_model( scale_mapping=f"biased_softplus_{default_policy_scale}", scale_lb=cfg.scale_lb, ) - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + out_keys_actor[1:], ) else: - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["action"] + out_keys_actor[1:], # will be overwritten @@ -1412,9 +1412,9 @@ def make_redq_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -1555,7 +1555,7 @@ def _dreamer_make_world_model( ): # World Model and reward model rssm_rollout = RSSMRollout( - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -1565,7 +1565,7 @@ def _dreamer_make_world_model( ("next", "belief"), ], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ @@ -1576,20 +1576,20 @@ def _dreamer_make_world_model( ), ) - transition_model = TensorDictSequential( - TensorDictModule( + transition_model = SafeSequential( + SafeModule( obs_encoder, in_keys=[("next", "pixels")], out_keys=[("next", "encoded_latents")], ), rssm_rollout, - TensorDictModule( + SafeModule( obs_decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "reco_pixels")], ), ) - reward_model = TensorDictModule( + reward_model = SafeModule( reward_module, in_keys=[("next", "state"), ("next", "belief")], out_keys=["reward"], @@ -1630,8 +1630,8 @@ def _dreamer_make_actors( def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): - actor_simulator = ProbabilisticTensorDictModule( - TensorDictModule( + actor_simulator = SafeProbabilisticModule( + SafeModule( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], @@ -1663,13 +1663,13 @@ def _dreamer_make_actor_real( # actor for real world: interacts with states ~ posterior # Out actor differs from the original paper where first they compute prior and posterior and then act on it # but we found that this approach worked better. - actor_realworld = TensorDictSequential( - TensorDictModule( + actor_realworld = SafeSequential( + SafeModule( obs_encoder, in_keys=["pixels"], out_keys=["encoded_latents"], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=["belief", "encoded_latents"], out_keys=[ @@ -1678,8 +1678,8 @@ def _dreamer_make_actor_real( "state", ], ), - ProbabilisticTensorDictModule( - TensorDictModule( + SafeProbabilisticModule( + SafeModule( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], @@ -1700,7 +1700,7 @@ def _dreamer_make_actor_real( } ), ), - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", action_key], out_keys=[ @@ -1716,7 +1716,7 @@ def _dreamer_make_actor_real( def _dreamer_make_value_model(mlp_num_units, value_key): # actor for simulator: interacts with states ~ prior - value_model = TensorDictModule( + value_model = SafeModule( MLP( out_features=1, depth=3, @@ -1740,7 +1740,7 @@ def _dreamer_make_mbenv( ): # MB environment if use_decoder_in_env: - mb_env_obs_decoder = TensorDictModule( + mb_env_obs_decoder = SafeModule( obs_decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "reco_pixels")], @@ -1748,8 +1748,8 @@ def _dreamer_make_mbenv( else: mb_env_obs_decoder = None - transition_model = TensorDictSequential( - TensorDictModule( + transition_model = SafeSequential( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -1760,7 +1760,7 @@ def _dreamer_make_mbenv( ], ), ) - reward_model = TensorDictModule( + reward_model = SafeModule( reward_module, in_keys=["state", "belief"], out_keys=["reward"], diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 0a1016c2f70..73691357c8d 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -8,13 +8,14 @@ from warnings import warn import torch +from tensordict.nn import TensorDictModuleWrapper from torch import optim from torch.optim.lr_scheduler import CosineAnnealingLR from torchrl.collectors.collectors import _DataCollector from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase -from torchrl.modules import reset_noise, TensorDictModule, TensorDictModuleWrapper +from torchrl.modules import reset_noise, SafeModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import TargetNetUpdater from torchrl.trainers.loggers import Logger @@ -79,9 +80,7 @@ def make_trainer( loss_module: LossModule, recorder: Optional[EnvBase] = None, target_net_updater: Optional[TargetNetUpdater] = None, - policy_exploration: Optional[ - Union[TensorDictModuleWrapper, TensorDictModule] - ] = None, + policy_exploration: Optional[Union[TensorDictModuleWrapper, SafeModule]] = None, replay_buffer: Optional[ReplayBuffer] = None, logger: Optional[Logger] = None, cfg: "DictConfig" = None, # noqa: F821 @@ -113,7 +112,7 @@ def make_trainer( >>> from torchrl.collectors.collectors import SyncDataCollector >>> from torchrl.data import TensorDictReplayBuffer >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.modules import TensorDictModuleWrapper, TensorDictModule, ValueOperator, EGreedyWrapper + >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper >>> from torchrl.objectives.common import LossModule >>> from torchrl.objectives.utils import TargetNetUpdater >>> from torchrl.objectives import DDPGLoss @@ -123,7 +122,7 @@ def make_trainer( >>> action_spec = env_proof.action_spec >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1]) >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing - >>> policy = TensorDictModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) + >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"]) >>> collector = SyncDataCollector(env_maker, policy, total_frames=100) >>> loss_module = DDPGLoss(policy, value, gamma=0.99) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 1909df1370a..1b6014ccb3a 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -29,7 +29,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.common import LossModule from torchrl.trainers.loggers import Logger @@ -1036,7 +1036,7 @@ def __init__( record_interval: int, record_frames: int, frame_skip: int, - policy_exploration: TensorDictModule, + policy_exploration: SafeModule, recorder: EnvBase, exploration_mode: str = "random", log_keys: Optional[List[str]] = None,