From e60ce73941099a88243db63cf28c45b1d4e0959f Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 11:59:18 +0000 Subject: [PATCH 01/14] Migrate TensorDictModule --- torchrl/modules/__init__.py | 1 - torchrl/modules/tensordict_module/__init__.py | 2 +- torchrl/modules/tensordict_module/actors.py | 6 +- torchrl/modules/tensordict_module/common.py | 380 +----------------- .../modules/tensordict_module/exploration.py | 2 +- torchrl/trainers/helpers/collectors.py | 4 +- torchrl/trainers/helpers/trainers.py | 3 +- 7 files changed, 13 insertions(+), 385 deletions(-) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8d1cdd8203e..b6fd254ff9f 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -50,7 +50,6 @@ ActorCriticWrapper, DistributionalQValueActor, TensorDictModule, - TensorDictModuleWrapper, EGreedyWrapper, AdditiveGaussianWrapper, OrnsteinUhlenbeckProcessWrapper, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 47558bcdbc5..56fb5b47423 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -13,7 +13,7 @@ ActorCriticWrapper, DistributionalQValueActor, ) -from .common import TensorDictModule, TensorDictModuleWrapper +from .common import TensorDictModule from .exploration import ( EGreedyWrapper, AdditiveGaussianWrapper, diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 22384bdce53..23fe6621196 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -8,12 +8,10 @@ import torch from torch import nn +from tensordict.nn import TensorDictModuleWrapper from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, TensorSpec from torchrl.modules.models.models import DistributionalDQNnet -from torchrl.modules.tensordict_module.common import ( - TensorDictModule, - TensorDictModuleWrapper, -) +from torchrl.modules.tensordict_module.common import TensorDictModule from torchrl.modules.tensordict_module.probabilistic import ( ProbabilisticTensorDictModule, ) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 81ca1c33f1b..220c24b5b0e 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -7,22 +7,11 @@ import inspect import warnings -from copy import deepcopy -from textwrap import indent -from typing import ( - Any, - Iterable, - List, - Optional, - Sequence, - Type, - Union, -) +from typing import Iterable, Optional, Type, Union import torch from torchrl.data.utils import DEVICE_TYPING -from torchrl.modules import functional_modules _has_functorch = False try: @@ -42,6 +31,7 @@ FunctionalModuleWithBuffers, ) +from tensordict.nn import TensorDictModule as _TensorDictModule from tensordict.tensordict import TensorDictBase from torch import nn, Tensor @@ -99,7 +89,7 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): ) -class TensorDictModule(nn.Module): +class TensorDictModule(_TensorDictModule): """A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. Args: @@ -191,22 +181,7 @@ def __init__( spec: Optional[TensorSpec] = None, safe: bool = False, ): - - super().__init__() - - if not out_keys: - raise RuntimeError(f"out_keys were not passed to {self.__class__.__name__}") - if not in_keys: - raise RuntimeError(f"in_keys were not passed to {self.__class__.__name__}") - self.out_keys = out_keys - _check_all_str(self.out_keys) - self.in_keys = in_keys - _check_all_str(self.in_keys) - - if "_" in in_keys: - warnings.warn( - 'key "_" is for ignoring output, it should not be used in input keys' - ) + super().__init__(module, in_keys, out_keys) if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") @@ -247,23 +222,6 @@ def __init__( ) self.register_forward_hook(_forward_hook_safe_action) - self.module = module - - @property - def is_functional(self): - if not _has_functorch: - return isinstance( - self.module, - ( - functional_modules.FunctionalModule, - functional_modules.FunctionalModuleWithBuffers, - ), - ) - return isinstance( - self.module, - (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers), - ) - @property def spec(self) -> CompositeSpec: return self._spec @@ -276,144 +234,6 @@ def spec(self, spec: CompositeSpec) -> None: ) self._spec = spec - def _write_to_tensordict( - self, - tensordict: TensorDictBase, - tensors: List, - tensordict_out: Optional[TensorDictBase] = None, - out_keys: Optional[Iterable[str]] = None, - vmap: Optional[int] = None, - ) -> TensorDictBase: - - if out_keys is None: - out_keys = self.out_keys - if ( - (tensordict_out is None) - and vmap - and (isinstance(vmap, bool) or vmap[-1] is None) - ): - dim = tensors[0].shape[0] - tensordict_out = tensordict.expand(dim, *tensordict.batch_size).contiguous() - elif tensordict_out is None: - tensordict_out = tensordict - for _out_key, _tensor in zip(out_keys, tensors): - if _out_key != "_": - tensordict_out.set(_out_key, _tensor) - return tensordict_out - - def _make_vmap(self, buffers, kwargs, n_input): - if "vmap" in kwargs and kwargs["vmap"]: - if not isinstance(kwargs["vmap"], (tuple, bool)): - raise RuntimeError( - "vmap argument must be a boolean or a tuple of dim expensions." - ) - # if vmap is a tuple, we make sure the number of inputs after params and buffers match - if isinstance(kwargs["vmap"], (tuple, list)): - err_msg = f"the vmap argument had {len(kwargs['vmap'])} elements, but the module has {len(self.in_keys)} inputs" - if isinstance( - self.module, - (FunctionalModuleWithBuffers, rlFunctionalModuleWithBuffers), - ): - if len(kwargs["vmap"]) == 3: - _vmap = ( - *kwargs["vmap"][:2], - *[kwargs["vmap"][2]] * len(self.in_keys), - ) - elif len(kwargs["vmap"]) == 2 + len(self.in_keys): - _vmap = kwargs["vmap"] - else: - raise RuntimeError(err_msg) - elif isinstance(self.module, (FunctionalModule, rlFunctionalModule)): - if len(kwargs["vmap"]) == 2: - _vmap = ( - *kwargs["vmap"][:1], - *[kwargs["vmap"][1]] * len(self.in_keys), - ) - elif len(kwargs["vmap"]) == 1 + len(self.in_keys): - _vmap = kwargs["vmap"] - else: - raise RuntimeError(err_msg) - else: - raise TypeError( - f"vmap not compatible with modules of type {type(self.module)}" - ) - else: - _vmap = ( - (0, 0, *(None,) * n_input) - if buffers is not None - else (0, *(None,) * n_input) - ) - return _vmap - - def _call_module( - self, - tensors: Sequence[Tensor], - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> Union[Tensor, Sequence[Tensor]]: - err_msg = "Did not find the {0} keyword argument to be used with the functional module. Check it was passed to the TensorDictModule method." - if isinstance( - self.module, - ( - FunctionalModule, - FunctionalModuleWithBuffers, - rlFunctionalModule, - rlFunctionalModuleWithBuffers, - ), - ): - _vmap = self._make_vmap(buffers, kwargs, len(tensors)) - if _vmap: - module = vmap(self.module, _vmap) - else: - module = self.module - - if isinstance(self.module, (FunctionalModule, rlFunctionalModule)): - if params is None: - raise KeyError(err_msg.format("params")) - kwargs_pruned = { - key: item for key, item in kwargs.items() if key not in ("vmap") - } - out = module(params, *tensors, **kwargs_pruned) - return out - - elif isinstance( - self.module, (FunctionalModuleWithBuffers, rlFunctionalModuleWithBuffers) - ): - if params is None: - raise KeyError(err_msg.format("params")) - if buffers is None: - raise KeyError(err_msg.format("buffers")) - - kwargs_pruned = { - key: item for key, item in kwargs.items() if key not in ("vmap") - } - out = module(params, buffers, *tensors, **kwargs_pruned) - return out - else: - out = self.module(*tensors, **kwargs) - return out - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys) - tensors = self._call_module(tensors, params=params, buffers=buffers, **kwargs) - if not isinstance(tensors, tuple): - tensors = (tensors,) - tensordict_out = self._write_to_tensordict( - tensordict, - tensors, - tensordict_out, - vmap=kwargs.get("vmap", False), - ) - return tensordict_out - def random(self, tensordict: TensorDictBase) -> TensorDictBase: """Samples a random element in the target space, irrespective of any input. @@ -434,204 +254,12 @@ def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: """See :obj:`TensorDictModule.random(...)`.""" return self.random(tensordict) - @property - def device(self): - for p in self.parameters(): - return p.device - return torch.device("cpu") - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> TensorDictModule: if hasattr(self, "spec") and self.spec is not None: self.spec = self.spec.to(dest) out = super().to(dest) return out - def __repr__(self) -> str: - fields = indent( - f"module={self.module}, \n" - f"device={self.device}, \n" - f"in_keys={self.in_keys}, \n" - f"out_keys={self.out_keys}", - 4 * " ", - ) - - return f"{self.__class__.__name__}(\n{fields})" - - def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - """Transforms a stateful module in a functional module and returns its parameters and buffers. - - Unlike functorch.make_functional_with_buffers, this method supports lazy modules. - - Args: - clone (bool, optional): if True, a clone of the module is created before it is returned. - This is useful as it prevents the original module to be scraped off of its - parameters and buffers. - Defaults to True - native (bool, optional): if True, TorchRL's functional modules will be used. - Defaults to True - - Returns: - A tuple of parameter and buffer tuples - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> lazy_module = nn.LazyLinear(4) - >>> spec = NdUnboundedContinuousTensorSpec(18) - >>> td_module = TensorDictModule(lazy_module, spec, ["some_input"], - ... ["some_output"]) - >>> _, (params, buffers) = td_module.make_functional_with_buffers() - >>> print(params[0].shape) # the lazy module has been initialized - torch.Size([4, 18]) - >>> print(td_module( - ... TensorDict({'some_input': torch.randn(18)}, batch_size=[]), - ... params=params, - ... buffers=buffers)) - TensorDict( - fields={ - some_input: Tensor(torch.Size([18]), dtype=torch.float32), - some_output: Tensor(torch.Size([4]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - - """ - native = native or not _has_functorch - if clone: - self_copy = deepcopy(self) - else: - self_copy = self - - if isinstance( - self_copy.module, - ( - TensorDictModule, - FunctionalModule, - FunctionalModuleWithBuffers, - rlFunctionalModule, - rlFunctionalModuleWithBuffers, - ), - ): - raise RuntimeError( - "TensorDictModule.make_functional_with_buffers requires the " - "module to be a regular nn.Module. " - f"Found type {type(self_copy.module)}" - ) - - # check if there is a non-initialized lazy module - for m in self_copy.module.modules(): - if hasattr(m, "has_uninitialized_params") and m.has_uninitialized_params(): - pseudo_input = self_copy.spec.rand() - self_copy.module(pseudo_input) - break - - module = self_copy.module - if native: - fmodule, params, buffers = rlFunctionalModuleWithBuffers._create_from( - module - ) - else: - fmodule, params, buffers = functorch.make_functional_with_buffers(module) - self_copy.module = fmodule - - # Erase meta params - for _ in fmodule.parameters(): - none_state = [None for _ in params + buffers] - if hasattr(fmodule, "all_names_map"): - # functorch >= 0.2.0 - _swap_state(fmodule.stateless_model, fmodule.all_names_map, none_state) - else: - # functorch < 0.2.0 - _swap_state(fmodule.stateless_model, fmodule.split_names, none_state) - - break - - return self_copy, (params, buffers) - - @property - def num_params(self): - if _has_functorch and isinstance( - self.module, - (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers), - ): - return len(self.module.param_names) - else: - return 0 - - @property - def num_buffers(self): - if _has_functorch and isinstance( - self.module, (functorch.FunctionalModuleWithBuffers,) - ): - return len(self.module.buffer_names) - else: - return 0 - - -class TensorDictModuleWrapper(nn.Module): - """Wrapper calss for TensorDictModule objects. - - Once created, a TensorDictModuleWrapper will behave exactly as the TensorDictModule it contains except for the methods that are - overwritten. - - Args: - td_module (TensorDictModule): operator to be wrapped. - - Examples: - >>> # This class can be used for exploration wrappers - >>> import functorch - >>> import torch - >>> from tensordict import TensorDict - >>> from tensordict.utils import expand_as_right - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TensorDictModuleWrapper, TensorDictModule - >>> - >>> class EpsilonGreedyExploration(TensorDictModuleWrapper): - ... eps = 0.5 - ... def forward(self, tensordict, params, buffers): - ... rand_output_clone = self.random(tensordict.clone()) - ... det_output_clone = self.td_module(tensordict.clone(), params, buffers) - ... rand_output_idx = torch.rand(tensordict.shape, device=rand_output_clone.device) < self.eps - ... for key in self.out_keys: - ... _rand_output = rand_output_clone.get(key) - ... _det_output = det_output_clone.get(key) - ... rand_output_idx_expand = expand_as_right(rand_output_idx, _rand_output).to(_rand_output.dtype) - ... tensordict.set(key, - ... rand_output_idx_expand * _rand_output + (1-rand_output_idx_expand) * _det_output) - ... return tensordict - >>> - >>> td = TensorDict({"input": torch.zeros(10, 4)}, [10]) - >>> module = torch.nn.Linear(4, 4, bias=False) # should return a zero tensor if input is a zero tensor - >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) - >>> spec = NdUnboundedContinuousTensorSpec(4) - >>> tensordict_module = TensorDictModule(module=fmodule, spec=spec, in_keys=["input"], out_keys=["output"]) - >>> tensordict_module_wrapped = EpsilonGreedyExploration(tensordict_module) - >>> tensordict_module_wrapped(td, params=params, buffers=buffers) - >>> print(td.get("output")) - - """ - - def __init__(self, td_module: TensorDictModule): - super().__init__() - self.td_module = td_module - if len(self.td_module._forward_hooks): - for pre_hook in self.td_module._forward_hooks: - self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) - - def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - if name not in self.__dict__ and not name.startswith("__"): - return getattr(self._modules["td_module"], name) - else: - raise AttributeError( - f"attribute {name} not recognised in {type(self).__name__}" - ) - - def forward(self, *args, **kwargs): - return self.td_module.forward(*args, **kwargs) - def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): """Returns `True` if a module can be used as a TensorDictModule, and False if it can't. diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index a07b13aeccc..ed56b5a05de 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,6 +7,7 @@ import numpy as np import torch +from tensordict.nn import TensorDictModuleWrapper from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right @@ -15,7 +16,6 @@ from torchrl.modules.tensordict_module.common import ( _forward_hook_safe_action, TensorDictModule, - TensorDictModuleWrapper, ) diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 1d72aab2643..11717b5ac67 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -6,8 +6,10 @@ from dataclasses import dataclass, field from typing import Callable, List, Optional, Type, Union, Dict, Any +from tensordict.nn import TensorDictModuleWrapper from tensordict.tensordict import TensorDictBase + from torchrl.collectors.collectors import ( _DataCollector, SyncDataCollector, @@ -17,7 +19,7 @@ from torchrl.data import MultiStep from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.modules import TensorDictModuleWrapper, ProbabilisticTensorDictModule +from torchrl.modules import ProbabilisticTensorDictModule def sync_async_collector( diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 10333ad82f3..36662133de4 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -8,13 +8,14 @@ from warnings import warn import torch +from tensordict.nn import TensorDictModuleWrapper from torch import optim from torch.optim.lr_scheduler import CosineAnnealingLR from torchrl.collectors.collectors import _DataCollector from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase -from torchrl.modules import TensorDictModule, TensorDictModuleWrapper, reset_noise +from torchrl.modules import TensorDictModule, reset_noise from torchrl.objectives.common import LossModule from torchrl.objectives.utils import TargetNetUpdater from torchrl.trainers.loggers import Logger From 5db5d408f634588b69379481e4a668f13091b973 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 13:46:50 +0000 Subject: [PATCH 02/14] Migrate functional modules --- test/test_cost.py | 2 +- test/test_functorch.py | 6 +- test/test_modules.py | 8 +- test/test_tensordictmodules.py | 2 +- torchrl/modules/__init__.py | 12 +- torchrl/modules/functional_modules.py | 293 -------------------- torchrl/modules/tensordict_module/common.py | 8 +- torchrl/objectives/common.py | 2 +- 8 files changed, 20 insertions(+), 313 deletions(-) delete mode 100644 torchrl/modules/functional_modules.py diff --git a/test/test_cost.py b/test/test_cost.py index af513a2d55e..fc0521760a4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6,7 +6,7 @@ import argparse from copy import deepcopy -from torchrl.modules.functional_modules import FunctionalModuleWithBuffers +from tensordict.nn.functional_modules import FunctionalModuleWithBuffers _has_functorch = True try: diff --git a/test/test_functorch.py b/test/test_functorch.py index e84b41a8679..95cea41b97a 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -10,12 +10,12 @@ except ImportError: _has_functorch = False from tensordict import TensorDict -from torch import nn -from torchrl.modules import TensorDictModule, TensorDictSequential -from torchrl.modules.functional_modules import ( +from tensordict.nn.functional_modules import ( FunctionalModule, FunctionalModuleWithBuffers, ) +from torch import nn +from torchrl.modules import TensorDictModule, TensorDictSequential @pytest.mark.skipif( diff --git a/test/test_modules.py b/test/test_modules.py index fa5305ac75b..9df0a50bb45 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -11,6 +11,10 @@ from mocking_classes import MockBatchedUnLockedEnv from packaging import version from tensordict import TensorDict +from tensordict.nn.functional_modules import ( + FunctionalModule, + FunctionalModuleWithBuffers, +) from torch import nn from torchrl.data.tensor_specs import ( DiscreteTensorSpec, @@ -26,10 +30,6 @@ TensorDictModule, ValueOperator, ) -from torchrl.modules.functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, -) from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear from torchrl.modules.models.model_based import ( DreamerActor, diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 6910ec2b497..5009461aeb2 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -15,7 +15,7 @@ _has_functorch = True except ImportError: - from torchrl.modules.functional_modules import ( + from tensordict.nn.functional_modules import ( FunctionalModule, FunctionalModuleWithBuffers, ) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index b6fd254ff9f..4ee60f359e9 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -13,12 +13,12 @@ OneHotCategorical, distributions_maps, ) -from .functional_modules import ( - FunctionalModule, - FunctionalModuleWithBuffers, - extract_weights, - extract_buffers, -) +# from .functional_modules import ( +# FunctionalModule, +# FunctionalModuleWithBuffers, +# extract_weights, +# extract_buffers, +# ) from .models import ( NoisyLinear, NoisyLazyLinear, diff --git a/torchrl/modules/functional_modules.py b/torchrl/modules/functional_modules.py deleted file mode 100644 index 7e7160e0fef..00000000000 --- a/torchrl/modules/functional_modules.py +++ /dev/null @@ -1,293 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from copy import deepcopy - -import torch -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase -from torch import nn - -_RESET_OLD_TENSORDICT = True -try: - import functorch._src.vmap - - _has_functorch = True -except ImportError: - _has_functorch = False - -# Monky-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked -if _has_functorch: - from functorch._src.vmap import ( - _get_name, - tree_flatten, - _broadcast_to_and_flatten, - Tensor, - _validate_and_get_batch_size, - _add_batch_dim, - tree_unflatten, - _remove_batch_dim, - ) - - # Monkey-patches - - def _process_batched_inputs(in_dims, args, func): - if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -expected `in_dims` to be int or a (potentially nested) tuple -matching the structure of inputs, got: {type(in_dims)}.""" - ) - if len(args) == 0: - raise ValueError( - f"""vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add -inputs, or you are trying to vmap over a function with no inputs. -The latter is unsupported.""" - ) - - flat_args, args_spec = tree_flatten(args) - flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) - if flat_in_dims is None: - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -in_dims is not compatible with the structure of `inputs`. -in_dims has structure {tree_flatten(in_dims)[1]} but inputs -has structure {args_spec}.""" - ) - - for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): - if not isinstance(in_dim, int) and in_dim is not None: - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -Got in_dim={in_dim} for an input but in_dim must be either -an integer dimension or None.""" - ) - if isinstance(in_dim, int) and not isinstance( - arg, (Tensor, TensorDictBase) - ): - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -Got in_dim={in_dim} for an input but the input is of type -{type(arg)}. We cannot vmap over non-Tensor arguments, -please use None as the respective in_dim""" - ) - if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): - raise ValueError( - f"""vmap({_get_name(func)}, in_dims={in_dims}, ...)(): -Got in_dim={in_dim} for some input, but that input is a Tensor -of dimensionality {arg.dim()} so expected in_dim to satisfy --{arg.dim()} <= in_dim < {arg.dim()}.""" - ) - if in_dim is not None and in_dim < 0: - flat_in_dims[i] = in_dim % arg.dim() - - return ( - _validate_and_get_batch_size(flat_in_dims, flat_args), - flat_in_dims, - flat_args, - args_spec, - ) - - functorch._src.vmap._process_batched_inputs = _process_batched_inputs - - def _create_batched_inputs(flat_in_dims, flat_args, vmap_level: int, args_spec): - # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] - # If tensordict, we remove the dim at batch_size[in_dim] such that the TensorDict can accept - # the batched tensors. This will be added in _unwrap_batched - batched_inputs = [ - arg - if in_dim is None - else arg.apply( - lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level), - batch_size=[b for i, b in enumerate(arg.batch_size) if i != in_dim], - ) - if isinstance(arg, TensorDictBase) - else _add_batch_dim(arg, in_dim, vmap_level) - for in_dim, arg in zip(flat_in_dims, flat_args) - ] - return tree_unflatten(batched_inputs, args_spec) - - functorch._src.vmap._create_batched_inputs = _create_batched_inputs - - def _unwrap_batched( - batched_outputs, out_dims, vmap_level: int, batch_size: int, func - ): - flat_batched_outputs, output_spec = tree_flatten(batched_outputs) - - for out in flat_batched_outputs: - # Change here: - if isinstance(out, (TensorDictBase, torch.Tensor)): - continue - raise ValueError( - f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " - f"Tensors, got type {type(out)} as a return." - ) - - def incompatible_error(): - raise ValueError( - f"vmap({_get_name(func)}, ..., out_dims={out_dims})(): " - f"out_dims is not compatible with the structure of `outputs`. " - f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " - f"has structure {output_spec}." - ) - - # Here: - if isinstance(batched_outputs, (TensorDictBase, torch.Tensor)): - # Some weird edge case requires us to spell out the following - # see test_out_dims_edge_case - if isinstance(out_dims, int): - flat_out_dims = [out_dims] - elif isinstance(out_dims, tuple) and len(out_dims) == 1: - flat_out_dims = out_dims - out_dims = out_dims[0] - else: - incompatible_error() - else: - flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) - if flat_out_dims is None: - incompatible_error() - - flat_outputs = [] - for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims): - if not isinstance(batched_output, TensorDictBase): - out = _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) - else: - out = batched_output.apply( - lambda x: _remove_batch_dim(x, vmap_level, batch_size, out_dim), - batch_size=[batch_size, *batched_output.batch_size], - ) - flat_outputs.append(out) - return tree_unflatten(flat_outputs, output_spec) - - functorch._src.vmap._unwrap_batched = _unwrap_batched - -# Tensordict-compatible Functional modules - - -class FunctionalModule(nn.Module): - """This is the callable object returned by :func:`make_functional`.""" - - def __init__(self, stateless_model): - super(FunctionalModule, self).__init__() - self.stateless_model = stateless_model - - @staticmethod - def _create_from(model, disable_autograd_tracking=False): - # TODO: We don't need to copy the model to create a stateless copy - model_copy = deepcopy(model) - param_tensordict = extract_weights(model_copy) - if disable_autograd_tracking: - param_tensordict.apply(lambda x: x.requires_grad_(False), inplace=True) - return FunctionalModule(model_copy), param_tensordict - - def forward(self, params, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state( - self.stateless_model, params, return_old_tensordict=_RESET_OLD_TENSORDICT - ) - try: - return self.stateless_model(*args, **kwargs) - finally: - # Remove the loaded state on self.stateless_model - if _RESET_OLD_TENSORDICT: - _swap_state(self.stateless_model, old_state) - - -class FunctionalModuleWithBuffers(nn.Module): - """This is the callable object returned by :func:`make_functional`.""" - - def __init__(self, stateless_model): - super(FunctionalModuleWithBuffers, self).__init__() - self.stateless_model = stateless_model - - @staticmethod - def _create_from(model, disable_autograd_tracking=False): - # TODO: We don't need to copy the model to create a stateless copy - model_copy = deepcopy(model) - param_tensordict = extract_weights(model_copy) - buffers = extract_buffers(model_copy) - if buffers is None: - buffers = TensorDict( - {}, param_tensordict.batch_size, device=param_tensordict.device - ) - if disable_autograd_tracking: - param_tensordict.apply(lambda x: x.requires_grad_(False), inplace=True) - return FunctionalModuleWithBuffers(model_copy), param_tensordict, buffers - - def forward(self, params, buffers, *args, **kwargs): - # Temporarily load the state back onto self.stateless_model - old_state = _swap_state( - self.stateless_model, params, return_old_tensordict=_RESET_OLD_TENSORDICT - ) - old_state_buffers = _swap_state( - self.stateless_model, buffers, return_old_tensordict=_RESET_OLD_TENSORDICT - ) - - try: - return self.stateless_model(*args, **kwargs) - finally: - # Remove the loaded state on self.stateless_model - if _RESET_OLD_TENSORDICT: - _swap_state(self.stateless_model, old_state) - _swap_state(self.stateless_model, old_state_buffers) - - -# Some utils for these - - -def extract_weights(model: nn.Module): - """Extracts the weights of a model in a tensordict.""" - tensordict = TensorDict({}, []) - for name, param in list(model.named_parameters(recurse=False)): - setattr(model, name, None) - tensordict[name] = param - for name, module in model.named_children(): - module_tensordict = extract_weights(module) - if module_tensordict is not None: - tensordict[name] = module_tensordict - if len(tensordict.keys()): - return tensordict - else: - return None - - -def extract_buffers(model: nn.Module): - """Extracts the buffers of a model in a tensordict.""" - tensordict = TensorDict({}, []) - for name, param in list(model.named_buffers(recurse=False)): - setattr(model, name, None) - tensordict[name] = param - for name, module in model.named_children(): - module_tensordict = extract_buffers(module) - if module_tensordict is not None: - tensordict[name] = module_tensordict - if len(tensordict.keys()): - return tensordict - else: - return None - - -def _swap_state(model, tensordict, return_old_tensordict=False): - # if return_old_tensordict: - # old_tensordict = tensordict.clone(recurse=False) - # old_tensordict.batch_size = [] - - if return_old_tensordict: - old_tensordict = TensorDict({}, [], device=tensordict.device) - - for key, value in list(tensordict.items()): - if isinstance(value, TensorDictBase): - _swap_state(getattr(model, key), value) - else: - if return_old_tensordict: - old_attr = getattr(model, key) - if old_attr is None: - old_attr = torch.tensor([]).view(*value.shape, 0) - delattr(model, key) - setattr(model, key, value) - if return_old_tensordict: - old_tensordict.set(key, old_attr) - if return_old_tensordict: - return old_tensordict diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 220c24b5b0e..0f1980526f5 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -26,7 +26,7 @@ "functional programming should work, but functionality and performance " "may be affected. Consider installing functorch and/or upgrating pytorch." ) - from torchrl.modules.functional_modules import ( + from tensordict.nn.functional_modules import ( FunctionalModule, FunctionalModuleWithBuffers, ) @@ -39,9 +39,9 @@ TensorSpec, CompositeSpec, ) -from torchrl.modules.functional_modules import ( - FunctionalModule as rlFunctionalModule, - FunctionalModuleWithBuffers as rlFunctionalModuleWithBuffers, +from tensordict.nn.functional_modules import ( + FunctionalModule as tdFunctionalModule, + FunctionalModuleWithBuffers as tdFunctionalModuleWithBuffers, ) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 7db260108da..727fbf20abf 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -9,7 +9,7 @@ import torch -from torchrl.modules.functional_modules import FunctionalModuleWithBuffers +from tensordict.nn.functional_modules import FunctionalModuleWithBuffers _has_functorch = False try: From d318a5802376b05452d0b22278528a2cf1b1ce82 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 14:13:52 +0000 Subject: [PATCH 03/14] Migrate probabilistic modules --- .../tensordict_module/probabilistic.py | 230 ++---------------- 1 file changed, 22 insertions(+), 208 deletions(-) diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 0e33039f89d..44b19bb7c15 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -3,22 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import re -from copy import deepcopy -from textwrap import indent -from typing import List, Sequence, Union, Type, Optional, Tuple +from typing import Sequence, Union, Type, Optional -from tensordict.tensordict import TensorDictBase -from torch import Tensor -from torch import distributions as d +from tensordict.nn import ( + ProbabilisticTensorDictModule as _ProbabilisticTensorDictModule, +) from torchrl.data import TensorSpec -from torchrl.envs.utils import exploration_mode, set_exploration_mode -from torchrl.modules.distributions import distributions_maps, Delta -from torchrl.modules.tensordict_module.common import TensorDictModule, _check_all_str +from torchrl.modules.distributions import Delta +from torchrl.modules.tensordict_module.common import TensorDictModule -class ProbabilisticTensorDictModule(TensorDictModule): +class ProbabilisticTensorDictModule(_ProbabilisticTensorDictModule, TensorDictModule): """A probabilistic TD Module. `ProbabilisticTDModule` is a special case of a TDModule where the output is @@ -152,203 +148,21 @@ def __init__( cache_dist: bool = False, n_empirical_estimate: int = 1000, ): - in_keys = module.in_keys - - # if the module returns the sampled key we wont be sampling it again - # then ProbabilisticTensorDictModule is presumably used to return the distribution using `get_dist` - if isinstance(dist_in_keys, str): - dist_in_keys = [dist_in_keys] - if isinstance(sample_out_key, str): - sample_out_key = [sample_out_key] - if not isinstance(dist_in_keys, dict): - dist_in_keys = {param_key: param_key for param_key in dist_in_keys} - for key in dist_in_keys.values(): - if key not in module.out_keys: - raise RuntimeError( - f"The key {key} could not be found in the wrapped module `{type(module)}.out_keys`." - ) - module_out_keys = module.out_keys - self.sample_out_key = sample_out_key - _check_all_str(self.sample_out_key) - sample_out_key = [key for key in sample_out_key if key not in module_out_keys] - self._requires_sample = bool(len(sample_out_key)) - out_keys = sample_out_key + module_out_keys super().__init__( - module=module, spec=spec, in_keys=in_keys, out_keys=out_keys, safe=safe - ) - self.dist_in_keys = dist_in_keys - _check_all_str(self.dist_in_keys.keys()) - _check_all_str(self.dist_in_keys.values()) - - self.default_interaction_mode = default_interaction_mode - if isinstance(distribution_class, str): - distribution_class = distributions_maps.get(distribution_class.lower()) - self.distribution_class = distribution_class - self.distribution_kwargs = ( - distribution_kwargs if distribution_kwargs is not None else dict() + module=module, + dist_in_keys=dist_in_keys, + sample_out_key=sample_out_key, + default_interaction_mode=default_interaction_mode, + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=return_log_prob, + cache_dist=cache_dist, + n_empirical_estimate=n_empirical_estimate, ) - self.n_empirical_estimate = n_empirical_estimate - self._dist = None - self.cache_dist = cache_dist if hasattr(distribution_class, "update") else False - self.return_log_prob = return_log_prob - - def _call_module( - self, - tensordict: TensorDictBase, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - return self.module(tensordict, params=params, buffers=buffers, **kwargs) - - def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - module_params = self.parameters(recurse=False) - if len(list(module_params)): - raise RuntimeError( - "make_functional_with_buffers cannot be called on ProbabilisticTensorDictModule" - "that contain parameters on the outer level." - ) - if clone: - self_copy = deepcopy(self) - else: - self_copy = self - - self_copy.module, other = self_copy.module.make_functional_with_buffers( - clone=True, - native=native, - ) - return self_copy, other - - def get_dist( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> Tuple[d.Distribution, TensorDictBase]: - interaction_mode = exploration_mode() - if interaction_mode is None: - interaction_mode = self.default_interaction_mode - with set_exploration_mode(interaction_mode): - tensordict_out = self._call_module( - tensordict, - tensordict_out=tensordict_out, - params=params, - buffers=buffers, - **kwargs, - ) - dist = self.build_dist_from_params(tensordict_out) - return dist, tensordict_out - - def build_dist_from_params(self, tensordict_out: TensorDictBase) -> d.Distribution: - try: - selected_td_out = tensordict_out.select(*self.dist_in_keys.values()) - dist_kwargs = { - dist_key: selected_td_out[td_key] - for dist_key, td_key in self.dist_in_keys.items() - } - dist = self.distribution_class(**dist_kwargs) - except TypeError as err: - if "an unexpected keyword argument" in str(err): - raise TypeError( - "distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_in_keys must match." - f"Got this error message: \n{indent(str(err), 4 * ' ')}\nwith dist_in_keys={self.dist_in_keys}" - ) - elif re.search(r"missing.*required positional arguments", str(err)): - raise TypeError( - f"TensorDict with keys {tensordict_out.keys()} does not match the distribution {self.distribution_class} keywords." - ) - else: - raise err - return dist - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: Optional[TensorDictBase] = None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - - dist, tensordict_out = self.get_dist( - tensordict, - tensordict_out=tensordict_out, - params=params, - buffers=buffers, - **kwargs, + super(_ProbabilisticTensorDictModule, self).__init__( + module=module, + spec=spec, + in_keys=self.in_keys, + out_keys=self.out_keys, + safe=safe, ) - if self._requires_sample: - out_tensors = self._dist_sample(dist, interaction_mode=exploration_mode()) - if isinstance(out_tensors, Tensor): - out_tensors = (out_tensors,) - tensordict_out.update( - {key: value for key, value in zip(self.sample_out_key, out_tensors)} - ) - if self.return_log_prob: - log_prob = dist.log_prob(*out_tensors) - tensordict_out.set("sample_log_prob", log_prob) - elif self.return_log_prob: - out_tensors = [tensordict_out.get(key) for key in self.sample_out_key] - log_prob = dist.log_prob(*out_tensors) - tensordict_out.set("sample_log_prob", log_prob) - # raise RuntimeError( - # "ProbabilisticTensorDictModule.return_log_prob = True is incompatible with settings in which " - # "the submodule is responsible for sampling. To manually gather the log-probability, call first " - # "\n>>> dist, tensordict = tensordict_module.get_dist(tensordict)" - # "\n>>> tensordict.set('sample_log_prob', dist.log_prob(tensordict.get(sample_key))" - # ) - return tensordict_out - - def _dist_sample( - self, - dist: d.Distribution, - *tensors: Tensor, - interaction_mode: bool = None, - ) -> Union[Tuple[Tensor], Tensor]: - if interaction_mode is None or interaction_mode == "": - interaction_mode = self.default_interaction_mode - if not isinstance(dist, d.Distribution): - raise TypeError(f"type {type(dist)} not recognised by _dist_sample") - - if interaction_mode == "mode": - if hasattr(dist, "mode"): - return dist.mode - else: - raise NotImplementedError( - f"method {type(dist)}.mode is not implemented" - ) - - elif interaction_mode == "median": - if hasattr(dist, "median"): - return dist.median - else: - raise NotImplementedError( - f"method {type(dist)}.median is not implemented" - ) - - elif interaction_mode == "mean": - try: - return dist.mean - except (AttributeError, NotImplementedError): - if dist.has_rsample: - return dist.rsample((self.n_empirical_estimate,)).mean(0) - else: - return dist.sample((self.n_empirical_estimate,)).mean(0) - - elif interaction_mode == "random": - if dist.has_rsample: - return dist.rsample() - else: - return dist.sample() - else: - raise NotImplementedError(f"unknown interaction_mode {interaction_mode}") - - @property - def num_params(self): - return self.module.num_params - - @property - def num_buffers(self): - return self.module.num_buffers From c79714d69d5126fecb34f43d174c729949872867 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 14:21:24 +0000 Subject: [PATCH 04/14] Patch set_exploration_mode --- torchrl/envs/utils.py | 42 ++++-------------------------------------- 1 file changed, 4 insertions(+), 38 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 36a5cd8df36..90f4cc72bad 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Union - import pkg_resources from tensordict.tensordict import TensorDictBase -from torch.autograd.grad_mode import _DecoratorContextManager +from tensordict.nn.probabilistic import ( # noqa + set_interaction_mode as set_exploration_mode, + interaction_mode as exploration_mode, +) AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} @@ -150,38 +151,3 @@ def _check_dmlab(): # "screeps": None, # https://github.com/screeps/screeps # "ml-agents": None, } - -EXPLORATION_MODE = None - - -class set_exploration_mode(_DecoratorContextManager): - """Sets the exploration mode of all ProbabilisticTDModules to the desired mode. - - Args: - mode (str): mode to use when the policy is being called. - - Examples: - >>> policy = Actor(action_spec, module=network, default_interaction_mode="mode") - >>> env.rollout(policy=policy, max_steps=100) # rollout with the "mode" interaction mode - >>> with set_exploration_mode("random"): - >>> env.rollout(policy=policy, max_steps=100) # rollout with the "random" interaction mode - - """ - - def __init__(self, mode: str = "mode"): - super().__init__() - self.mode = mode - - def __enter__(self) -> None: - global EXPLORATION_MODE - self.prev = EXPLORATION_MODE - EXPLORATION_MODE = self.mode - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - global EXPLORATION_MODE - EXPLORATION_MODE = self.prev - - -def exploration_mode() -> Union[str, None]: - """Returns the exploration mode currently set.""" - return EXPLORATION_MODE From b5c9f1d53f1a72aa0c443f6812268458037b61fa Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 14:25:49 +0000 Subject: [PATCH 05/14] Lint and format --- torchrl/envs/utils.py | 2 +- torchrl/modules/__init__.py | 1 + torchrl/modules/tensordict_module/actors.py | 2 +- torchrl/modules/tensordict_module/common.py | 17 ++++------------- torchrl/objectives/common.py | 1 - torchrl/trainers/helpers/collectors.py | 1 - 6 files changed, 7 insertions(+), 17 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 90f4cc72bad..d8299d02a6e 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,11 +4,11 @@ # LICENSE file in the root directory of this source tree. import pkg_resources -from tensordict.tensordict import TensorDictBase from tensordict.nn.probabilistic import ( # noqa set_interaction_mode as set_exploration_mode, interaction_mode as exploration_mode, ) +from tensordict.tensordict import TensorDictBase AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4ee60f359e9..c8e89ccb8fe 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -13,6 +13,7 @@ OneHotCategorical, distributions_maps, ) + # from .functional_modules import ( # FunctionalModule, # FunctionalModuleWithBuffers, diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 23fe6621196..76f553829cd 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -6,9 +6,9 @@ from typing import Optional, Sequence, Tuple, Union import torch +from tensordict.nn import TensorDictModuleWrapper from torch import nn -from tensordict.nn import TensorDictModuleWrapper from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, TensorSpec from torchrl.modules.models.models import DistributionalDQNnet from torchrl.modules.tensordict_module.common import TensorDictModule diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 0f1980526f5..7a92ef614d3 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -15,9 +15,7 @@ _has_functorch = False try: - import functorch - from functorch import FunctionalModule, FunctionalModuleWithBuffers, vmap - from functorch._src.make_functional import _swap_state + from functorch import FunctionalModule, FunctionalModuleWithBuffers _has_functorch = True except ImportError: @@ -33,16 +31,9 @@ from tensordict.nn import TensorDictModule as _TensorDictModule from tensordict.tensordict import TensorDictBase -from torch import nn, Tensor - -from torchrl.data import ( - TensorSpec, - CompositeSpec, -) -from tensordict.nn.functional_modules import ( - FunctionalModule as tdFunctionalModule, - FunctionalModuleWithBuffers as tdFunctionalModuleWithBuffers, -) +from torch import nn + +from torchrl.data import CompositeSpec, TensorSpec def _check_all_str(list_of_str, first_level=True): diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 727fbf20abf..e8196bb29ea 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -8,7 +8,6 @@ from typing import Iterator, Optional, Tuple, List, Union import torch - from tensordict.nn.functional_modules import FunctionalModuleWithBuffers _has_functorch = False diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 11717b5ac67..9786ecab5f0 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -9,7 +9,6 @@ from tensordict.nn import TensorDictModuleWrapper from tensordict.tensordict import TensorDictBase - from torchrl.collectors.collectors import ( _DataCollector, SyncDataCollector, From d6656edadad1b90b64da83dfd982f2cd34331a47 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 15:40:47 +0000 Subject: [PATCH 06/14] Migrate sequential modules --- torchrl/modules/tensordict_module/sequence.py | 378 +----------------- 1 file changed, 10 insertions(+), 368 deletions(-) diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index a1f3b96f8f2..fb0dd94add8 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -5,38 +5,16 @@ from __future__ import annotations -from copy import copy, deepcopy -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Union -_has_functorch = False -try: - import functorch - - _has_functorch = True -except ImportError: - print( - "failed to import functorch. TorchRL's features that do not require " - "functional programming should work, but functionality and performance " - "may be affected. Consider installing functorch and/or upgrating pytorch." - ) - FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." - -import torch -from tensordict.tensordict import ( - LazyStackedTensorDict, - TensorDict, - TensorDictBase, -) -from torch import Tensor, nn +from tensordict.nn import TensorDictSequential as _TensorDictSequential +from torch import nn from torchrl.data import CompositeSpec from torchrl.modules.tensordict_module.common import TensorDictModule -from torchrl.modules.tensordict_module.probabilistic import ( - ProbabilisticTensorDictModule, -) -class TensorDictSequential(TensorDictModule): +class TensorDictSequential(_TensorDictSequential, TensorDictModule): """A sequence of TensorDictModules. Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor @@ -135,6 +113,8 @@ def __init__( *modules: TensorDictModule, partial_tolerant: bool = False, ): + self.partial_tolerant = partial_tolerant + in_keys, out_keys = self._compute_in_and_out_keys(modules) spec = CompositeSpec() @@ -143,90 +123,14 @@ def __init__( spec.update(module.spec) else: spec.update(CompositeSpec({key: None for key in module.out_keys})) - super().__init__( + + super(_TensorDictSequential, self).__init__( spec=spec, module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys, ) - self.partial_tolerant = partial_tolerant - - def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[List]: - in_keys = [] - out_keys = [] - for module in modules: - # we sometimes use in_keys to select keys of a tensordict that are - # necessary to run a TensorDictModule. If a key is an intermediary in - # the chain, there is no reason why it should belong to the input - # TensorDict. - for in_key in module.in_keys: - if in_key not in (out_keys + in_keys): - in_keys.append(in_key) - out_keys += module.out_keys - - out_keys = [ - out_key - for i, out_key in enumerate(out_keys) - if out_key not in out_keys[i + 1 :] - ] - return in_keys, out_keys - - @staticmethod - def _find_functional_module(module: TensorDictModule) -> nn.Module: - if not _has_functorch: - raise ImportError(FUNCTORCH_ERROR) - fmodule = module - while not isinstance( - fmodule, (functorch.FunctionalModule, functorch.FunctionalModuleWithBuffers) - ): - try: - fmodule = fmodule.module - except AttributeError: - raise AttributeError( - f"couldn't find a functional module in module of type {type(module)}" - ) - return fmodule - - @property - def num_params(self): - return self.param_len[-1] - - @property - def num_buffers(self): - return self.buffer_len[-1] - - @property - def param_len(self) -> List[int]: - param_list = [] - prev = 0 - for module in self.module: - param_list.append(module.num_params + prev) - prev = param_list[-1] - return param_list - - @property - def buffer_len(self) -> List[int]: - buffer_list = [] - prev = 0 - for module in self.module: - buffer_list.append(module.num_buffers + prev) - prev = buffer_list[-1] - return buffer_list - - def _split_param( - self, param_list: Iterable[Tensor], params_or_buffers: str - ) -> Iterable[Iterable[Tensor]]: - if params_or_buffers == "params": - list_out = self.param_len - elif params_or_buffers == "buffers": - list_out = self.buffer_len - list_in = [0] + list_out[:-1] - out = [] - for a, b in zip(list_in, list_out): - out.append(param_list[a:b]) - return out - def select_subsequence( self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None ) -> "TensorDictSequential": @@ -239,273 +143,11 @@ def select_subsequence( Returns: A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys. """ - if in_keys is None: - in_keys = deepcopy(self.in_keys) - if out_keys is None: - out_keys = deepcopy(self.out_keys) - id_to_keep = {i for i in range(len(self.module))} - for i, module in enumerate(self.module): - if all(key in in_keys for key in module.in_keys): - in_keys.extend(module.out_keys) - else: - id_to_keep.remove(i) - for i, module in reversed(list(enumerate(self.module))): - if i in id_to_keep: - if any(key in out_keys for key in module.out_keys): - out_keys.extend(module.in_keys) - else: - id_to_keep.remove(i) - id_to_keep = sorted(list(id_to_keep)) - - modules = [self.module[i] for i in id_to_keep] - - if modules == []: - raise ValueError( - "No modules left after selection. Make sure that in_keys and out_keys are coherent." - ) - - return TensorDictSequential(*modules) - - def _run_module( - self, - module, - tensordict, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ): - tensordict_keys = set(tensordict.keys()) - if not self.partial_tolerant or all( - key in tensordict_keys for key in module.in_keys - ): - if params is not None or buffers is not None: - tensordict = module( - tensordict, params=params, buffers=buffers, **kwargs - ) - else: - tensordict = module(tensordict, **kwargs) - elif self.partial_tolerant and isinstance(tensordict, LazyStackedTensorDict): - for sub_td in tensordict.tensordicts: - tensordict_keys = set(sub_td.keys()) - if all(key in tensordict_keys for key in module.in_keys): - if params is not None or buffers is not None: - module(sub_td, params=params, buffers=buffers, **kwargs) - else: - module(sub_td, **kwargs) - tensordict._update_valid_keys() - return tensordict - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out=None, - params: Optional[Union[TensorDictBase, List[Tensor]]] = None, - buffers: Optional[Union[TensorDictBase, List[Tensor]]] = None, - **kwargs, - ) -> TensorDictBase: - if params is not None and buffers is not None: - if isinstance(params, TensorDictBase): - # TODO: implement sorted values and items - param_splits = list(zip(*sorted(list(params.items()))))[1] - buffer_splits = list(zip(*sorted(list(buffers.items()))))[1] - else: - param_splits = self._split_param(params, "params") - buffer_splits = self._split_param(buffers, "buffers") - for i, (module, param, buffer) in enumerate( - zip(self.module, param_splits, buffer_splits) - ): - if "vmap" in kwargs and i > 0: - # the tensordict is already expended - if not isinstance(kwargs["vmap"], tuple): - kwargs["vmap"] = (0, 0, *(0,) * len(module.in_keys)) - else: - kwargs["vmap"] = ( - *kwargs["vmap"][:2], - *(0,) * len(module.in_keys), - ) - tensordict = self._run_module( - module, tensordict, params=param, buffers=buffer, **kwargs - ) - - elif params is not None: - if isinstance(params, TensorDictBase): - # TODO: implement sorted values and items - param_splits = list(zip(*sorted(list(params.items()))))[1] - else: - param_splits = self._split_param(params, "params") - for i, (module, param) in enumerate(zip(self.module, param_splits)): - if "vmap" in kwargs and i > 0: - # the tensordict is already expended - if not isinstance(kwargs["vmap"], tuple): - kwargs["vmap"] = (0, *(0,) * len(module.in_keys)) - else: - kwargs["vmap"] = ( - *kwargs["vmap"][:1], - *(0,) * len(module.in_keys), - ) - tensordict = self._run_module( - module, tensordict, params=param, **kwargs - ) - - elif not len(kwargs): - for module in self.module: - tensordict = self._run_module(module, tensordict, **kwargs) - else: - raise RuntimeError( - "TensorDictSequential does not support keyword arguments other than 'tensordict_out', 'in_keys', 'out_keys' 'params', 'buffers' and 'vmap'" - ) - if tensordict_out is not None: - tensordict_out.update(tensordict, inplace=True) - return tensordict_out - return tensordict - - def __len__(self): - return len(self.module) + td_sequential = super().select_subsequence(in_keys=in_keys, out_keys=out_keys) + return TensorDictSequential(*td_sequential.module) def __getitem__(self, index: Union[int, slice]) -> TensorDictModule: if isinstance(index, int): return self.module.__getitem__(index) else: return TensorDictSequential(*self.module.__getitem__(index)) - - def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None: - return self.module.__setitem__(idx=index, module=tensordict_module) - - def __delitem__(self, index: Union[int, slice]) -> None: - self.module.__delitem__(idx=index) - - def make_functional_with_buffers(self, clone: bool = True, native: bool = False): - """Transforms a stateful module in a functional module and returns its parameters and buffers. - - Unlike functorch.make_functional_with_buffers, this method supports lazy modules. - - Args: - clone (bool, optional): if True, a clone of the module is created before it is returned. - This is useful as it prevents the original module to be scraped off of its - parameters and buffers. - Defaults to True - native (bool, optional): if True, TorchRL's functional modules will be used. - Defaults to True - - Returns: - A tuple of parameter and buffer tuples - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> lazy_module1 = nn.LazyLinear(4) - >>> lazy_module2 = nn.LazyLinear(3) - >>> spec1 = NdUnboundedContinuousTensorSpec(18) - >>> spec2 = NdUnboundedContinuousTensorSpec(4) - >>> td_module1 = TensorDictModule(spec=spec1, module=lazy_module1, in_keys=["some_input"], out_keys=["hidden"]) - >>> td_module2 = TensorDictModule(spec=spec2, module=lazy_module2, in_keys=["hidden"], out_keys=["some_output"]) - >>> td_module = TensorDictSequential(td_module1, td_module2) - >>> _, (params, buffers) = td_module.make_functional_with_buffers() - >>> print(params[0].shape) # the lazy module has been initialized - torch.Size([4, 18]) - >>> print(td_module( - ... TensorDict({'some_input': torch.randn(18)}, batch_size=[]), - ... params=params, - ... buffers=buffers)) - TensorDict( - fields={ - some_input: Tensor(torch.Size([18]), dtype=torch.float32), - hidden: Tensor(torch.Size([4]), dtype=torch.float32), - some_output: Tensor(torch.Size([3]), dtype=torch.float32)}, - batch_size=torch.Size([]), - device=cpu, - is_shared=False) - - """ - native = native or not _has_functorch - if clone: - self_copy = deepcopy(self) - self_copy.module = copy(self_copy.module) - else: - self_copy = self - params = [] if not native else TensorDict({}, []) - buffers = [] if not native else TensorDict({}, []) - for i, module in enumerate(self.module): - self_copy.module[i], ( - _params, - _buffers, - ) = module.make_functional_with_buffers(clone=True, native=native) - if native or not _has_functorch: - params[str(i)] = _params - buffers[str(i)] = _buffers - else: - params.extend(_params) - buffers.extend(_buffers) - return self_copy, (params, buffers) - - def get_dist( - self, - tensordict: TensorDictBase, - **kwargs, - ) -> Tuple[torch.distributions.Distribution, ...]: - L = len(self.module) - - if isinstance(self.module[-1], ProbabilisticTensorDictModule): - if "params" in kwargs and "buffers" in kwargs: - params = kwargs["params"] - buffers = kwargs["buffers"] - if isinstance(params, TensorDictBase): - param_splits = list(zip(*sorted(list(params.items()))))[1] - buffer_splits = list(zip(*sorted(list(buffers.items()))))[1] - else: - param_splits = self._split_param(kwargs["params"], "params") - buffer_splits = self._split_param(kwargs["buffers"], "buffers") - kwargs_pruned = { - key: item - for key, item in kwargs.items() - if key not in ("params", "buffers") - } - for i, (module, param, buffer) in enumerate( - zip(self.module, param_splits, buffer_splits) - ): - if "vmap" in kwargs_pruned and i > 0: - # the tensordict is already expended - kwargs_pruned["vmap"] = (0, 0, *(0,) * len(module.in_keys)) - if i < L - 1: - tensordict = module( - tensordict, params=param, buffers=buffer, **kwargs_pruned - ) - else: - out = module.get_dist( - tensordict, params=param, buffers=buffer, **kwargs_pruned - ) - - elif "params" in kwargs: - params = kwargs["params"] - if isinstance(params, TensorDictBase): - param_splits = list(zip(*sorted(list(params.items()))))[1] - else: - param_splits = self._split_param(kwargs["params"], "params") - kwargs_pruned = { - key: item for key, item in kwargs.items() if key not in ("params",) - } - for i, (module, param) in enumerate(zip(self.module, param_splits)): - if "vmap" in kwargs_pruned and i > 0: - # the tensordict is already expended - kwargs_pruned["vmap"] = (0, *(0,) * len(module.in_keys)) - if i < L - 1: - tensordict = module(tensordict, params=param, **kwargs_pruned) - else: - out = module.get_dist(tensordict, params=param, **kwargs_pruned) - - elif not len(kwargs): - for i, module in enumerate(self.module): - if i < L - 1: - tensordict = module(tensordict) - else: - out = module.get_dist(tensordict) - else: - raise RuntimeError( - "TensorDictSequential does not support keyword arguments other than 'params', 'buffers' and 'vmap'" - ) - - return out - else: - raise RuntimeError( - "Cannot call get_dist on a sequence of tensordicts that does not end with a probabilistic TensorDict" - ) From 1998fdebc8b57ceb68397c8a0591791eab2fd1f2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 15:45:12 +0000 Subject: [PATCH 07/14] Adopt tensordict.nn.utils where possible --- torchrl/modules/utils/__init__.py | 4 ++- torchrl/modules/utils/mappings.py | 42 +++---------------------------- 2 files changed, 6 insertions(+), 40 deletions(-) diff --git a/torchrl/modules/utils/__init__.py b/torchrl/modules/utils/__init__.py index 0bbf5182e08..94ca930f536 100644 --- a/torchrl/modules/utils/__init__.py +++ b/torchrl/modules/utils/__init__.py @@ -3,4 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .mappings import mappings, inv_softplus, biased_softplus +from .mappings import biased_softplus, inv_softplus, mappings + +__all__ = ["biased_softplus", "inv_softplus", "mappings"] diff --git a/torchrl/modules/utils/mappings.py b/torchrl/modules/utils/mappings.py index 1af962e857f..406c79b304b 100644 --- a/torchrl/modules/utils/mappings.py +++ b/torchrl/modules/utils/mappings.py @@ -3,48 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Union +from typing import Callable import torch -from torch import nn +from tensordict.nn.utils import biased_softplus, inv_softplus - -def inv_softplus(bias: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - """Inverse softplus function. - - Args: - bias (float or tensor): the value to be softplus-inverted. - """ - is_tensor = True - if not isinstance(bias, torch.Tensor): - is_tensor = False - bias = torch.tensor(bias) - out = bias.expm1().clamp_min(1e-6).log() - if not is_tensor and out.numel() == 1: - return out.item() - return out - - -class biased_softplus(nn.Module): - """A biased softplus module. - - The bias indicates the value that is to be returned when a zero-tensor is - passed through the transform. - - Args: - bias (scalar): 'bias' of the softplus transform. If bias=1.0, then a _bias shift will be computed such that - softplus(0.0 + _bias) = bias. - min_val (scalar): minimum value of the transform. - default: 0.1 - """ - - def __init__(self, bias: float, min_val: float = 0.01): - super().__init__() - self.bias = inv_softplus(bias - min_val) - self.min_val = min_val - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.softplus(x + self.bias) + self.min_val +__all__ = ["biased_softplus", "expln", "inv_softplus", "mappings"] def expln(x): From 6e2f833351f06857919796a22dcbddbf7db4d6a2 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Wed, 23 Nov 2022 13:39:57 +0000 Subject: [PATCH 08/14] Rerun CI From 905f55c5626300ab0f15790498b2b54224254f37 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Thu, 24 Nov 2022 11:21:27 +0000 Subject: [PATCH 09/14] Delete tests duplicated from tensordict --- test/test_tensordictmodules.py | 96 ++-------------------------------- 1 file changed, 5 insertions(+), 91 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 5009461aeb2..b9e3f3863bf 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -873,20 +873,6 @@ def test_in_key_warning(self): nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] ) - def test_key_exclusion(self): - module1 = TensorDictModule( - nn.Linear(3, 4), in_keys=["key1", "key2"], out_keys=["foo1"] - ) - module2 = TensorDictModule( - nn.Linear(3, 4), in_keys=["key1", "key3"], out_keys=["key1"] - ) - module3 = TensorDictModule( - nn.Linear(3, 4), in_keys=["foo1", "key3"], out_keys=["key2"] - ) - seq = TensorDictSequential(module1, module2, module3) - assert set(seq.in_keys) == {"key1", "key2", "key3"} - assert set(seq.out_keys) == {"foo1", "key1", "key2"} - @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) @pytest.mark.parametrize("lazy", [True, False]) @@ -1656,53 +1642,6 @@ def test_vmap_probabilistic(self, safe, spec_type): elif safe and spec_type == "bounded": assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - @pytest.mark.parametrize("functional", [True, False]) - def test_submodule_sequence(self, functional): - td_module_1 = TensorDictModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = TensorDictModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = TensorDictSequential(td_module_1, td_module_2) - - if functional: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1, (params, buffers) = sub_seq_1.make_functional_with_buffers() - sub_seq_1( - td_1, - params=params, - buffers=buffers, - ) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2, (params, buffers) = sub_seq_2.make_functional_with_buffers() - sub_seq_2( - td_2, - params=params, - buffers=buffers, - ) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - else: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - @pytest.mark.parametrize("stack", [True, False]) @pytest.mark.parametrize("functional", [True, False]) def test_sequential_partial(self, stack, functional): @@ -1817,36 +1756,6 @@ def test_sequential_partial(self, stack, functional): assert "out" in td.keys() assert "b" in td.keys() - def test_subsequence_weight_update(self): - td_module_1 = TensorDictModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = TensorDictModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = TensorDictSequential(td_module_1, td_module_2) - - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - copy = sub_seq_1[0].module.weight.clone() - - opt = torch.optim.SGD(td_module.parameters(), lr=0.1) - opt.zero_grad() - td_1 = td_module(td_1) - td_1["out"].mean().backward() - opt.step() - - assert not torch.allclose(copy, sub_seq_1[0].module.weight) - assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight) - - if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) - def test_is_tensordict_compatible(): class MultiHeadLinear(nn.Module): @@ -1953,3 +1862,8 @@ def forward(self, in_1, in_2): ) assert set(ensured_module.in_keys) == {"x"} assert isinstance(ensured_module, TensorDictModule) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From aa8ac958b457eae87da845fc88cea027e58ed336 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 24 Nov 2022 13:24:25 +0000 Subject: [PATCH 10/14] minor --- torchrl/envs/utils.py | 2 +- torchrl/modules/__init__.py | 1 - torchrl/modules/tensordict_module/probabilistic.py | 2 +- torchrl/trainers/helpers/trainers.py | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d8299d02a6e..ef86b45b0c4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -5,8 +5,8 @@ import pkg_resources from tensordict.nn.probabilistic import ( # noqa - set_interaction_mode as set_exploration_mode, interaction_mode as exploration_mode, + set_interaction_mode as set_exploration_mode, ) from tensordict.tensordict import TensorDictBase diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8dd879f0c3c..b2f1f55c496 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -48,7 +48,6 @@ ActorValueOperator, AdditiveGaussianWrapper, DistributionalQValueActor, - TensorDictModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 44b19bb7c15..c0893c6bcea 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Sequence, Union, Type, Optional +from typing import Optional, Sequence, Type, Union from tensordict.nn import ( ProbabilisticTensorDictModule as _ProbabilisticTensorDictModule, diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 75c471f53d6..7ec5f982a32 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -15,7 +15,7 @@ from torchrl.collectors.collectors import _DataCollector from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase -from torchrl.modules import TensorDictModule, reset_noise +from torchrl.modules import reset_noise, TensorDictModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import TargetNetUpdater from torchrl.trainers.loggers import Logger From 8742ce9d852344c56b13613edffda74835be92b9 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 25 Nov 2022 09:23:55 +0000 Subject: [PATCH 11/14] Remove references to torchrl.modules.TensorDictWrapper --- docs/source/reference/modules.rst | 1 - torchrl/modules/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 5915c000f96..d2a2dbe6839 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -14,7 +14,6 @@ TensorDict modules TensorDictModule ProbabilisticTensorDictModule TensorDictSequential - TensorDictModuleWrapper Actor ProbabilisticActor ValueOperator diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index b2f1f55c496..0689b5e79e4 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -54,7 +54,6 @@ ProbabilisticTensorDictModule, QValueActor, TensorDictModule, - TensorDictModuleWrapper, TensorDictSequential, ValueOperator, WorldModelWrapper, From 62e35a348a6f83716a0e25b6a17e11339e78b9e3 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 25 Nov 2022 10:56:25 +0000 Subject: [PATCH 12/14] Rename TensorDictModule -> SafeModule --- README.md | 30 +-- docs/source/reference/envs.rst | 2 +- docs/source/reference/modules.rst | 6 +- test/smoke_test.py | 2 +- test/test_collector.py | 15 +- test/test_cost.py | 64 +++---- test/test_env.py | 20 +- test/test_exploration.py | 18 +- test/test_functorch.py | 34 ++-- test/test_modules.py | 10 +- test/test_tensordictmodules.py | 178 +++++++++--------- torchrl/collectors/collectors.py | 42 ++--- torchrl/envs/model_based/common.py | 8 +- torchrl/envs/model_based/dreamer.py | 6 +- torchrl/modules/__init__.py | 6 +- torchrl/modules/models/exploration.py | 8 +- torchrl/modules/models/model_based.py | 12 +- torchrl/modules/planners/cem.py | 6 +- torchrl/modules/planners/common.py | 6 +- torchrl/modules/tensordict_module/__init__.py | 6 +- torchrl/modules/tensordict_module/actors.py | 68 ++++--- torchrl/modules/tensordict_module/common.py | 54 +++--- .../modules/tensordict_module/exploration.py | 14 +- .../tensordict_module/probabilistic.py | 26 ++- torchrl/modules/tensordict_module/sequence.py | 42 ++--- .../modules/tensordict_module/world_models.py | 25 +-- torchrl/objectives/a2c.py | 12 +- torchrl/objectives/common.py | 8 +- torchrl/objectives/ddpg.py | 10 +- torchrl/objectives/deprecated.py | 10 +- torchrl/objectives/dreamer.py | 18 +- torchrl/objectives/ppo.py | 22 +-- torchrl/objectives/redq.py | 10 +- torchrl/objectives/reinforce.py | 6 +- torchrl/objectives/sac.py | 10 +- torchrl/objectives/utils.py | 4 +- torchrl/objectives/value/advantages.py | 14 +- torchrl/trainers/helpers/collectors.py | 10 +- torchrl/trainers/helpers/models.py | 102 +++++----- torchrl/trainers/helpers/trainers.py | 10 +- torchrl/trainers/trainers.py | 4 +- 41 files changed, 454 insertions(+), 504 deletions(-) diff --git a/README.md b/README.md index 194472a3736..9ae69a24cda 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ algorithms. For instance, here's how to code a rollout in TorchRL: ```diff - obs, done = env.reset() + tensordict = env.reset() - policy = TensorDictModule( + policy = SafeModule( model, in_keys=["observation_pixels", "observation_vector"], out_keys=["action"], @@ -106,14 +106,14 @@ Here's another example of an off-policy training loop in TorchRL (assuming that Check our TorchRL-specific [TensorDict tutorial](tutorials/tensordict.ipynb) for more information. -The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible! - +The associated [`SafeModule` class](torchrl/modules/tensordict_module/common.py) which is [functorch](https://github.com/pytorch/functorch)-compatible! +
Code ```diff transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) - + td_module = TensorDictModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"]) + + td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"]) src = torch.rand((10, 32, 512)) tgt = torch.rand((20, 32, 512)) + tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32]) @@ -122,19 +122,19 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm + out = tensordict["out"] ``` - The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way. + The `SafeSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way. For instance, here is an implementation of a transformer using the encoder and decoder blocks: ```python encoder_module = TransformerEncoder(...) - encoder = TensorDictModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"]) + encoder = SafeModule(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"]) decoder_module = TransformerDecoder(...) - decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"]) - transformer = TensorDictSequential(encoder, decoder) + decoder = SafeModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"]) + transformer = SafeSequential(encoder, decoder) assert transformer.in_keys == ["src", "src_mask", "tgt"] assert transformer.out_keys == ["memory", "output"] ``` - `TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys: + `SafeSequential` allows to isolate subgraphs by querying a set of desired input / output keys: ```python transformer.select_subsequence(out_keys=["memory"]) # returns the encoder transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder @@ -261,9 +261,9 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm kernel_sizes=[8, 4, 3], strides=[4, 2, 1], ) - # Wrap it in a TensorDictModule, indicating what key to read in and where to + # Wrap it in a SafeModule, indicating what key to read in and where to # write out the output - common_module = TensorDictModule( + common_module = SafeModule( common_module, in_keys=["pixels"], out_keys=["hidden"], @@ -277,10 +277,10 @@ The associated [`TensorDictModule` class](torchrl/modules/tensordict_module/comm activation=nn.ELU, ) ) - # Wrap the nn.Module in a ProbabilisticTensorDictModule, indicating how + # Wrap the nn.Module in a SafeProbabilisticModule, indicating how # to build the torch.distribution.Distribution object and what to do with it - policy_module = ProbabilisticTensorDictModule( # stochastic policy - TensorDictModule( + policy_module = SafeProbabilisticModule( # stochastic policy + SafeModule( policy_module, in_keys=["hidden"], out_keys=["loc", "scale"], @@ -409,7 +409,7 @@ pip3 install torchrl This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one should install the library locally (see below). -The **nightly build** can be installed via +The **nightly build** can be installed via ``` pip install torchrl-nightly ``` diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index ad84d38bdc1..748f65d6b68 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -50,7 +50,7 @@ With these, the following methods are implemented: having reproducible results. - :obj:`env.rollout(max_steps, policy)`: executes a rollout in the environment for a maximum number of steps :obj:`max_steps` and using a policy :obj:`policy`. - The policy should be coded using a :obj:`TensorDictModule` (or any other + The policy should be coded using a :obj:`SafeModule` (or any other :obj:`TensorDict`-compatible module). diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index d2a2dbe6839..bf0992be77b 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -11,9 +11,9 @@ TensorDict modules :toctree: generated/ :template: rl_template_noinherit.rst - TensorDictModule - ProbabilisticTensorDictModule - TensorDictSequential + SafeModule + SafeProbabilisticModule + SafeSequential Actor ProbabilisticActor ValueOperator diff --git a/test/smoke_test.py b/test/smoke_test.py index 630171d4082..f0db69def86 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -6,5 +6,5 @@ def test_imports(): ) # noqa: F401 from torchrl.envs import Transform, TransformedEnv # noqa: F401 from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 - from torchrl.modules import TensorDictModule # noqa: F401 + from torchrl.modules import SafeModule # noqa: F401 from torchrl.objectives.common import LossModule # noqa: F401 diff --git a/test/test_collector.py b/test/test_collector.py index 9d384665106..4b8b70d8444 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -35,12 +35,7 @@ from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv from torchrl.envs.libs.gym import _has_gym, GymEnv from torchrl.envs.transforms import TransformedEnv, VecNorm -from torchrl.modules import ( - Actor, - LSTMNet, - OrnsteinUhlenbeckProcessWrapper, - TensorDictModule, -) +from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule # torch.set_default_dtype(torch.double) @@ -754,7 +749,7 @@ def create_env(): return ContinuousActionVecMockEnv() n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] - policy = TensorDictModule( + policy = SafeModule( torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] ) policy(create_env().reset()) @@ -898,7 +893,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec), ) - policy = TensorDictModule(**policy_kwargs) + policy = SafeModule(**policy_kwargs) env_maker = lambda: GymEnv(PENDULUM_VERSIONED) @@ -985,12 +980,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): if collector_class is not SyncDataCollector: assert all( - isinstance(p, TensorDictModule) for p in collector._policy_dict.values() + isinstance(p, SafeModule) for p in collector._policy_dict.values() ) assert all(p.out_keys == out_keys for p in collector._policy_dict.values()) assert all(p.module is policy for p in collector._policy_dict.values()) else: - assert isinstance(collector.policy, TensorDictModule) + assert isinstance(collector.policy, SafeModule) assert collector.policy.out_keys == out_keys assert collector.policy.module is policy diff --git a/test/test_cost.py b/test/test_cost.py index 820f90a6da5..83bf01f49cb 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -41,10 +41,10 @@ from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv from torchrl.modules import ( DistributionalQValueActor, - ProbabilisticTensorDictModule, QValueActor, - TensorDictModule, - TensorDictSequential, + SafeModule, + SafeProbabilisticModule, + SafeSequential, WorldModelWrapper, ) from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal @@ -777,9 +777,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( spec=CompositeSpec(action=action_spec, loc=None, scale=None), module=module, @@ -1096,9 +1094,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -1151,13 +1147,9 @@ def __init__(self): def forward(self, hidden, act): return self.linear(torch.cat([hidden, act], -1)) - common = TensorDictModule( - CommonClass(), in_keys=["observation"], out_keys=["hidden"] - ) + common = SafeModule(CommonClass(), in_keys=["observation"], out_keys=["hidden"]) actor_subnet = ProbabilisticActor( - TensorDictModule( - ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"] - ), + SafeModule(ActorClass(), in_keys=["hidden"], out_keys=["loc", "scale"]), dist_in_keys=["loc", "scale"], distribution_class=TanhNormal, return_log_prob=True, @@ -1528,9 +1520,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -1763,9 +1753,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor = ProbabilisticActor( module=module, distribution_class=TanhNormal, @@ -1989,9 +1977,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): gamma = 0.9 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) actor_net = ProbabilisticActor( module, distribution_class=TanhNormal, @@ -2138,7 +2124,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 # World Model and reward model rssm_rollout = RSSMRollout( - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -2148,7 +2134,7 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 ("next", "belief"), ], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ @@ -2162,20 +2148,20 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20 out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU ) # World Model and reward model - world_modeler = TensorDictSequential( - TensorDictModule( + world_modeler = SafeSequential( + SafeModule( obs_encoder, in_keys=[("next", "pixels")], out_keys=[("next", "encoded_latents")], ), rssm_rollout, - TensorDictModule( + SafeModule( obs_decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "reco_pixels")], ), ) - reward_module = TensorDictModule( + reward_module = SafeModule( reward_module, in_keys=[("next", "state"), ("next", "belief")], out_keys=["reward"], @@ -2209,8 +2195,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): reward_module = MLP( out_features=1, depth=2, num_cells=mlp_num_units, activation_class=nn.ELU ) - transition_model = TensorDictSequential( - TensorDictModule( + transition_model = SafeSequential( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -2221,7 +2207,7 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200): ], ), ) - reward_model = TensorDictModule( + reward_model = SafeModule( reward_module, in_keys=["state", "belief"], out_keys=["reward"], @@ -2255,8 +2241,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): num_cells=mlp_num_units, activation_class=nn.ELU, ) - actor_model = ProbabilisticTensorDictModule( - TensorDictModule( + actor_model = SafeProbabilisticModule( + SafeModule( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], @@ -2278,7 +2264,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): return actor_model def _create_value_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200): - value_model = TensorDictModule( + value_model = SafeModule( MLP( out_features=1, depth=3, @@ -2380,7 +2366,7 @@ def test_dreamer_env(self, device, imagination_horizon, discount_loss): # test reconstruction with pytest.raises(ValueError, match="No observation decoder provided"): mb_env.decode_obs(rollout) - mb_env.obs_decoder = TensorDictModule( + mb_env.obs_decoder = SafeModule( nn.LazyLinear(4, device=device), in_keys=["state"], out_keys=["reco_observation"], @@ -2896,13 +2882,13 @@ def test_shared_params(dest, expected_dtype, expected_device): if torch.cuda.device_count() == 0 and dest == "cuda": pytest.skip("no cuda device available") module_hidden = torch.nn.Linear(4, 4) - td_module_hidden = TensorDictModule( + td_module_hidden = SafeModule( module=module_hidden, spec=None, in_keys=["observation"], out_keys=["hidden"], ) - module_action = TensorDictModule( + module_action = SafeModule( NormalParamWrapper(torch.nn.Linear(4, 8)), in_keys=["hidden"], out_keys=["loc", "scale"], diff --git a/test/test_env.py b/test/test_env.py index 97b1cd5f8e8..c4379ec203d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -46,13 +46,7 @@ ) from torchrl.envs.utils import step_mdp from torchrl.envs.vec_env import ParallelEnv, SerialEnv -from torchrl.modules import ( - Actor, - ActorCriticOperator, - MLP, - TensorDictModule, - ValueOperator, -) +from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator from torchrl.modules.tensordict_module import WorldModelWrapper gym_version = None @@ -305,12 +299,12 @@ def test_mb_rollout(self, device, seed=0): torch.manual_seed(seed) np.random.seed(seed) world_model = WorldModelWrapper( - TensorDictModule( + SafeModule( ActionObsMergeLinear(5, 4), in_keys=["hidden_observation", "action"], out_keys=["hidden_observation"], ), - TensorDictModule( + SafeModule( nn.Linear(4, 1), in_keys=["hidden_observation"], out_keys=["reward"], @@ -331,12 +325,12 @@ def test_mb_env_batch_lock(self, device, seed=0): torch.manual_seed(seed) np.random.seed(seed) world_model = WorldModelWrapper( - TensorDictModule( + SafeModule( ActionObsMergeLinear(5, 4), in_keys=["hidden_observation", "action"], out_keys=["hidden_observation"], ), - TensorDictModule( + SafeModule( nn.Linear(4, 1), in_keys=["hidden_observation"], out_keys=["reward"], @@ -551,13 +545,13 @@ def test_parallel_env_with_policy( ) policy = ActorCriticOperator( - TensorDictModule( + SafeModule( spec=None, module=nn.LazyLinear(12), in_keys=["observation"], out_keys=["hidden"], ), - TensorDictModule( + SafeModule( spec=None, module=nn.LazyLinear(env0.action_spec.shape[-1]), in_keys=["hidden"], diff --git a/test/test_exploration.py b/test/test_exploration.py index 40c3be7d78b..17f64983865 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,7 +14,7 @@ from torchrl.data import CompositeSpec, NdBoundedTensorSpec from torchrl.envs.transforms.transforms import gSDENoise from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import TensorDictModule, TensorDictSequential +from torchrl.modules import SafeModule, SafeSequential from torchrl.modules.distributions import TanhNormal from torchrl.modules.distributions.continuous import ( IndependentNormal, @@ -60,7 +60,7 @@ def test_ou(device, seed=0): def test_ou_wrapper(device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) - module = TensorDictModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = NdBoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( spec=action_spec, @@ -112,7 +112,7 @@ def test_additivegaussian_sd( (d_act,), device=device, ) - module = TensorDictModule( + module = SafeModule( net, in_keys=["observation"], out_keys=["loc", "scale"], @@ -172,9 +172,7 @@ def test_additivegaussian_wrapper( ): torch.manual_seed(seed) net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) + module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = NdBoundedTensorSpec( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), @@ -229,9 +227,9 @@ def test_gsde( if gSDE: model = torch.nn.LazyLinear(action_dim, device=device) in_keys = ["observation"] - module = TensorDictSequential( - TensorDictModule(model, in_keys=in_keys, out_keys=["action"]), - TensorDictModule( + module = SafeSequential( + SafeModule(model, in_keys=in_keys, out_keys=["action"]), + SafeModule( LazygSDEModule(device=device), in_keys=["action", "observation", "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -243,7 +241,7 @@ def test_gsde( in_keys = ["observation"] model = torch.nn.LazyLinear(action_dim * 2, device=device) wrapper = NormalParamWrapper(model) - module = TensorDictModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) + module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"min": -bound, "max": bound} spec = NdBoundedTensorSpec( diff --git a/test/test_functorch.py b/test/test_functorch.py index 95cea41b97a..7b043968afb 100644 --- a/test/test_functorch.py +++ b/test/test_functorch.py @@ -15,7 +15,7 @@ FunctionalModuleWithBuffers, ) from torch import nn -from torchrl.modules import TensorDictModule, TensorDictSequential +from torchrl.modules import SafeModule, SafeSequential @pytest.mark.skipif( @@ -77,7 +77,7 @@ def test_vmap_tdmodule(moduletype, batch_params): raise NotImplementedError if moduletype == "linear": fmodule, params = FunctionalModule._create_from(module) - tdmodule = TensorDictModule(fmodule, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) x = torch.randn(10, 1, 3) td = TensorDict({"x": x}, [10]) if batch_params: @@ -89,7 +89,7 @@ def test_vmap_tdmodule(moduletype, batch_params): assert y.shape == torch.Size([10, 1, 4]) elif moduletype == "bn1": fmodule, params, buffers = FunctionalModuleWithBuffers._create_from(module) - tdmodule = TensorDictModule(fmodule, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(fmodule, in_keys=["x"], out_keys=["y"]) x = torch.randn(10, 2, 3) td = TensorDict({"x": x}, [10]) if batch_params: @@ -121,7 +121,7 @@ def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): else: raise NotImplementedError if moduletype == "linear": - tdmodule = TensorDictModule(module, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) x = torch.randn(10, 1, 3) td = TensorDict({"x": x}, [10]) @@ -134,7 +134,7 @@ def test_vmap_tdmodule_nativebuilt(moduletype, batch_params): y = td["y"] assert y.shape == torch.Size([10, 1, 4]) elif moduletype == "bn1": - tdmodule = TensorDictModule(module, in_keys=["x"], out_keys=["y"]) + tdmodule = SafeModule(module, in_keys=["x"], out_keys=["y"]) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) x = torch.randn(10, 2, 3) td = TensorDict({"x": x}, [10]) @@ -173,10 +173,10 @@ def test_vmap_tdsequence(moduletype, batch_params): else: raise NotImplementedError if moduletype == "linear": - tdmodule1 = TensorDictModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(fmodule2, in_keys=["y"], out_keys=["z"]) + tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) params = TensorDict({"0": params1, "1": params2}, []) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) assert {"0", "1"} == set(params.keys()) x = torch.randn(10, 1, 3) td = TensorDict({"x": x}, [10]) @@ -188,11 +188,11 @@ def test_vmap_tdsequence(moduletype, batch_params): z = td["z"] assert z.shape == torch.Size([10, 1, 5]) elif moduletype == "bn1": - tdmodule1 = TensorDictModule(fmodule1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(fmodule2, in_keys=["y"], out_keys=["z"]) + tdmodule1 = SafeModule(fmodule1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(fmodule2, in_keys=["y"], out_keys=["z"]) params = TensorDict({"0": params1, "1": params2}, []) buffers = TensorDict({"0": buffers1, "1": buffers2}, []) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) assert {"0", "1"} == set(params.keys()) assert {"0", "1"} == set(buffers.keys()) x = torch.randn(10, 2, 3) @@ -228,9 +228,9 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): else: raise NotImplementedError if moduletype == "linear": - tdmodule1 = TensorDictModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) + tdmodule = SafeSequential(tdmodule1, tdmodule2) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) assert {"0", "1"} == set(params.keys()) x = torch.randn(10, 1, 3) @@ -244,9 +244,9 @@ def test_vmap_tdsequence_nativebuilt(moduletype, batch_params): z = td["z"] assert z.shape == torch.Size([10, 1, 5]) elif moduletype == "bn1": - tdmodule1 = TensorDictModule(module1, in_keys=["x"], out_keys=["y"]) - tdmodule2 = TensorDictModule(module2, in_keys=["y"], out_keys=["z"]) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule1 = SafeModule(module1, in_keys=["x"], out_keys=["y"]) + tdmodule2 = SafeModule(module2, in_keys=["y"], out_keys=["z"]) + tdmodule = SafeSequential(tdmodule1, tdmodule2) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers(native=True) assert {"0", "1"} == set(params.keys()) assert {"0", "1"} == set(buffers.keys()) diff --git a/test/test_modules.py b/test/test_modules.py index 6284c91fab6..3a83f48c18c 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -27,7 +27,7 @@ LSTMNet, ProbabilisticActor, QValueActor, - TensorDictModule, + SafeModule, ValueOperator, ) from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear @@ -259,10 +259,10 @@ def make_net(): @pytest.mark.parametrize("device", get_available_devices()) def test_actorcritic(device): - common_module = TensorDictModule( + common_module = SafeModule( spec=None, module=nn.Linear(3, 4), in_keys=["obs"], out_keys=["hidden"] ).to(device) - module = TensorDictModule(nn.Linear(4, 5), in_keys=["hidden"], out_keys=["param"]) + module = SafeModule(nn.Linear(4, 5), in_keys=["hidden"], out_keys=["param"]) policy_operator = ProbabilisticActor( spec=None, module=module, dist_in_keys=["param"], return_log_prob=True ).to(device) @@ -613,7 +613,7 @@ def test_rssm_rollout( ).to(device) rssm_rollout = RSSMRollout( - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -623,7 +623,7 @@ def test_rssm_rollout( ("next", "belief"), ], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 919fa3b73be..60f513ae213 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -30,15 +30,13 @@ NdUnboundedContinuousTensorSpec, ) from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import NormalParamWrapper, TanhNormal, TensorDictModule +from torchrl.modules import NormalParamWrapper, SafeModule, TanhNormal from torchrl.modules.tensordict_module.common import ( ensure_tensordict_compatible, is_tensordict_compatible, ) -from torchrl.modules.tensordict_module.probabilistic import ( - ProbabilisticTensorDictModule, -) -from torchrl.modules.tensordict_module.sequence import TensorDictSequential +from torchrl.modules.tensordict_module.probabilistic import SafeProbabilisticModule +from torchrl.modules.tensordict_module.sequence import SafeSequential class TestTDModule: @@ -53,7 +51,7 @@ def __init__(self, in_1, out_1, out_2, out_3): def forward(self, x): return self.linear_1(x), self.linear_2(x), self.linear_3(x) - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["input"], out_keys=["out_1", "out_2", "out_3"], @@ -68,7 +66,7 @@ def forward(self, x): assert td.get("out_3").shape == torch.Size([3, 2]) # Using "_" key to ignore some output - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["input"], out_keys=["_", "_", "out_3"], @@ -98,7 +96,7 @@ def forward(self, x): # warning due to "_" in spec keys with pytest.warns(UserWarning, match='got a spec with key "_"'): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( MultiHeadLinear(5, 4, 3), in_keys=["input"], out_keys=["_", "out_2"], @@ -129,7 +127,7 @@ def test_stateful(self, safe, spec_type, lazy): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( module=net, spec=spec, in_keys=["in"], @@ -138,7 +136,7 @@ def test_stateful(self, safe, spec_type, lazy): ) return else: - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( module=net, spec=spec, in_keys=["in"], @@ -171,7 +169,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) net = nn.Linear(3, 4 * param_multiplier) in_keys = ["in"] - net = TensorDictModule( + net = SafeModule( module=NormalParamWrapper(net), spec=None, in_keys=in_keys, @@ -206,7 +204,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=net, spec=spec, dist_in_keys=dist_in_keys, @@ -216,7 +214,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) ) return else: - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=net, spec=spec, dist_in_keys=dist_in_keys, @@ -260,7 +258,7 @@ def test_functional(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -269,7 +267,7 @@ def test_functional(self, safe, spec_type): ) return else: - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -298,7 +296,7 @@ def test_functional_probabilistic(self, safe, spec_type): in_keys = ["in"] net = NormalParamWrapper(net) fnet, params = make_functional(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -322,7 +320,7 @@ def test_functional_probabilistic(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -332,7 +330,7 @@ def test_functional_probabilistic(self, safe, spec_type): ) return else: - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -361,7 +359,7 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) in_keys = ["in"] net = NormalParamWrapper(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -385,7 +383,7 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -395,7 +393,7 @@ def test_functional_probabilistic_laterconstruct(self, safe, spec_type): ) return else: - tensordict_module = ProbabilisticTensorDictModule( + tensordict_module = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -442,7 +440,7 @@ def test_functional_with_buffer(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -451,7 +449,7 @@ def test_functional_with_buffer(self, safe, spec_type): ) return else: - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -480,7 +478,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): in_keys = ["in"] net = NormalParamWrapper(net) fnet, params, buffers = make_functional_with_buffers(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -504,7 +502,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -514,7 +512,7 @@ def test_functional_with_buffer_probabilistic(self, safe, spec_type): ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -543,7 +541,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_ty net = nn.BatchNorm1d(32 * param_multiplier) in_keys = ["in"] net = NormalParamWrapper(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -567,7 +565,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_ty match="is not a valid configuration as the tensor specs are not " "specified", ): - ProbabilisticTensorDictModule( + SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -577,7 +575,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self, safe, spec_ty ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -624,7 +622,7 @@ def test_vmap(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -633,7 +631,7 @@ def test_vmap(self, safe, spec_type): ) return else: - tdmodule = TensorDictModule( + tdmodule = SafeModule( spec=spec, module=fnet, in_keys=["in"], @@ -690,7 +688,7 @@ def test_vmap_probabilistic(self, safe, spec_type): net = NormalParamWrapper(net) in_keys = ["in"] fnet, params = make_functional(net) - tdnet = TensorDictModule( + tdnet = SafeModule( module=fnet, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -714,7 +712,7 @@ def test_vmap_probabilistic(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -724,7 +722,7 @@ def test_vmap_probabilistic(self, safe, spec_type): ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -781,7 +779,7 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): net = nn.Linear(3, 4 * param_multiplier) net = NormalParamWrapper(net) in_keys = ["in"] - tdnet = TensorDictModule( + tdnet = SafeModule( module=net, spec=None, in_keys=in_keys, out_keys=["loc", "scale"] ) @@ -805,7 +803,7 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): match="is not a valid configuration as the tensor specs are not " "specified", ): - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -815,7 +813,7 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): ) return else: - tdmodule = ProbabilisticTensorDictModule( + tdmodule = SafeProbabilisticModule( module=tdnet, spec=spec, dist_in_keys=["loc", "scale"], @@ -865,11 +863,11 @@ def test_vmap_probabilistic_laterconstruct(self, safe, spec_type): class TestTDSequence: def test_in_key_warning(self): with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] ) with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - tensordict_module = TensorDictModule( + tensordict_module = SafeModule( nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] ) @@ -900,21 +898,21 @@ def test_stateful(self, safe, spec_type, lazy): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( spec=spec, module=net2, in_keys=["hidden"], @@ -922,7 +920,7 @@ def test_stateful(self, safe, spec_type, lazy): safe=False, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -967,9 +965,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) + net2 = SafeModule(module=net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -988,21 +984,21 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( dummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( spec=spec, module=net2, dist_in_keys=["loc", "scale"], @@ -1010,7 +1006,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): safe=False, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1068,24 +1064,24 @@ def test_functional(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( fnet2, spec=spec, in_keys=["hidden"], out_keys=["out"], safe=safe, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1129,9 +1125,7 @@ def test_functional_probabilistic(self, safe, spec_type): fnet1, params1 = make_functional(net1) fdummy_net, _ = make_functional(dummy_net) fnet2, params2 = make_functional(net2) - fnet2 = TensorDictModule( - module=fnet2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) + fnet2 = SafeModule(module=fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) if isinstance(params1, TensorDictBase): params = TensorDict({"0": params1, "1": params2}, []) else: @@ -1154,17 +1148,17 @@ def test_functional_probabilistic(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, dist_in_keys=["loc", "scale"], @@ -1172,7 +1166,7 @@ def test_functional_probabilistic(self, safe, spec_type): safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1241,24 +1235,24 @@ def test_functional_with_buffer( if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( fnet2, spec=spec, in_keys=["hidden"], out_keys=["out"], safe=safe, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1309,8 +1303,8 @@ def test_functional_with_buffer_probabilistic( fnet1, params1, buffers1 = make_functional_with_buffers(net1) fdummy_net, _, _ = make_functional_with_buffers(dummy_net) # fnet2, params2, buffers2 = make_functional_with_buffers(net2) - # fnet2 = TensorDictModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) - net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) + # fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) + net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) fnet2, (params2, buffers2) = net2.make_functional_with_buffers() if isinstance(params1, TensorDictBase): @@ -1339,17 +1333,17 @@ def test_functional_with_buffer_probabilistic( if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, dist_in_keys=["loc", "scale"], @@ -1357,7 +1351,7 @@ def test_functional_with_buffer_probabilistic( safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1403,7 +1397,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct( nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) ) net2 = NormalParamWrapper(net2) - net2 = TensorDictModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) + net2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) if spec_type is None: spec = None @@ -1422,10 +1416,10 @@ def test_functional_with_buffer_probabilistic_laterconstruct( if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( net2, spec=spec, dist_in_keys=["loc", "scale"], @@ -1433,7 +1427,7 @@ def test_functional_with_buffer_probabilistic_laterconstruct( safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() @@ -1480,28 +1474,28 @@ def test_vmap(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - dummy_tdmodule = TensorDictModule( + dummy_tdmodule = SafeModule( fdummy_net, spec=None, in_keys=["hidden"], out_keys=["hidden"], safe=False, ) - tdmodule2 = TensorDictModule( + tdmodule2 = SafeModule( fnet2, spec=spec, in_keys=["hidden"], out_keys=["out"], safe=safe, ) - tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) assert hasattr(tdmodule, "__setitem__") assert len(tdmodule) == 3 @@ -1568,7 +1562,7 @@ def test_vmap_probabilistic(self, safe, spec_type): net2 = nn.Linear(4, 4 * param_multiplier) net2 = NormalParamWrapper(net2) fnet2, params2 = make_functional(net2) - fnet2 = TensorDictModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) + fnet2 = SafeModule(fnet2, in_keys=["hidden"], out_keys=["loc", "scale"]) params = params1 + params2 @@ -1589,14 +1583,14 @@ def test_vmap_probabilistic(self, safe, spec_type): if safe and spec is None: pytest.skip("safe and spec is None is checked elsewhere") else: - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, sample_out_key=["out"], @@ -1604,7 +1598,7 @@ def test_vmap_probabilistic(self, safe, spec_type): safe=safe, **kwargs, ) - tdmodule = TensorDictSequential(tdmodule1, tdmodule2) + tdmodule = SafeSequential(tdmodule1, tdmodule2) # vmap = True params = [p.repeat(10, *[1 for _ in p.shape]) for p in params] @@ -1662,7 +1656,7 @@ def test_sequential_partial(self, stack, functional): else: fnet2 = net2 params2 = None - fnet2 = TensorDictModule(fnet2, in_keys=["b"], out_keys=["loc", "scale"]) + fnet2 = SafeModule(fnet2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) net3 = NormalParamWrapper(net3) @@ -1671,21 +1665,21 @@ def test_sequential_partial(self, stack, functional): else: fnet3 = net3 params3 = None - fnet3 = TensorDictModule(fnet3, in_keys=["c"], out_keys=["loc", "scale"]) + fnet3 = SafeModule(fnet3, in_keys=["c"], out_keys=["loc", "scale"]) spec = NdBoundedTensorSpec(-0.1, 0.1, 4) spec = CompositeSpec(out=spec, loc=None, scale=None) kwargs = {"distribution_class": TanhNormal} - tdmodule1 = TensorDictModule( + tdmodule1 = SafeModule( fnet1, spec=None, in_keys=["a"], out_keys=["hidden"], safe=False, ) - tdmodule2 = ProbabilisticTensorDictModule( + tdmodule2 = SafeProbabilisticModule( fnet2, spec=spec, sample_out_key=["out"], @@ -1693,7 +1687,7 @@ def test_sequential_partial(self, stack, functional): safe=True, **kwargs, ) - tdmodule3 = ProbabilisticTensorDictModule( + tdmodule3 = SafeProbabilisticModule( fnet3, spec=spec, sample_out_key=["out"], @@ -1701,7 +1695,7 @@ def test_sequential_partial(self, stack, functional): safe=True, **kwargs, ) - tdmodule = TensorDictSequential( + tdmodule = SafeSequential( tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) @@ -1768,7 +1762,7 @@ def __init__(self, in_1, out_1, out_2, out_3): def forward(self, x): return self.linear_1(x), self.linear_2(x), self.linear_3(x) - td_module = TensorDictModule( + td_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["in_1", "in_2"], out_keys=["out_1", "out_2"], @@ -1823,7 +1817,7 @@ def __init__(self, in_1, out_1, out_2, out_3): def forward(self, x): return self.linear_1(x), self.linear_2(x), self.linear_3(x) - td_module = TensorDictModule( + td_module = SafeModule( MultiHeadLinear(5, 4, 3, 2), in_keys=["in_1", "in_2"], out_keys=["out_1", "out_2"], @@ -1861,7 +1855,7 @@ def forward(self, in_1, in_2): out_keys=["out_1", "out_2", "out_3"], ) assert set(ensured_module.in_keys) == {"x"} - assert isinstance(ensured_module, TensorDictModule) + assert isinstance(ensured_module, SafeModule) if __name__ == "__main__": diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 72d63332ca9..82dd40cd7d9 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -29,7 +29,7 @@ from ..data.utils import CloudpickleWrapper, DEVICE_TYPING from ..envs.common import EnvBase from ..envs.vec_env import _BatchedEnv -from ..modules.tensordict_module import ProbabilisticTensorDictModule, TensorDictModule +from ..modules.tensordict_module import SafeModule, SafeProbabilisticModule from .utils import split_trajectories _TIMEOUT = 1.0 @@ -84,21 +84,21 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: def _policy_is_tensordict_compatible(policy: nn.Module): sig = inspect.signature(policy.forward) - if isinstance(policy, TensorDictModule) or ( + if isinstance(policy, SafeModule) or ( len(sig.parameters) == 1 and hasattr(policy, "in_keys") and hasattr(policy, "out_keys") ): - # if the policy is a TensorDictModule or takes a single argument and defines + # if the policy is a SafeModule or takes a single argument and defines # in_keys and out_keys then we assume it can already deal with TensorDict input # to forward and we return True return True elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): - # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # if it's not a SafeModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. return False - # if in_keys or out_keys were defined but policy is not a TensorDictModule or + # if in_keys or out_keys were defined but policy is not a SafeModule or # accepts multiple arguments then it's likely the user is trying to do something # that will have undetermined behaviour, we raise an error raise TypeError( @@ -107,7 +107,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module): "should take a single argument of type TensorDict to policy.forward and define " "both in_keys and out_keys. Alternatively, policy.forward can accept " "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." + "TorchRL will attempt to automatically wrap the policy with a SafeModule." ) @@ -116,15 +116,13 @@ def _get_policy_and_device( self, policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, device: Optional[DEVICE_TYPING] = None, observation_spec: TensorSpec = None, - ) -> Tuple[ - ProbabilisticTensorDictModule, torch.device, Union[None, Callable[[], dict]] - ]: + ) -> Tuple[SafeProbabilisticModule, torch.device, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. From a policy and a device, assigns the self.device attribute to @@ -135,7 +133,7 @@ def _get_policy_and_device( create_env_fn (Callable or list of callables): an env creator function (or a list of creators) create_env_kwargs (dictionary): kwargs for the env creator - policy (ProbabilisticTensorDictModule, optional): a policy to be used + policy (SafeProbabilisticModule, optional): a policy to be used device (int, str or torch.device, optional): device where to place the policy observation_spec (TensorSpec, optional): spec of the observations @@ -163,13 +161,13 @@ def _get_policy_and_device( # callables should be supported as policies. if not _policy_is_tensordict_compatible(policy): # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with TensorDictModule + # so we attempt to auto-wrap policy with SafeModule if observation_spec is None: raise ValueError( "Unable to read observation_spec from the environment. This is " "required to check compatibility of the environment and policy " "since the policy is a nn.Module that operates on tensors " - "rather than a TensorDictModule or a nn.Module that accepts a " + "rather than a SafeModule or a nn.Module that accepts a " "TensorDict as input and defines in_keys and out_keys." ) sig = inspect.signature(policy.forward) @@ -183,18 +181,18 @@ def _get_policy_and_device( if isinstance(output, tuple): out_keys.extend(f"output{i+1}" for i in range(len(output) - 1)) - policy = TensorDictModule( + policy = SafeModule( policy, in_keys=list(sig.parameters), out_keys=out_keys ) else: raise TypeError( "Arguments to policy.forward are incompatible with entries in " "env.observation_spec. If you want TorchRL to automatically " - "wrap your policy with a TensorDictModule then the arguments " + "wrap your policy with a SafeModule then the arguments " "to policy.forward must correspond one-to-one with entries in " "env.observation_spec that are prefixed with 'next_'. For more " "complex behaviour and more control you can consider writing " - "your own TensorDictModule." + "your own SafeModule." ) try: @@ -307,7 +305,7 @@ def __init__( ], # noqa: F821 policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -520,7 +518,7 @@ def iterator(self) -> Iterator[TensorDictBase]: def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase: policy_device = self.device if hasattr(self.policy, "in_keys"): - # some keys may be absent -- TensorDictModule is resilient to missing keys + # some keys may be absent -- SafeModule is resilient to missing keys td = td.select(*self.policy.in_keys, strict=False) if self._td_policy is None: self._td_policy = td.to(policy_device) @@ -720,7 +718,7 @@ class _MultiDataCollector(_DataCollector): Args: create_env_fn (list of Callabled): list of Callables, each returning an instance of EnvBase - policy (Callable, optional): Instance of ProbabilisticTensorDictModule class. + policy (Callable, optional): Instance of SafeProbabilisticModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of frames may well be greater than this as the closing signals are sent to the @@ -779,7 +777,7 @@ def __init__( create_env_fn: Sequence[Callable[[], EnvBase]], policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -1306,7 +1304,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable, optional): Instance of ProbabilisticTensorDictModule class. + policy (Callable, optional): Instance of SafeProbabilisticModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of @@ -1361,7 +1359,7 @@ def __init__( create_env_fn: Callable[[], EnvBase], policy: Optional[ Union[ - ProbabilisticTensorDictModule, + SafeProbabilisticModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 128408dd6f2..1ff0cd03712 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -13,7 +13,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase -from torchrl.modules.tensordict_module import TensorDictModule +from torchrl.modules.tensordict_module import SafeModule class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): @@ -53,12 +53,12 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): >>> import torch.nn as nn >>> from torchrl.modules import MLP, WorldModelWrapper >>> world_model = WorldModelWrapper( - ... TensorDictModule( + ... SafeModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), - ... TensorDictModule( + ... SafeModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], @@ -114,7 +114,7 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta): def __init__( self, - world_model: TensorDictModule, + world_model: SafeModule, params: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None, device: DEVICE_TYPING = "cpu", diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 432682812b2..fb902d692f7 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -13,7 +13,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase -from torchrl.modules.tensordict_module import TensorDictModule +from torchrl.modules.tensordict_module import SafeModule class DreamerEnv(ModelBasedEnvBase): @@ -21,10 +21,10 @@ class DreamerEnv(ModelBasedEnvBase): def __init__( self, - world_model: TensorDictModule, + world_model: SafeModule, prior_shape: Tuple[int, ...], belief_shape: Tuple[int, ...], - obs_decoder: TensorDictModule = None, + obs_decoder: SafeModule = None, device: DEVICE_TYPING = "cpu", dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0689b5e79e4..8c6ca5d8593 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -51,10 +51,10 @@ EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, - ProbabilisticTensorDictModule, QValueActor, - TensorDictModule, - TensorDictSequential, + SafeModule, + SafeProbabilisticModule, + SafeSequential, ValueOperator, WorldModelWrapper, ) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index a18454fbbdf..7de0a7c9caa 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -264,18 +264,18 @@ class gSDEModule(nn.Module): Examples: >>> from tensordict import TensorDict - >>> from torchrl.modules import TensorDictModule, TensorDictSequential, ProbabilisticActor, TanhNormal + >>> from torchrl.modules import SafeModule, SafeSequential, ProbabilisticActor, TanhNormal >>> batch, state_dim, action_dim = 3, 7, 5 >>> model = nn.Linear(state_dim, action_dim) - >>> deterministic_policy = TensorDictModule(model, in_keys=["obs"], out_keys=["action"]) - >>> stochatstic_part = TensorDictModule( + >>> deterministic_policy = SafeModule(model, in_keys=["obs"], out_keys=["action"]) + >>> stochatstic_part = SafeModule( ... gSDEModule(action_dim, state_dim), ... in_keys=["action", "obs", "_eps_gSDE"], ... out_keys=["loc", "scale", "action", "_eps_gSDE"]) >>> stochatstic_part = ProbabilisticActor(stochatstic_part, ... dist_in_keys=["loc", "scale"], ... distribution_class=TanhNormal) - >>> stochatstic_policy = TensorDictSequential(deterministic_policy, stochatstic_part) + >>> stochatstic_policy = SafeSequential(deterministic_policy, stochatstic_part) >>> tensordict = TensorDict({'obs': torch.randn(state_dim), '_epx_gSDE': torch.zeros(1)}, []) >>> _ = stochatstic_policy(tensordict) >>> print(tensordict) diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 9b94ed4912b..064565ccc79 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -11,8 +11,8 @@ from torchrl.envs.utils import step_mdp from torchrl.modules.distributions import NormalParamWrapper from torchrl.modules.models.models import MLP -from torchrl.modules.tensordict_module.common import TensorDictModule -from torchrl.modules.tensordict_module.sequence import TensorDictSequential +from torchrl.modules.tensordict_module.common import SafeModule +from torchrl.modules.tensordict_module.sequence import SafeSequential class DreamerActor(nn.Module): @@ -151,15 +151,15 @@ class RSSMRollout(nn.Module): Reference: https://arxiv.org/abs/1811.04551 Args: - rssm_prior (TensorDictModule): Prior network. - rssm_posterior (TensorDictModule): Posterior network. + rssm_prior (SafeModule): Prior network. + rssm_posterior (SafeModule): Posterior network. """ - def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule): + def __init__(self, rssm_prior: SafeModule, rssm_posterior: SafeModule): super().__init__() - _module = TensorDictSequential(rssm_prior, rssm_posterior) + _module = SafeSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys self.rssm_prior = rssm_prior diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index d11c9ab12fd..dd69c8b4e16 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -47,7 +47,7 @@ class CEMPlanner(MPCPlannerBase): >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec >>> from torchrl.envs.model_based import ModelBasedEnvBase - >>> from torchrl.modules import TensorDictModule + >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) @@ -71,12 +71,12 @@ class CEMPlanner(MPCPlannerBase): >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( - ... TensorDictModule( + ... SafeModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), - ... TensorDictModule( + ... SafeModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 63ecba7991c..a9d1e4ca942 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -9,13 +9,13 @@ from tensordict.tensordict import TensorDictBase from torchrl.envs import EnvBase -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule -class MPCPlannerBase(TensorDictModule, metaclass=abc.ABCMeta): +class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta): """MPCPlannerBase abstract Module. - This class inherits from :obj:`TensorDictModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step. + This class inherits from :obj:`SafeModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step. At the end of the planning step, the :obj:`MPCPlanner` will return a proposed action. Args: diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index eab0bfc2760..a94b8eeb12b 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -13,12 +13,12 @@ QValueActor, ValueOperator, ) -from .common import TensorDictModule +from .common import SafeModule from .exploration import ( AdditiveGaussianWrapper, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) -from .probabilistic import ProbabilisticTensorDictModule -from .sequence import TensorDictSequential +from .probabilistic import SafeProbabilisticModule +from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 3744acc2aa1..dba80fc67a5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -11,14 +11,12 @@ from torchrl.data import CompositeSpec, TensorSpec, UnboundedContinuousTensorSpec from torchrl.modules.models.models import DistributionalDQNnet -from torchrl.modules.tensordict_module.common import TensorDictModule -from torchrl.modules.tensordict_module.probabilistic import ( - ProbabilisticTensorDictModule, -) -from torchrl.modules.tensordict_module.sequence import TensorDictSequential +from torchrl.modules.tensordict_module.common import SafeModule +from torchrl.modules.tensordict_module.probabilistic import SafeProbabilisticModule +from torchrl.modules.tensordict_module.sequence import SafeSequential -class Actor(TensorDictModule): +class Actor(SafeModule): """General class for deterministic actors in RL. The Actor class comes with default values for the out_keys (["action"]) @@ -70,7 +68,7 @@ def __init__( ) -class ProbabilisticActor(ProbabilisticTensorDictModule): +class ProbabilisticActor(SafeProbabilisticModule): """General class for probabilistic actors in RL. The Actor class comes with default values for the out_keys (["action"]) @@ -89,7 +87,7 @@ class ProbabilisticActor(ProbabilisticTensorDictModule): >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> fmodule, params, buffers = functorch.make_functional_with_buffers( ... module) - >>> tensordict_module = TensorDictModule(fmodule, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> tensordict_module = SafeModule(fmodule, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( ... module=tensordict_module, ... spec=action_spec, @@ -112,7 +110,7 @@ class ProbabilisticActor(ProbabilisticTensorDictModule): def __init__( self, - module: TensorDictModule, + module: SafeModule, dist_in_keys: Union[str, Sequence[str]], sample_out_key: Optional[Sequence[str]] = None, spec: Optional[TensorSpec] = None, @@ -136,7 +134,7 @@ def __init__( ) -class ValueOperator(TensorDictModule): +class ValueOperator(SafeModule): """General class for value functions in RL. The ValueOperator class comes with default values for the in_keys and @@ -529,7 +527,7 @@ def __init__( ) -class ActorValueOperator(TensorDictSequential): +class ActorValueOperator(SafeSequential): """Actor-value operator. This class wraps together an actor and a value model that share a common observation embedding network: @@ -561,9 +559,9 @@ class ActorValueOperator(TensorDictSequential): will both return a stand-alone TDModule with the dedicated functionality. Args: - common_operator (TensorDictModule): a common operator that reads observations and produces a hidden variable - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + common_operator (SafeModule): a common operator that reads observations and produces a hidden variable + policy_operator (SafeModule): a policy operator that reads the hidden variable and returns an action + value_operator (SafeModule): a value operator, that reads the hidden variable and returns a value Examples: >>> import torch @@ -573,14 +571,14 @@ class ActorValueOperator(TensorDictSequential): >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) - >>> td_module_hidden = TensorDictModule( + >>> td_module_hidden = SafeModule( ... module=module_hidden, ... spec=spec_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) - >>> module_action = TensorDictModule( + >>> module_action = SafeModule( ... NormalParamWrapper(torch.nn.Linear(4, 8)), ... in_keys=["hidden"], ... out_keys=["loc", "scale"], @@ -636,9 +634,9 @@ class ActorValueOperator(TensorDictSequential): def __init__( self, - common_operator: TensorDictModule, - policy_operator: TensorDictModule, - value_operator: TensorDictModule, + common_operator: SafeModule, + policy_operator: SafeModule, + value_operator: SafeModule, ): super().__init__( common_operator, @@ -646,13 +644,13 @@ def __init__( value_operator, ) - def get_policy_operator(self) -> TensorDictSequential: + def get_policy_operator(self) -> SafeSequential: """Returns a stand-alone policy operator that maps an observation to an action.""" - return TensorDictSequential(self.module[0], self.module[1]) + return SafeSequential(self.module[0], self.module[1]) - def get_value_operator(self) -> TensorDictSequential: + def get_value_operator(self) -> SafeSequential: """Returns a stand-alone value network operator that maps an observation to a value estimate.""" - return TensorDictSequential(self.module[0], self.module[2]) + return SafeSequential(self.module[0], self.module[2]) class ActorCriticOperator(ActorValueOperator): @@ -687,9 +685,9 @@ class ActorCriticOperator(ActorValueOperator): parent object, as the value is computed based on the policy output. Args: - common_operator (TensorDictModule): a common operator that reads observations and produces a hidden variable - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + common_operator (SafeModule): a common operator that reads observations and produces a hidden variable + policy_operator (SafeModule): a policy operator that reads the hidden variable and returns an action + value_operator (SafeModule): a value operator, that reads the hidden variable and returns a value Examples: >>> import torch @@ -699,7 +697,7 @@ class ActorCriticOperator(ActorValueOperator): >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP >>> spec_hidden = NdUnboundedContinuousTensorSpec(4) >>> module_hidden = torch.nn.Linear(4, 4) - >>> td_module_hidden = TensorDictModule( + >>> td_module_hidden = SafeModule( ... module=module_hidden, ... spec=spec_hidden, ... in_keys=["observation"], @@ -707,7 +705,7 @@ class ActorCriticOperator(ActorValueOperator): ... ) >>> spec_action = NdBoundedTensorSpec(-1, 1, torch.Size([8])) >>> module_action = NormalParamWrapper(torch.nn.Linear(4, 8)) - >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) + >>> module_action = SafeModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( ... module=module_action, ... spec=spec_action, @@ -791,7 +789,7 @@ def get_value_operator(self) -> TensorDictModuleWrapper: ) -class ActorCriticWrapper(TensorDictSequential): +class ActorCriticWrapper(SafeSequential): """Actor-value operator without common module. This class wraps together an actor and a value model that do not share a common observation embedding network: @@ -818,8 +816,8 @@ class ActorCriticWrapper(TensorDictSequential): will both return a stand-alone TDModule with the dedicated functionality. Args: - policy_operator (TensorDictModule): a policy operator that reads the hidden variable and returns an action - value_operator (TensorDictModule): a value operator, that reads the hidden variable and returns a value + policy_operator (SafeModule): a policy operator that reads the hidden variable and returns an action + value_operator (SafeModule): a value operator, that reads the hidden variable and returns a value Examples: >>> import torch @@ -875,18 +873,18 @@ class ActorCriticWrapper(TensorDictSequential): def __init__( self, - policy_operator: TensorDictModule, - value_operator: TensorDictModule, + policy_operator: SafeModule, + value_operator: SafeModule, ): super().__init__( policy_operator, value_operator, ) - def get_policy_operator(self) -> TensorDictSequential: + def get_policy_operator(self) -> SafeSequential: """Returns a stand-alone policy operator that maps an observation to an action.""" return self.module[0] - def get_value_operator(self) -> TensorDictSequential: + def get_value_operator(self) -> SafeSequential: """Returns a stand-alone value network operator that maps an observation to a value estimate.""" return self.module[1] diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 7a92ef614d3..e1807f66e69 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -29,7 +29,7 @@ FunctionalModuleWithBuffers, ) -from tensordict.nn import TensorDictModule as _TensorDictModule +from tensordict.nn import TensorDictModule from tensordict.tensordict import TensorDictBase from torch import nn @@ -54,7 +54,7 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): spec = module.spec if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): raise RuntimeError( - "safe TensorDictModules with multiple out_keys require a CompositeSpec with matching keys. Got " + "safe SafeModules with multiple out_keys require a CompositeSpec with matching keys. Got " f"keys {module.out_keys}." ) elif not isinstance(spec, CompositeSpec): @@ -80,8 +80,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): ) -class TensorDictModule(_TensorDictModule): - """A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. +class SafeModule(TensorDictModule): + """A SafeModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. Args: module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional @@ -98,8 +98,8 @@ class TensorDictModule(_TensorDictModule): If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is :obj:`False`. - Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. The domain spec can - be passed along if needed. TensorDictModule support functional and regular :obj:`nn.Module` objects. In the functional + Embedding a neural network in a SafeModule only requires to specify the input and output keys. The domain spec can + be passed along if needed. SafeModule support functional and regular :obj:`nn.Module` objects. In the functional case, the 'params' (and 'buffers') keyword argument must be specified: Examples: @@ -107,12 +107,12 @@ class TensorDictModule(_TensorDictModule): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TensorDictModule + >>> from torchrl.modules import SafeModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> spec = NdUnboundedContinuousTensorSpec(8) >>> module = torch.nn.GRUCell(4, 8) >>> fmodule, params, buffers = functorch.make_functional_with_buffers(module) - >>> td_fmodule = TensorDictModule( + >>> td_fmodule = SafeModule( ... module=fmodule, ... spec=spec, ... in_keys=["input", "hidden"], @@ -129,7 +129,7 @@ class TensorDictModule(_TensorDictModule): device=cpu) In the stateful case: - >>> td_module = TensorDictModule( + >>> td_module = SafeModule( ... module=module, ... spec=spec, ... in_keys=["input", "hidden"], @@ -165,7 +165,7 @@ class TensorDictModule(_TensorDictModule): def __init__( self, module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module + FunctionalModule, FunctionalModuleWithBuffers, SafeModule, nn.Module ], in_keys: Iterable[str], out_keys: Iterable[str], @@ -179,7 +179,7 @@ def __init__( elif spec is not None and not isinstance(spec, CompositeSpec): if len(self.out_keys) > 1: raise RuntimeError( - f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. " + f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. " "Consider using a CompositeSpec object or no spec at all." ) spec = CompositeSpec(**{self.out_keys[0]: spec}) @@ -208,7 +208,7 @@ def __init__( and all(_spec is None for _spec in spec.values()) ): raise RuntimeError( - "`TensorDictModule(spec=None, safe=True)` is not a valid configuration as the tensor " + "`SafeModule(spec=None, safe=True)` is not a valid configuration as the tensor " "specs are not specified" ) self.register_forward_hook(_forward_hook_safe_action) @@ -242,18 +242,18 @@ def random(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase: - """See :obj:`TensorDictModule.random(...)`.""" + """See :obj:`SafeModule.random(...)`.""" return self.random(tensordict) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> TensorDictModule: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> SafeModule: if hasattr(self, "spec") and self.spec is not None: self.spec = self.spec.to(dest) out = super().to(dest) return out -def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): - """Returns `True` if a module can be used as a TensorDictModule, and False if it can't. +def is_tensordict_compatible(module: Union[SafeModule, nn.Module]): + """Returns `True` if a module can be used as a SafeModule, and False if it can't. If the signature is misleading an error is raised. @@ -291,21 +291,21 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): """ sig = inspect.signature(module.forward) - if isinstance(module, TensorDictModule) or ( + if isinstance(module, SafeModule) or ( len(sig.parameters) == 1 and hasattr(module, "in_keys") and hasattr(module, "out_keys") ): - # if the module is a TensorDictModule or takes a single argument and defines + # if the module is a SafeModule or takes a single argument and defines # in_keys and out_keys then we assume it can already deal with TensorDict input # to forward and we return True return True elif not hasattr(module, "in_keys") and not hasattr(module, "out_keys"): - # if it's not a TensorDictModule, and in_keys and out_keys are not defined then + # if it's not a SafeModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. return False - # if in_keys or out_keys were defined but module is not a TensorDictModule or + # if in_keys or out_keys were defined but module is not a SafeModule or # accepts multiple arguments then it's likely the user is trying to do something # that will have undetermined behaviour, we raise an error raise TypeError( @@ -314,18 +314,16 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): "should take a single argument of type TensorDict to module.forward and define " "both in_keys and out_keys. Alternatively, module.forward can accept " "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the module with a TensorDictModule." + "TorchRL will attempt to automatically wrap the module with a SafeModule." ) def ensure_tensordict_compatible( - module: Union[ - FunctionalModule, FunctionalModuleWithBuffers, TensorDictModule, nn.Module - ], + module: Union[FunctionalModule, FunctionalModuleWithBuffers, SafeModule, nn.Module], in_keys: Optional[Iterable[str]] = None, out_keys: Optional[Iterable[str]] = None, safe: bool = False, - wrapper_type: Optional[Type] = TensorDictModule, + wrapper_type: Optional[Type] = SafeModule, ): """Checks and ensures an object with forward method is TensorDict compatible.""" if is_tensordict_compatible(module): @@ -345,7 +343,7 @@ def ensure_tensordict_compatible( if not isinstance(module, nn.Module): raise TypeError( "Argument to ensure_tensordict_compatible should be either " - "a TensorDictModule or an nn.Module" + "a SafeModule or an nn.Module" ) sig = inspect.signature(module.forward) @@ -353,10 +351,10 @@ def ensure_tensordict_compatible( raise TypeError( "Arguments to module.forward are incompatible with entries in " "env.observation_spec. If you want TorchRL to automatically " - "wrap your module with a TensorDictModule then the arguments " + "wrap your module with a SafeModule then the arguments " "to module must correspond one-to-one with entries in " "in_keys. For more complex behaviour and more control you can " - "consider writing your own TensorDictModule." + "consider writing your own SafeModule." ) # TODO: Check whether out_keys match (at least in number) if they are provided. diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index ed56b5a05de..fe3aac62df9 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -15,7 +15,7 @@ from torchrl.envs.utils import exploration_mode from torchrl.modules.tensordict_module.common import ( _forward_hook_safe_action, - TensorDictModule, + SafeModule, ) @@ -30,7 +30,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): """Epsilon-Greedy PO wrapper. Args: - policy (TensorDictModule): a deterministic policy. + policy (SafeModule): a deterministic policy. eps_init (scalar, optional): initial epsilon value. default: 1.0 eps_end (scalar, optional): final epsilon value. @@ -71,7 +71,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): def __init__( self, - policy: TensorDictModule, + policy: SafeModule, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, @@ -139,7 +139,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): """Additive Gaussian PO wrapper. Args: - policy (TensorDictModule): a policy. + policy (SafeModule): a policy. sigma_init (scalar, optional): initial epsilon value. default: 1.0 sigma_end (scalar, optional): final epsilon value. @@ -162,7 +162,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): def __init__( self, - policy: TensorDictModule, + policy: SafeModule, sigma_init: float = 1.0, sigma_end: float = 0.1, annealing_num_steps: int = 1000, @@ -250,7 +250,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): zeroing the tensordict at reset time. Args: - policy (TensorDictModule): a policy + policy (SafeModule): a policy eps_init (scalar): initial epsilon value, determining the amount of noise to be added. default: 1.0 eps_end (scalar): final epsilon value, determining the amount of noise to be added. @@ -293,7 +293,7 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): def __init__( self, - policy: TensorDictModule, + policy: SafeModule, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index c0893c6bcea..121be87c2f2 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -5,16 +5,14 @@ from typing import Optional, Sequence, Type, Union -from tensordict.nn import ( - ProbabilisticTensorDictModule as _ProbabilisticTensorDictModule, -) +from tensordict.nn import ProbabilisticTensorDictModule from torchrl.data import TensorSpec from torchrl.modules.distributions import Delta -from torchrl.modules.tensordict_module.common import TensorDictModule +from torchrl.modules.tensordict_module.common import SafeModule -class ProbabilisticTensorDictModule(_ProbabilisticTensorDictModule, TensorDictModule): +class SafeProbabilisticModule(ProbabilisticTensorDictModule, SafeModule): """A probabilistic TD Module. `ProbabilisticTDModule` is a special case of a TDModule where the output is @@ -22,12 +20,12 @@ class ProbabilisticTensorDictModule(_ProbabilisticTensorDictModule, TensorDictMo argument and the :obj:`exploration_mode()` global function. It consists in a wrapper around another TDModule that returns a tensordict - updated with the distribution parameters. :obj:`ProbabilisticTensorDictModule` is + updated with the distribution parameters. :obj:`SafeProbabilisticModule` is responsible for constructing the distribution (through the :obj:`get_dist()` method) and/or sampling from this distribution (through a regular :obj:`__call__()` to the module). - A :obj:`ProbabilisticTensorDictModule` instance has two main features: + A :obj:`SafeProbabilisticModule` instance has two main features: - It reads and writes TensorDict objects - It uses a real mapping R^n -> R^m to create a distribution in R^d from which values can be sampled or computed. @@ -36,8 +34,8 @@ class ProbabilisticTensorDictModule(_ProbabilisticTensorDictModule, TensorDictMo the 'rsample', 'sample' method). The sampling step is skipped if the inner TDModule has already created the desired key-value pair. - By default, ProbabilisticTensorDictModule distribution class is a Delta - distribution, making ProbabilisticTensorDictModule a simple wrapper around + By default, SafeProbabilisticModule distribution class is a Delta + distribution, making SafeProbabilisticModule a simple wrapper around a deterministic mapping function. Args: @@ -87,13 +85,13 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import ProbabilisticTensorDictModule, TanhNormal, NormalParamWrapper + >>> from torchrl.modules import SafeProbabilisticModule, TanhNormal, NormalParamWrapper >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> spec = NdUnboundedContinuousTensorSpec(4) >>> net = NormalParamWrapper(torch.nn.GRUCell(4, 8)) >>> fnet, params, buffers = functorch.make_functional_with_buffers(net) - >>> module = TensorDictModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) - >>> td_module = ProbabilisticTensorDictModule( + >>> module = SafeModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) + >>> td_module = SafeProbabilisticModule( ... module=module, ... spec=spec, ... dist_in_keys=["loc", "scale"], @@ -136,7 +134,7 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut def __init__( self, - module: TensorDictModule, + module: SafeModule, dist_in_keys: Union[str, Sequence[str], dict], sample_out_key: Union[str, Sequence[str]], spec: Optional[TensorSpec] = None, @@ -159,7 +157,7 @@ def __init__( cache_dist=cache_dist, n_empirical_estimate=n_empirical_estimate, ) - super(_ProbabilisticTensorDictModule, self).__init__( + super(ProbabilisticTensorDictModule, self).__init__( module=module, spec=spec, in_keys=self.in_keys, diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index fb0dd94add8..954a1a5faf7 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -7,15 +7,15 @@ from typing import Iterable, Union -from tensordict.nn import TensorDictSequential as _TensorDictSequential +from tensordict.nn import TensorDictSequential from torch import nn from torchrl.data import CompositeSpec -from torchrl.modules.tensordict_module.common import TensorDictModule +from torchrl.modules.tensordict_module.common import SafeModule -class TensorDictSequential(_TensorDictSequential, TensorDictModule): - """A sequence of TensorDictModules. +class SafeSequential(TensorDictSequential, SafeModule): + """A sequence of SafeModules. Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. @@ -23,12 +23,12 @@ class TensorDictSequential(_TensorDictSequential, TensorDictModule): buffers) will be concatenated in a single list. Args: - modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially. + modules (iterable of SafeModules): ordered sequence of SafeModule instances to be run sequentially. partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the - stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts + stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts looking for those that have the required keys, if any. TensorDictSequence supports functional, modular and vmap coding: @@ -37,14 +37,14 @@ class TensorDictSequential(_TensorDictSequential, TensorDictModule): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import NdUnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal, TensorDictSequential, NormalParamWrapper - >>> from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule + >>> from torchrl.modules import TanhNormal, SafeSequential, NormalParamWrapper + >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> spec1 = NdUnboundedContinuousTensorSpec(4) >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) >>> fnet1, params1, buffers1 = functorch.make_functional_with_buffers(net1) - >>> fmodule1 = TensorDictModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) - >>> td_module1 = ProbabilisticTensorDictModule( + >>> fmodule1 = SafeModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) + >>> td_module1 = SafeProbabilisticModule( ... module=fmodule1, ... spec=spec1, ... dist_in_keys=["loc", "scale"], @@ -55,13 +55,13 @@ class TensorDictSequential(_TensorDictSequential, TensorDictModule): >>> spec2 = NdUnboundedContinuousTensorSpec(8) >>> module2 = torch.nn.Linear(4, 8) >>> fmodule2, params2, buffers2 = functorch.make_functional_with_buffers(module2) - >>> td_module2 = TensorDictModule( + >>> td_module2 = SafeModule( ... module=fmodule2, ... spec=spec2, ... in_keys=["hidden"], ... out_keys=["output"], ... ) - >>> td_module = TensorDictSequential(td_module1, td_module2) + >>> td_module = SafeSequential(td_module1, td_module2) >>> params = params1 + params2 >>> buffers = buffers1 + buffers2 >>> _ = td_module(td, params=params, buffers=buffers) @@ -110,7 +110,7 @@ class TensorDictSequential(_TensorDictSequential, TensorDictModule): def __init__( self, - *modules: TensorDictModule, + *modules: SafeModule, partial_tolerant: bool = False, ): self.partial_tolerant = partial_tolerant @@ -119,12 +119,12 @@ def __init__( spec = CompositeSpec() for module in modules: - if isinstance(module, TensorDictModule) or hasattr(module, "spec"): + if isinstance(module, SafeModule) or hasattr(module, "spec"): spec.update(module.spec) else: spec.update(CompositeSpec({key: None for key in module.out_keys})) - super(_TensorDictSequential, self).__init__( + super(TensorDictSequential, self).__init__( spec=spec, module=nn.ModuleList(list(modules)), in_keys=in_keys, @@ -133,21 +133,21 @@ def __init__( def select_subsequence( self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None - ) -> "TensorDictSequential": - """Returns a new TensorDictSequential with only the modules that are necessary to compute the given output keys with the given input keys. + ) -> "SafeSequential": + """Returns a new SafeSequential with only the modules that are necessary to compute the given output keys with the given input keys. Args: in_keys: input keys of the subsequence we want to select out_keys: output keys of the subsequence we want to select Returns: - A new TensorDictSequential with only the modules that are necessary acording to the given input and output keys. + A new SafeSequential with only the modules that are necessary acording to the given input and output keys. """ td_sequential = super().select_subsequence(in_keys=in_keys, out_keys=out_keys) - return TensorDictSequential(*td_sequential.module) + return SafeSequential(*td_sequential.module) - def __getitem__(self, index: Union[int, slice]) -> TensorDictModule: + def __getitem__(self, index: Union[int, slice]) -> SafeModule: if isinstance(index, int): return self.module.__getitem__(index) else: - return TensorDictSequential(*self.module.__getitem__(index)) + return SafeSequential(*self.module.__getitem__(index)) diff --git a/torchrl/modules/tensordict_module/world_models.py b/torchrl/modules/tensordict_module/world_models.py index 10b8e5b9a5a..0243c3806a3 100644 --- a/torchrl/modules/tensordict_module/world_models.py +++ b/torchrl/modules/tensordict_module/world_models.py @@ -4,10 +4,10 @@ # LICENSE file in the root directory of this source tree. -from torchrl.modules.tensordict_module import TensorDictModule, TensorDictSequential +from torchrl.modules.tensordict_module import SafeModule, SafeSequential -class WorldModelWrapper(TensorDictSequential): +class WorldModelWrapper(SafeSequential): """World model wrapper. This module wraps together a transition model and a reward model. @@ -15,25 +15,18 @@ class WorldModelWrapper(TensorDictSequential): The reward model is used to predict the reward of the imagined transition. Args: - transition_model (TensorDictModule): a transition model that generates a new world states. - reward_model (TensorDictModule): a reward model, that reads the world state and returns a reward. + transition_model (SafeModule): a transition model that generates a new world states. + reward_model (SafeModule): a reward model, that reads the world state and returns a reward. """ - def __init__( - self, - transition_model: TensorDictModule, - reward_model: TensorDictModule, - ): - super().__init__( - transition_model, - reward_model, - ) - - def get_transition_model_operator(self) -> TensorDictSequential: + def __init__(self, transition_model: SafeModule, reward_model: SafeModule): + super().__init__(transition_model, reward_model) + + def get_transition_model_operator(self) -> SafeSequential: """Returns a transition operator that maps either an observation to a world state or a world state to the next world state.""" return self.module[0] - def get_reward_operator(self) -> TensorDictSequential: + def get_reward_operator(self) -> SafeSequential: """Returns a reward operator that maps a world state to a reward.""" return self.module[1] diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 5f1ffe01618..af20007b26a 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -9,8 +9,8 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d -from torchrl.modules import TensorDictModule -from torchrl.modules.tensordict_module import ProbabilisticTensorDictModule +from torchrl.modules import SafeModule +from torchrl.modules.tensordict_module import SafeProbabilisticModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import distance_loss @@ -26,7 +26,7 @@ class A2CLoss(LossModule): https://arxiv.org/abs/1602.01783v2 Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -36,13 +36,13 @@ class A2CLoss(LossModule): critic_coef (float): the weight of the critic loss. gamma (scalar): a discount factor for return computation. loss_function_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". - advantage_module (nn.Module): TensorDictModule used to compute tha advantage function. + advantage_module (nn.Module): SafeModule used to compute tha advantage function. """ def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key: str = "advantage", advantage_diff_key: str = "value_error", entropy_bonus: bool = True, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 14d33e09bce..a7c90521a5a 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -28,7 +28,7 @@ from torch import nn, Tensor from torch.nn import Parameter -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule class LossModule(nn.Module): @@ -64,7 +64,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def convert_to_functional( self, - module: TensorDictModule, + module: SafeModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, @@ -89,7 +89,7 @@ def convert_to_functional( def _convert_to_functional_functorch( self, - module: TensorDictModule, + module: SafeModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, @@ -249,7 +249,7 @@ def _convert_to_functional_functorch( def _convert_to_functional_native( self, - module: TensorDictModule, + module: SafeModule, module_name: str, expand_dim: Optional[int] = None, create_target_params: bool = False, diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index d2a54ee81fe..5692f109739 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -10,7 +10,7 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.utils import distance_loss, hold_out_params, next_state_value @@ -22,8 +22,8 @@ class DDPGLoss(LossModule): """The DDPG Loss class. Args: - actor_network (TensorDictModule): a policy operator. - value_network (TensorDictModule): a Q value operator. + actor_network (SafeModule): a policy operator. + value_network (SafeModule): a Q value operator. gamma (scalar): a discount factor for return computation. device (str, int or torch.device, optional): a device where the losses will be computed, if it can't be found via the value operator. @@ -36,8 +36,8 @@ class DDPGLoss(LossModule): def __init__( self, - actor_network: TensorDictModule, - value_network: TensorDictModule, + actor_network: SafeModule, + value_network: SafeModule, gamma: float, loss_function: str = "l2", delay_actor: bool = False, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 9005112e7d7..555525161a4 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -9,7 +9,7 @@ from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives import ( distance_loss, hold_out_params, @@ -26,8 +26,8 @@ class REDQLoss_deprecated(LossModule): train a SAC-like algorithm. Args: - actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value Default is 2. @@ -51,8 +51,8 @@ class REDQLoss_deprecated(LossModule): def __init__( self, - actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + actor_network: SafeModule, + qvalue_network: SafeModule, num_qvalue_nets: int = 10, sub_sample_len: int = 2, gamma: Number = 0.99, diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index c9f35b64649..8d839078154 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -9,7 +9,7 @@ from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import distance_loss, hold_out_net from torchrl.objectives.value.functional import vec_td_lambda_return_estimate @@ -24,7 +24,7 @@ class DreamerModelLoss(LossModule): Reference: https://arxiv.org/abs/1912.01603. Args: - world_model (TensorDictModule): the world model. + world_model (SafeModule): the world model. lambda_kl (float, optional): the weight of the kl divergence loss. Default: 1.0. lambda_reco (float, optional): the weight of the reconstruction loss. Default: 1.0. lambda_reward (float, optional): the weight of the reward loss. Default: 1.0. @@ -42,7 +42,7 @@ class DreamerModelLoss(LossModule): def __init__( self, - world_model: TensorDictModule, + world_model: SafeModule, lambda_kl: float = 1.0, lambda_reco: float = 1.0, lambda_reward: float = 1.0, @@ -133,8 +133,8 @@ class DreamerActorLoss(LossModule): Reference: https://arxiv.org/abs/1912.01603. Args: - actor_model (TensorDictModule): the actor model. - value_model (TensorDictModule): the value model. + actor_model (SafeModule): the actor model. + value_model (SafeModule): the value model. model_based_env (DreamerEnv): the model based environment. imagination_horizon (int, optional): The number of steps to unroll the model. Default: 15. @@ -147,8 +147,8 @@ class DreamerActorLoss(LossModule): def __init__( self, - actor_model: TensorDictModule, - value_model: TensorDictModule, + actor_model: SafeModule, + value_model: SafeModule, model_based_env: DreamerEnv, imagination_horizon: int = 15, gamma: int = 0.99, @@ -217,7 +217,7 @@ class DreamerValueLoss(LossModule): Reference: https://arxiv.org/abs/1912.01603. Args: - value_model (TensorDictModule): the value model. + value_model (SafeModule): the value model. value_loss (str, optional): the loss to use for the value loss. Default: "l2". gamma (float, optional): the gamma discount factor. Default: 0.99. discount_loss (bool, optional): if True, the loss is discounted with a @@ -227,7 +227,7 @@ class DreamerValueLoss(LossModule): def __init__( self, - value_model: TensorDictModule, + value_model: SafeModule, value_loss: Optional[str] = None, gamma: int = 0.99, discount_loss: bool = False, # for consistency with paper diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b711519492e..2926e2c667a 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -10,10 +10,10 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import distributions as d -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.utils import distance_loss -from ..modules.tensordict_module import ProbabilisticTensorDictModule +from ..modules.tensordict_module import SafeProbabilisticModule from .common import LossModule @@ -32,7 +32,7 @@ class PPOLoss(LossModule): https://arxiv.org/abs/1707.06347 Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -52,8 +52,8 @@ class PPOLoss(LossModule): def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key: str = "advantage", advantage_diff_key: str = "value_error", entropy_bonus: bool = True, @@ -171,7 +171,7 @@ class ClipPPOLoss(PPOLoss): loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage) Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -193,8 +193,8 @@ class ClipPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key: str = "advantage", clip_epsilon: float = 0.2, entropy_bonus: bool = True, @@ -277,7 +277,7 @@ class KLPENPPOLoss(PPOLoss): favouring a certain level of distancing between the two while still preventing them to be too much apart. Args: - actor (ProbabilisticTensorDictModule): policy operator. + actor (SafeProbabilisticModule): policy operator. critic (ValueOperator): value operator. advantage_key (str): the input tensordict key where the advantage is expected to be written. default: "advantage" @@ -304,8 +304,8 @@ class KLPENPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictModule, - critic: TensorDictModule, + actor: SafeProbabilisticModule, + critic: SafeModule, advantage_key="advantage", dtarg: float = 0.01, beta: float = 1.0, diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index beb7c31c51a..70b3d6a3e8d 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -13,7 +13,7 @@ from torch import Tensor from torchrl.envs.utils import set_exploration_mode, step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.common import _has_functorch, LossModule from torchrl.objectives.utils import ( distance_loss, @@ -30,8 +30,8 @@ class REDQLoss(LossModule): train a SAC-like algorithm. Args: - actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. + actor_network (SafeModule): the actor to be trained + qvalue_network (SafeModule): a single Q-value network that will be multiplicated as many times as needed. num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 10. sub_sample_len (int, optional): number of Q-value networks to be subsampled to evaluate the next state value Default is 2. @@ -59,8 +59,8 @@ class REDQLoss(LossModule): def __init__( self, - actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + actor_network: SafeModule, + qvalue_network: SafeModule, num_qvalue_nets: int = 10, sub_sample_len: int = 2, gamma: Number = 0.99, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index b68b7b981a0..294f79c50ec 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -4,7 +4,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.envs.utils import step_mdp -from torchrl.modules import ProbabilisticTensorDictModule, TensorDictModule +from torchrl.modules import SafeModule, SafeProbabilisticModule from torchrl.objectives import distance_loss from torchrl.objectives.common import LossModule @@ -19,9 +19,9 @@ class ReinforceLoss(LossModule): def __init__( self, - actor_network: ProbabilisticTensorDictModule, + actor_network: SafeProbabilisticModule, advantage_module: Callable[[TensorDictBase], TensorDictBase], - critic: Optional[TensorDictModule] = None, + critic: Optional[SafeModule] = None, delay_value: bool = False, gamma: float = 0.99, advantage_key: str = "advantage", diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 9b0685e2178..bfc5e088813 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -12,7 +12,7 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import Tensor -from torchrl.modules import ProbabilisticActor, TensorDictModule +from torchrl.modules import ProbabilisticActor, SafeModule from torchrl.modules.tensordict_module.actors import ActorCriticWrapper from torchrl.objectives.utils import distance_loss, next_state_value @@ -28,8 +28,8 @@ class SACLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (TensorDictModule): Q(s, a) parametric model - value_network (TensorDictModule): V(s) parametric model\ + qvalue_network (SafeModule): Q(s, a) parametric model + value_network (SafeModule): V(s) parametric model\ qvalue_network_bis (ProbabilisticTDModule, optional): if required, the Q-value can be computed twice independently using two separate networks. The minimum predicted value will then be used for @@ -68,8 +68,8 @@ class SACLoss(LossModule): def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, - value_network: TensorDictModule, + qvalue_network: SafeModule, + value_network: SafeModule, num_qvalue_nets: int = 2, gamma: Number = 0.99, priotity_key: str = "td_error", diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 467c0cb7c7f..4f2da57c93a 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -13,7 +13,7 @@ from torch.nn import functional as F from torchrl.envs.utils import step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule class _context_manager: @@ -293,7 +293,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @torch.no_grad() def next_state_value( tensordict: TensorDictBase, - operator: Optional[TensorDictModule] = None, + operator: Optional[SafeModule] = None, next_val_key: str = "state_action_value", gamma: float = 0.99, pred_next_val: Optional[Tensor] = None, diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 6996252847e..6ee6ef3503b 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -10,7 +10,7 @@ from torch import nn, Tensor from torchrl.envs.utils import step_mdp -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.value.functional import ( td_lambda_advantage_estimate, vec_generalized_advantage_estimate, @@ -26,7 +26,7 @@ class TDEstimate(nn.Module): Args: gamma (scalar): exponential mean discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. + value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. gradient_mode (bool, optional): if True, gradients are propagated throught @@ -38,7 +38,7 @@ class TDEstimate(nn.Module): def __init__( self, gamma: Union[float, torch.Tensor], - value_network: TensorDictModule, + value_network: SafeModule, average_rewards: bool = False, gradient_mode: bool = False, value_key: str = "state_value", @@ -129,7 +129,7 @@ class TDLambdaEstimate(nn.Module): Args: gamma (scalar): exponential mean discount. lmbda (scalar): trajectory discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. + value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. gradient_mode (bool, optional): if True, gradients are propagated throught @@ -144,7 +144,7 @@ def __init__( self, gamma: Union[float, torch.Tensor], lmbda: Union[float, torch.Tensor], - value_network: TensorDictModule, + value_network: SafeModule, average_rewards: bool = False, gradient_mode: bool = False, value_key: str = "state_value", @@ -251,7 +251,7 @@ class GAE(nn.Module): Args: gamma (scalar): exponential mean discount. lmbda (scalar): trajectory discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. + value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool): if True, rewards will be standardized before the GAE is computed. gradient_mode (bool): if True, gradients are propagated throught the computation of the value function. Default is `False`. @@ -262,7 +262,7 @@ def __init__( self, gamma: Union[float, torch.Tensor], lmbda: float, - value_network: TensorDictModule, + value_network: SafeModule, average_rewards: bool = False, gradient_mode: bool = False, ): diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index 0729ca65a7f..d7facd322c1 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -18,7 +18,7 @@ from torchrl.data import MultiStep from torchrl.envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.modules import ProbabilisticTensorDictModule +from torchrl.modules import SafeProbabilisticModule def sync_async_collector( @@ -249,7 +249,7 @@ def _make_collector( def make_collector_offpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[TensorDictModuleWrapper, ProbabilisticTensorDictModule], + actor_model_explore: Union[TensorDictModuleWrapper, SafeProbabilisticModule], cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> _DataCollector: @@ -257,7 +257,7 @@ def make_collector_offpolicy( Args: make_env (Callable): environment creator - actor_model_explore (TensorDictModule): Model instance used for evaluation and exploration update + actor_model_explore (SafeModule): Model instance used for evaluation and exploration update cfg (DictConfig): config for creating collector object make_env_kwargs (dict): kwargs for the env creator @@ -313,7 +313,7 @@ def make_collector_offpolicy( def make_collector_onpolicy( make_env: Callable[[], EnvBase], - actor_model_explore: Union[TensorDictModuleWrapper, ProbabilisticTensorDictModule], + actor_model_explore: Union[TensorDictModuleWrapper, SafeProbabilisticModule], cfg: "DictConfig", # noqa: F821 make_env_kwargs: Optional[Dict] = None, ) -> _DataCollector: @@ -321,7 +321,7 @@ def make_collector_onpolicy( Args: make_env (Callable): environment creator - actor_model_explore (TensorDictModule): Model instance used for evaluation and exploration update + actor_model_explore (SafeModule): Model instance used for evaluation and exploration update cfg (DictConfig): config for creating collector object make_env_kwargs (dict): kwargs for the env creator diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index f24771b0288..24742d62ee0 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -24,9 +24,9 @@ ActorValueOperator, NoisyLinear, NormalParamWrapper, - ProbabilisticTensorDictModule, - TensorDictModule, - TensorDictSequential, + SafeModule, + SafeProbabilisticModule, + SafeSequential, ) from torchrl.modules.distributions import ( Delta, @@ -315,7 +315,7 @@ def make_ddpg_actor( actor_net = DdpgMlpActor(**actor_net_default_kwargs) gSDE_state_key = "observation_vector" out_keys = ["param"] - actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys) + actor_module = SafeModule(actor_net, in_keys=in_keys, out_keys=out_keys) if cfg.gSDE: min = env_specs["action_spec"].space.minimum @@ -325,9 +325,9 @@ def make_ddpg_actor( transform = d.ComposeTransform( transform, d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2) ) - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform, learn_sigma=False), in_keys=["param", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -549,7 +549,7 @@ def make_a2c_model( out_features=hidden_features, activate_last_layer=True, ) - common_operator = TensorDictModule( + common_operator = SafeModule( spec=None, module=common_module, in_keys=in_keys_actor, @@ -565,13 +565,13 @@ def make_a2c_model( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) in_keys = ["hidden"] - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys, out_keys=["loc", "scale"] ) else: in_keys = ["hidden"] gSDE_state_key = "hidden" - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -589,9 +589,9 @@ def make_a2c_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSD"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -640,13 +640,13 @@ def make_a2c_model( actor_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] ) else: in_keys = in_keys_actor gSDE_state_key = in_keys_actor[0] - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -664,9 +664,9 @@ def make_a2c_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -838,7 +838,7 @@ def make_ppo_model( out_features=hidden_features, activate_last_layer=True, ) - common_operator = TensorDictModule( + common_operator = SafeModule( spec=None, module=common_module, in_keys=in_keys_actor, @@ -854,13 +854,13 @@ def make_ppo_model( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) in_keys = ["hidden"] - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys, out_keys=["loc", "scale"] ) else: in_keys = ["hidden"] gSDE_state_key = "hidden" - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -878,9 +878,9 @@ def make_ppo_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -929,13 +929,13 @@ def make_ppo_model( actor_net = NormalParamWrapper( policy_net, scale_mapping=f"biased_softplus_{cfg.default_policy_scale}" ) - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] ) else: in_keys = in_keys_actor gSDE_state_key = in_keys_actor[0] - actor_module = TensorDictModule( + actor_module = SafeModule( policy_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -953,9 +953,9 @@ def make_ppo_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -1140,7 +1140,7 @@ def make_sac_model( scale_lb=cfg.scale_lb, ) in_keys_actor = in_keys - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=[ @@ -1151,7 +1151,7 @@ def make_sac_model( else: gSDE_state_key = in_keys[0] - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys, out_keys=["action"], # will be overwritten @@ -1169,9 +1169,9 @@ def make_sac_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -1387,14 +1387,14 @@ def make_redq_model( scale_mapping=f"biased_softplus_{default_policy_scale}", scale_lb=cfg.scale_lb, ) - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["loc", "scale"] + out_keys_actor[1:], ) else: - actor_module = TensorDictModule( + actor_module = SafeModule( actor_net, in_keys=in_keys_actor, out_keys=["action"] + out_keys_actor[1:], # will be overwritten @@ -1412,9 +1412,9 @@ def make_redq_model( else: raise RuntimeError("cannot use gSDE with discrete actions") - actor_module = TensorDictSequential( + actor_module = SafeSequential( actor_module, - TensorDictModule( + SafeModule( LazygSDEModule(transform=transform), in_keys=["action", gSDE_state_key, "_eps_gSDE"], out_keys=["loc", "scale", "action", "_eps_gSDE"], @@ -1555,7 +1555,7 @@ def _dreamer_make_world_model( ): # World Model and reward model rssm_rollout = RSSMRollout( - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -1565,7 +1565,7 @@ def _dreamer_make_world_model( ("next", "belief"), ], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=[("next", "belief"), ("next", "encoded_latents")], out_keys=[ @@ -1576,20 +1576,20 @@ def _dreamer_make_world_model( ), ) - transition_model = TensorDictSequential( - TensorDictModule( + transition_model = SafeSequential( + SafeModule( obs_encoder, in_keys=[("next", "pixels")], out_keys=[("next", "encoded_latents")], ), rssm_rollout, - TensorDictModule( + SafeModule( obs_decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "reco_pixels")], ), ) - reward_model = TensorDictModule( + reward_model = SafeModule( reward_module, in_keys=[("next", "state"), ("next", "belief")], out_keys=["reward"], @@ -1630,8 +1630,8 @@ def _dreamer_make_actors( def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): - actor_simulator = ProbabilisticTensorDictModule( - TensorDictModule( + actor_simulator = SafeProbabilisticModule( + SafeModule( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], @@ -1663,13 +1663,13 @@ def _dreamer_make_actor_real( # actor for real world: interacts with states ~ posterior # Out actor differs from the original paper where first they compute prior and posterior and then act on it # but we found that this approach worked better. - actor_realworld = TensorDictSequential( - TensorDictModule( + actor_realworld = SafeSequential( + SafeModule( obs_encoder, in_keys=["pixels"], out_keys=["encoded_latents"], ), - TensorDictModule( + SafeModule( rssm_posterior, in_keys=["belief", "encoded_latents"], out_keys=[ @@ -1678,8 +1678,8 @@ def _dreamer_make_actor_real( "state", ], ), - ProbabilisticTensorDictModule( - TensorDictModule( + SafeProbabilisticModule( + SafeModule( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], @@ -1700,7 +1700,7 @@ def _dreamer_make_actor_real( } ), ), - TensorDictModule( + SafeModule( rssm_prior, in_keys=["state", "belief", action_key], out_keys=[ @@ -1716,7 +1716,7 @@ def _dreamer_make_actor_real( def _dreamer_make_value_model(mlp_num_units, value_key): # actor for simulator: interacts with states ~ prior - value_model = TensorDictModule( + value_model = SafeModule( MLP( out_features=1, depth=3, @@ -1740,7 +1740,7 @@ def _dreamer_make_mbenv( ): # MB environment if use_decoder_in_env: - mb_env_obs_decoder = TensorDictModule( + mb_env_obs_decoder = SafeModule( obs_decoder, in_keys=[("next", "state"), ("next", "belief")], out_keys=[("next", "reco_pixels")], @@ -1748,8 +1748,8 @@ def _dreamer_make_mbenv( else: mb_env_obs_decoder = None - transition_model = TensorDictSequential( - TensorDictModule( + transition_model = SafeSequential( + SafeModule( rssm_prior, in_keys=["state", "belief", "action"], out_keys=[ @@ -1760,7 +1760,7 @@ def _dreamer_make_mbenv( ], ), ) - reward_model = TensorDictModule( + reward_model = SafeModule( reward_module, in_keys=["state", "belief"], out_keys=["reward"], diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 7ec5f982a32..73691357c8d 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -15,7 +15,7 @@ from torchrl.collectors.collectors import _DataCollector from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase -from torchrl.modules import reset_noise, TensorDictModule +from torchrl.modules import reset_noise, SafeModule from torchrl.objectives.common import LossModule from torchrl.objectives.utils import TargetNetUpdater from torchrl.trainers.loggers import Logger @@ -80,9 +80,7 @@ def make_trainer( loss_module: LossModule, recorder: Optional[EnvBase] = None, target_net_updater: Optional[TargetNetUpdater] = None, - policy_exploration: Optional[ - Union[TensorDictModuleWrapper, TensorDictModule] - ] = None, + policy_exploration: Optional[Union[TensorDictModuleWrapper, SafeModule]] = None, replay_buffer: Optional[ReplayBuffer] = None, logger: Optional[Logger] = None, cfg: "DictConfig" = None, # noqa: F821 @@ -114,7 +112,7 @@ def make_trainer( >>> from torchrl.collectors.collectors import SyncDataCollector >>> from torchrl.data import TensorDictReplayBuffer >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.modules import TensorDictModuleWrapper, TensorDictModule, ValueOperator, EGreedyWrapper + >>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper >>> from torchrl.objectives.common import LossModule >>> from torchrl.objectives.utils import TargetNetUpdater >>> from torchrl.objectives import DDPGLoss @@ -124,7 +122,7 @@ def make_trainer( >>> action_spec = env_proof.action_spec >>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1]) >>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing - >>> policy = TensorDictModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) + >>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"]) >>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"]) >>> collector = SyncDataCollector(env_maker, policy, total_frames=100) >>> loss_module = DDPGLoss(policy, value, gamma=0.99) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 1909df1370a..1b6014ccb3a 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -29,7 +29,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.utils import set_exploration_mode -from torchrl.modules import TensorDictModule +from torchrl.modules import SafeModule from torchrl.objectives.common import LossModule from torchrl.trainers.loggers import Logger @@ -1036,7 +1036,7 @@ def __init__( record_interval: int, record_frames: int, frame_skip: int, - policy_exploration: TensorDictModule, + policy_exploration: SafeModule, recorder: EnvBase, exploration_mode: str = "random", log_keys: Optional[List[str]] = None, From 21ea2ee6c4f0380bc1c92d8a274440f117dd9c81 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 25 Nov 2022 12:02:42 +0000 Subject: [PATCH 13/14] Delete redundant methods following inheritance fixes --- torchrl/modules/tensordict_module/sequence.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 954a1a5faf7..bbc3323630f 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -5,8 +5,6 @@ from __future__ import annotations -from typing import Iterable, Union - from tensordict.nn import TensorDictSequential from torch import nn @@ -130,24 +128,3 @@ def __init__( in_keys=in_keys, out_keys=out_keys, ) - - def select_subsequence( - self, in_keys: Iterable[str] = None, out_keys: Iterable[str] = None - ) -> "SafeSequential": - """Returns a new SafeSequential with only the modules that are necessary to compute the given output keys with the given input keys. - - Args: - in_keys: input keys of the subsequence we want to select - out_keys: output keys of the subsequence we want to select - - Returns: - A new SafeSequential with only the modules that are necessary acording to the given input and output keys. - """ - td_sequential = super().select_subsequence(in_keys=in_keys, out_keys=out_keys) - return SafeSequential(*td_sequential.module) - - def __getitem__(self, index: Union[int, slice]) -> SafeModule: - if isinstance(index, int): - return self.module.__getitem__(index) - else: - return SafeSequential(*self.module.__getitem__(index)) From a814bdf89595c90378edf393423be916fdddc4b8 Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Fri, 25 Nov 2022 12:07:03 +0000 Subject: [PATCH 14/14] Some docstring improvements --- torchrl/modules/tensordict_module/common.py | 2 +- torchrl/modules/tensordict_module/probabilistic.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index e1807f66e69..c092197eb7c 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -81,7 +81,7 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): class SafeModule(TensorDictModule): - """A SafeModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. + """An :obj:``SafeModule`` is a :obj:``tensordict.nn.TensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. Args: module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 121be87c2f2..3061d1017fa 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -13,11 +13,7 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule, SafeModule): - """A probabilistic TD Module. - - `ProbabilisticTDModule` is a special case of a TDModule where the output is - sampled given some rule, specified by the input :obj:`default_interaction_mode` - argument and the :obj:`exploration_mode()` global function. + """A :obj:``SafeProbabilisticModule`` is an :obj:``tensordict.nn.ProbabilisticTensorDictModule`` subclass that accepts a :obj:``TensorSpec`` as argument to control the output domain. It consists in a wrapper around another TDModule that returns a tensordict updated with the distribution parameters. :obj:`SafeProbabilisticModule` is