diff --git a/test/test_collector.py b/test/test_collector.py index 4b8b70d8444..769ef221ae6 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -17,6 +17,7 @@ DiscreteActionVecPolicy, MockSerialEnv, ) +from tensordict.nn import TensorDictModule from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn from torchrl._utils import seed_generator @@ -980,12 +981,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker): if collector_class is not SyncDataCollector: assert all( - isinstance(p, SafeModule) for p in collector._policy_dict.values() + isinstance(p, TensorDictModule) for p in collector._policy_dict.values() ) assert all(p.out_keys == out_keys for p in collector._policy_dict.values()) assert all(p.module is policy for p in collector._policy_dict.values()) else: - assert isinstance(collector.policy, SafeModule) + assert isinstance(collector.policy, TensorDictModule) assert collector.policy.out_keys == out_keys assert collector.policy.module is policy diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 93a117821de..8170e7723a4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -17,6 +17,7 @@ import numpy as np import torch import torch.nn as nn +from tensordict.nn import TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset @@ -29,7 +30,6 @@ from ..data.utils import CloudpickleWrapper, DEVICE_TYPING from ..envs.common import EnvBase from ..envs.vec_env import _BatchedEnv -from ..modules.tensordict_module import SafeModule, SafeProbabilisticModule from .utils import split_trajectories _TIMEOUT = 1.0 @@ -84,21 +84,21 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: def _policy_is_tensordict_compatible(policy: nn.Module): sig = inspect.signature(policy.forward) - if isinstance(policy, SafeModule) or ( + if isinstance(policy, TensorDictModule) or ( len(sig.parameters) == 1 and hasattr(policy, "in_keys") and hasattr(policy, "out_keys") ): - # if the policy is a SafeModule or takes a single argument and defines + # if the policy is a TensorDictModule or takes a single argument and defines # in_keys and out_keys then we assume it can already deal with TensorDict input # to forward and we return True return True elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"): - # if it's not a SafeModule, and in_keys and out_keys are not defined then + # if it's not a TensorDictModule, and in_keys and out_keys are not defined then # we assume no TensorDict compatibility and will try to wrap it. return False - # if in_keys or out_keys were defined but policy is not a SafeModule or + # if in_keys or out_keys were defined but policy is not a TensorDictModule or # accepts multiple arguments then it's likely the user is trying to do something # that will have undetermined behaviour, we raise an error raise TypeError( @@ -107,7 +107,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module): "should take a single argument of type TensorDict to policy.forward and define " "both in_keys and out_keys. Alternatively, policy.forward can accept " "arbitrarily many tensor inputs and leave in_keys and out_keys undefined and " - "TorchRL will attempt to automatically wrap the policy with a SafeModule." + "TorchRL will attempt to automatically wrap the policy with a TensorDictModule." ) @@ -116,13 +116,13 @@ def _get_policy_and_device( self, policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, device: Optional[DEVICE_TYPING] = None, observation_spec: TensorSpec = None, - ) -> Tuple[SafeProbabilisticModule, torch.device, Union[None, Callable[[], dict]]]: + ) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. From a policy and a device, assigns the self.device attribute to @@ -133,7 +133,7 @@ def _get_policy_and_device( create_env_fn (Callable or list of callables): an env creator function (or a list of creators) create_env_kwargs (dictionary): kwargs for the env creator - policy (SafeProbabilisticModule, optional): a policy to be used + policy (TensorDictModule, optional): a policy to be used device (int, str or torch.device, optional): device where to place the policy observation_spec (TensorSpec, optional): spec of the observations @@ -161,13 +161,13 @@ def _get_policy_and_device( # callables should be supported as policies. if not _policy_is_tensordict_compatible(policy): # policy is a nn.Module that doesn't operate on tensordicts directly - # so we attempt to auto-wrap policy with SafeModule + # so we attempt to auto-wrap policy with TensorDictModule if observation_spec is None: raise ValueError( "Unable to read observation_spec from the environment. This is " "required to check compatibility of the environment and policy " "since the policy is a nn.Module that operates on tensors " - "rather than a SafeModule or a nn.Module that accepts a " + "rather than a TensorDictModule or a nn.Module that accepts a " "TensorDict as input and defines in_keys and out_keys." ) sig = inspect.signature(policy.forward) @@ -181,18 +181,18 @@ def _get_policy_and_device( if isinstance(output, tuple): out_keys.extend(f"output{i+1}" for i in range(len(output) - 1)) - policy = SafeModule( + policy = TensorDictModule( policy, in_keys=list(sig.parameters), out_keys=out_keys ) else: raise TypeError( "Arguments to policy.forward are incompatible with entries in " "env.observation_spec. If you want TorchRL to automatically " - "wrap your policy with a SafeModule then the arguments " + "wrap your policy with a TensorDictModule then the arguments " "to policy.forward must correspond one-to-one with entries in " "env.observation_spec that are prefixed with 'next_'. For more " "complex behaviour and more control you can consider writing " - "your own SafeModule." + "your own TensorDictModule." ) try: @@ -305,7 +305,7 @@ def __init__( ], # noqa: F821 policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -517,7 +517,7 @@ def iterator(self) -> Iterator[TensorDictBase]: def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase: policy_device = self.device if hasattr(self.policy, "in_keys"): - # some keys may be absent -- SafeModule is resilient to missing keys + # some keys may be absent -- TensorDictModule is resilient to missing keys td = td.select(*self.policy.in_keys, strict=False) if self._td_policy is None: self._td_policy = td.to(policy_device) @@ -717,7 +717,7 @@ class _MultiDataCollector(_DataCollector): Args: create_env_fn (list of Callabled): list of Callables, each returning an instance of EnvBase - policy (Callable, optional): Instance of SafeProbabilisticModule class. + policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of frames may well be greater than this as the closing signals are sent to the @@ -776,7 +776,7 @@ def __init__( create_env_fn: Sequence[Callable[[], EnvBase]], policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, @@ -1303,7 +1303,7 @@ class aSyncDataCollector(MultiaSyncDataCollector): Args: create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable, optional): Instance of SafeProbabilisticModule class. + policy (Callable, optional): Instance of TensorDictModule class. Must accept TensorDictBase object as input. total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings, the actual number of @@ -1358,7 +1358,7 @@ def __init__( create_env_fn: Callable[[], EnvBase], policy: Optional[ Union[ - SafeProbabilisticModule, + TensorDictModule, Callable[[TensorDictBase], TensorDictBase], ] ] = None, diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index e12af352f43..67cdaab8e4b 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -32,6 +32,12 @@ # Make all the necessary imports for training +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + from copy import deepcopy from typing import Optional @@ -40,6 +46,7 @@ import torch.cuda import tqdm from matplotlib import pyplot as plt +from tensordict.nn import TensorDictModule from torch import nn, optim from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import ( @@ -64,7 +71,6 @@ MLP, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, - TensorDictModule, ValueOperator, ) from torchrl.modules.distributions.continuous import TanhDelta diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index ef21e18f4eb..78003cbd9b2 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -42,11 +42,19 @@ # to provide a high-level illustration of TorchRL features in the context # of this algorithm. +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch import tqdm +from functorch import vmap from IPython import display from matplotlib import pyplot as plt from tensordict import TensorDict +from tensordict.nn import get_functional from torch import nn from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -251,18 +259,16 @@ def make_model(): print("Q-value network results:", tensordict) # make functional - factor, (_, buffers) = actor.make_functional_with_buffers(clone=True, native=True) - # making functional creates a copy of the params, which we don't want (i.e. we want the parameters from `actor` to match those in the params object), - # hence we create the params object in a second step - params = TensorDict({k: v for k, v in net.named_parameters()}, []).unflatten_keys( - "." - ) + # here's an explicit way of creating the parameters and buffer tensordict. + # Alternatively, we could have used `params = make_functional(actor)` from + # tensordict.nn + params = TensorDict({k: v for k, v in actor.named_parameters()}, []) + buffers = TensorDict({k: v for k, v in actor.named_buffers()}, []) + params = params.update(buffers).unflatten_keys(".") # creates a nested TensorDict + factor = get_functional(actor) # creating the target parameters is fairly easy with tensordict: - params_target, buffers_target = ( - params.to_tensordict().detach(), - buffers.to_tensordict().detach(), - ) + (params_target,) = (params.to_tensordict().detach(),) # we wrap our actor in an EGreedyWrapper for data collection actor_explore = EGreedyWrapper( @@ -272,7 +278,7 @@ def make_model(): eps_end=eps_greedy_val_env, ) - return factor, actor, actor_explore, params, buffers, params_target, buffers_target + return factor, actor, actor_explore, params, params_target ############################################################################### @@ -286,14 +292,10 @@ def make_model(): actor, actor_explore, params, - buffers, params_target, - buffers_target, ) = make_model() params_flat = params.flatten_keys(".") -buffers_flat = buffers.flatten_keys(".") params_target_flat = params_target.flatten_keys(".") -buffers_target_flat = buffers_target.flatten_keys(".") ############################################################################### # Regular DQN @@ -393,7 +395,7 @@ def make_model(): # Compute action value (of the action actually taken) at time t sampled_data_out = sampled_data.select(*actor.in_keys) - sampled_data_out = factor(sampled_data_out, params=params, buffers=buffers) + sampled_data_out = factor(sampled_data_out, params=params) action_value = sampled_data_out["action_value"] action_value = (action_value * action.to(action_value.dtype)).sum(-1) with torch.no_grad(): @@ -402,7 +404,6 @@ def make_model(): next_value = factor( tdstep.select(*actor.in_keys), params=params_target, - buffers=buffers_target, )["chosen_action_value"].squeeze(-1) exp_value = reward + gamma * next_value * (1 - done) assert exp_value.shape == action_value.shape @@ -420,9 +421,6 @@ def make_model(): for (key, p1) in params_flat.items(): p2 = params_target_flat[key] params_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) - for (key, p1) in buffers_flat.items(): - p2 = buffers_target_flat[key] - buffers_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) pbar.set_description( f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}" @@ -513,7 +511,7 @@ def make_model(): "grad_vals": grad_vals, "traj_lengths_training": traj_lengths, "traj_count": traj_count, - "weights": (params, buffers), + "weights": (params,), }, "saved_results_td0.pt", ) @@ -548,14 +546,10 @@ def make_model(): actor, actor_explore, params, - buffers, params_target, - buffers_target, ) = make_model() params_flat = params.flatten_keys(".") -buffers_flat = buffers.flatten_keys(".") params_target_flat = params_target.flatten_keys(".") -buffers_target_flat = buffers_target.flatten_keys(".") ############################################################################### @@ -632,19 +626,15 @@ def make_model(): action = sampled_data["action"].clone() sampled_data_out = sampled_data.select(*actor.in_keys) - sampled_data_out = factor( - sampled_data_out, params=params, buffers=buffers, vmap=(None, None, 0) - ) + sampled_data_out = vmap(factor, (0, None))(sampled_data_out, params) action_value = sampled_data_out["action_value"] action_value = (action_value * action.to(action_value.dtype)).sum(-1, True) with torch.no_grad(): tdstep = step_mdp(sampled_data) - next_value = factor( - tdstep.select(*actor.in_keys), - params=params_target, - buffers=buffers_target, - vmap=(None, None, 0), - )["chosen_action_value"] + next_value = vmap(factor, (0, None))( + tdstep.select(*actor.in_keys), params + ) + next_value = next_value["chosen_action_value"] error = vec_td_lambda_advantage_estimate( gamma, lmbda, @@ -671,9 +661,6 @@ def make_model(): for (key, p1) in params_flat.items(): p2 = params_target_flat[key] params_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) - for (key, p1) in buffers_flat.items(): - p2 = buffers_target_flat[key] - buffers_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data) pbar.set_description( f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}" @@ -765,7 +752,7 @@ def make_model(): "grad_vals": grad_vals, "traj_lengths_training": traj_lengths, "traj_count": traj_count, - "weights": (params, buffers), + "weights": (params,), }, "saved_results_tdlambda.pt", ) diff --git a/tutorials/sphinx-tutorials/multi_task.py b/tutorials/sphinx-tutorials/multi_task.py index 59aa6c76d68..9896d616fbc 100644 --- a/tutorials/sphinx-tutorials/multi_task.py +++ b/tutorials/sphinx-tutorials/multi_task.py @@ -9,14 +9,21 @@ # can compute actions in diverse settings using a distinct set of weights. # You will also be able to execute diverse environments in parallel. +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn ############################################################################## from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.modules import MLP, TensorDictModule, TensorDictSequential +from torchrl.modules import MLP ############################################################################### # We design two environments, one humanoid that must complete the stand task diff --git a/tutorials/sphinx-tutorials/tensordict_module.py b/tutorials/sphinx-tutorials/tensordict_module.py index 648b92909ac..c4102b35208 100644 --- a/tutorials/sphinx-tutorials/tensordict_module.py +++ b/tutorials/sphinx-tutorials/tensordict_module.py @@ -6,7 +6,7 @@ """ ############################################################################## # For a convenient usage of the ``TensorDict`` class with ``nn.Module``, -# TorchRL provides an interface between the two named ``TensorDictModule``. +# :obj:`tensordict` provides an interface between the two named ``TensorDictModule``. # The ``TensorDictModule`` class is an ``nn.Module`` that takes a # ``TensorDict`` as input when called. # It is up to the user to define the keys to be read as input and output. @@ -14,10 +14,16 @@ # TensorDictModule by examples # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch import torch.nn as nn from tensordict import TensorDict -from torchrl.modules import TensorDictModule, TensorDictSequential +from tensordict.nn import TensorDictModule, TensorDictSequential ############################################################################### # Example 1: Simple usage @@ -143,10 +149,10 @@ def forward(self, x): ############################################################################### # Example 5: Compatibility with functorch # ----------------------------------------- -# ``TensorDictModule`` comes with its own ``make_functional_with_buffers`` -# method to make it functional (you should not be using -# ``functorch.make_functional_with_buffers(tensordictmodule)``, that will -# not work in general). +# tensordict.nn is compatible with functorch. It also comes with its own functional +# utilities. Let us have a look: + +import functorch tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) @@ -155,29 +161,39 @@ def forward(self, x): in_keys=["a"], out_keys=["output_1", "output_2"], ) -func, (params, buffers) = splitlinear.make_functional_with_buffers() -func(tensordict, params=params, buffers=buffers) +func, params, buffers = functorch.make_functional_with_buffers(splitlinear) +print(func(params, buffers, tensordict)) + +############################################################################### +# This can be used with the vmap operator. For example, we use 3 replicas of the +# params and buffers and execute a vectorized map over these for a single batch +# of data: + +params_expand = [p.expand(3, *p.shape) for p in params] +buffers_expand = [p.expand(3, *p.shape) for p in buffers] +print(functorch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict)) ############################################################################### -# We can also use the ``vmap`` operator, here's an example of -# model ensembling with it: +# We can also use the native :obj:`get_functional()` function from tensordict.nn, +# which modifies the module to make it accept the parameters as regular inputs: + +from tensordict.nn import make_functional tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) num_models = 10 model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"]) -fmodel, (params, buffers) = model.make_functional_with_buffers() -params = [torch.randn(num_models, *p.shape, device=p.device) for p in params] -buffers = [torch.randn(num_models, *b.shape, device=b.device) for b in buffers] -result_td = fmodel(tensordict, params=params, buffers=buffers, vmap=True) +params = make_functional(model) +# we stack two groups of parameters to show the vmap usage: +params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0) +result_td = functorch.vmap(model, (None, 0))(tensordict, params) print("the output tensordict shape is: ", result_td.shape) +from tensordict.nn import ProbabilisticTensorDictModule + ############################################################################### # Do's and don't with TensorDictModule # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Don't use ``nn.Module`` wrappers with ``TensorDictModule`` componants. -# This would break some of ``TensorDictModule`` features such as ``functorch`` -# compatibility. # # Don't use ``nn.Sequence``, similar to ``nn.Module``, it would break features # such as ``functorch`` compatibility. Do use ``TensorDictSequential`` instead. @@ -189,14 +205,6 @@ def forward(self, x): # # tensordict_out = module(tensordict) # don't! # -# Don't use ``make_functional_with_buffers`` from ``functorch`` directly but -# use ``TensorDictModule.make_functional_with_buffers`` instead. -# -# TensorDictModule for RL -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# TorchRL provides a few RL-specific ``TensorDictModule`` instances that -# serves domain-specific needs. -# # ``ProbabilisticTensorDictModule`` # ---------------------------------- # ``ProbabilisticTensorDictModule`` is a special case of a ``TensorDictModule`` @@ -214,11 +222,7 @@ def forward(self, x): # One can find the parameters in the output tensordict as well as the log # probability if needed. -from torchrl.modules import ( - NormalParamWrapper, - ProbabilisticTensorDictModule, - TanhNormal, -) +from torchrl.modules import NormalParamWrapper, TanhNormal td = TensorDict( {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, diff --git a/tutorials/sphinx-tutorials/tensordict_tutorial.py b/tutorials/sphinx-tutorials/tensordict_tutorial.py index e856f754f9f..ad50c6a3d1f 100644 --- a/tutorials/sphinx-tutorials/tensordict_tutorial.py +++ b/tutorials/sphinx-tutorials/tensordict_tutorial.py @@ -84,6 +84,12 @@ # However to achieve this you would need to write a complicated collate # function that make sure that every modality is aggregated properly. +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + def collate_dict_fn(dict_list): final_dict = {} @@ -123,6 +129,7 @@ def collate_dict_fn(dict_list): # TensorDict structure # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + import torch ############################################################################### diff --git a/tutorials/sphinx-tutorials/torch_envs.py b/tutorials/sphinx-tutorials/torch_envs.py index 27512b85e20..5f4f70172b2 100644 --- a/tutorials/sphinx-tutorials/torch_envs.py +++ b/tutorials/sphinx-tutorials/torch_envs.py @@ -25,6 +25,11 @@ # will pass the arguments and keyword arguments to the root library builder. # # With gym, it means that building an environment is as easy as: +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore import torch from matplotlib import pyplot as plt diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 0dd2fbb3236..f7ac96f35e1 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -124,6 +124,12 @@ # TensorDict # ------------------------------ +# sphinx_gallery_start_ignore +import warnings + +warnings.filterwarnings("ignore") +# sphinx_gallery_end_ignore + import torch from tensordict import TensorDict @@ -172,7 +178,9 @@ # Here are some other functionalities of TensorDict. print( - "view(-1): ", tensordict.view(-1).batch_size, tensordict.view(-1).get("key 1").shape + "view(-1): ", + tensordict.view(-1).batch_size, + tensordict.view(-1).get("key 1").shape, ) print("to device: ", tensordict.to("cpu")) @@ -348,7 +356,8 @@ from torchrl.envs import ParallelEnv base_env = ParallelEnv( - 4, lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) + 4, + lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False), ) env = TransformedEnv( base_env, Compose(NoopResetEnv(3), ToTensorImage()) @@ -384,7 +393,10 @@ # Example of a CNN model: cnn = ConvNet( - num_cells=[32, 64], kernel_sizes=[8, 4], strides=[2, 1], aggregator_class=SquashDims + num_cells=[32, 64], + kernel_sizes=[8, 4], + strides=[2, 1], + aggregator_class=SquashDims, ) print(cnn) print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed @@ -393,7 +405,7 @@ # TensorDictModules # ------------------------------ -from torchrl.modules import TensorDictModule +from tensordict.nn import TensorDictModule tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10]) module = nn.Linear(3, 4) @@ -405,7 +417,7 @@ # Sequences of Modules # ------------------------------ -from torchrl.modules import TensorDictSequential +from tensordict.nn import TensorDictSequential backbone_module = nn.Linear(5, 3) backbone = TensorDictModule( @@ -446,20 +458,21 @@ # Functional Programming (Ensembling / Meta-RL) # ---------------------------------------------- -fsequence, (params, buffers) = sequence.make_functional_with_buffers() -len(list(fsequence.parameters())) # functional modules have no parameters +from tensordict.nn import make_functional + +params = make_functional(sequence) +len(list(sequence.parameters())) # functional modules have no parameters ############################################################################### -fsequence(tensordict, params=params, buffers=buffers) +sequence(tensordict, params) ############################################################################### -params_expand = [p.expand(4, *p.shape) for p in params] -buffers_expand = [b.expand(4, *b.shape) for b in buffers] -tensordict_exp = fsequence( - tensordict, params=params_expand, buffers=buffers, vmap=(0, 0, None) -) +import functorch + +params_expand = params.expand(4) +tensordict_exp = functorch.vmap(sequence, (None, 0))(tensordict, params_expand) print(tensordict_exp) ############################################################################### @@ -468,10 +481,11 @@ torch.manual_seed(0) from torchrl.data import NdBoundedTensorSpec +from torchrl.modules import SafeModule spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) -module = TensorDictModule( +module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True ) tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) @@ -491,14 +505,12 @@ tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) actor(tensordict) # action is the default value +from tensordict.nn import ProbabilisticTensorDictModule + ############################################################################### # Probabilistic modules -from torchrl.modules import ( - NormalParamWrapper, - ProbabilisticTensorDictModule, - TanhNormal, -) +from torchrl.modules import NormalParamWrapper, TanhNormal td = TensorDict( {"input": torch.randn(3, 5)}, @@ -572,7 +584,7 @@ action_spec = env.action_spec actor_module = nn.Linear(3, 1) -actor = TensorDictModule( +actor = SafeModule( actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"] ) @@ -628,6 +640,8 @@ (tensordict_rollout == tensordicts_prealloc).all() +from tensordict.nn import TensorDictModule + # Collectors # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -637,7 +651,6 @@ from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv -from torchrl.modules import TensorDictModule # EnvCreator makes sure that we can send a lambda function from process to process parallel_env = ParallelEnv(3, EnvCreator(lambda: GymEnv("Pendulum-v1"))) @@ -666,6 +679,8 @@ print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) +collector.shutdown() +del collector ############################################################################### @@ -685,6 +700,7 @@ print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) +collector.shutdown() del collector ############################################################################### @@ -728,11 +744,11 @@ def forward(self, obs, action): ############################################################################### -loss_td +print(loss_td) ############################################################################### -tensordict +print(tensordict) ############################################################################### # State of the Library