From 3cc640e398e022c17e167d1a8ee219b7f5f84a64 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Mar 2024 11:18:18 +0000 Subject: [PATCH 1/7] init --- torchrl/modules/models/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index c610bb61350..1c5d1d7059e 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -423,8 +423,8 @@ def __init__( def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: layers = [] - in_features = [self.in_features] + self.num_cells[: self.depth] - out_features = self.num_cells + [self.out_features] + in_features = [self.in_features] + list(self.num_cells[: self.depth]) + out_features = list(self.num_cells) + [self.out_features] kernel_sizes = self.kernel_sizes strides = self.strides paddings = self.paddings From ba4073ce0bd7a4b8172bea8f4986e0548f22ed04 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Mar 2024 16:17:36 +0000 Subject: [PATCH 2/7] amend --- torchrl/modules/models/__init__.py | 1 + torchrl/modules/models/models.py | 863 ++++++++++++++++++++--------- torchrl/modules/models/utils.py | 10 +- 3 files changed, 604 insertions(+), 270 deletions(-) diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 01aa429a412..518abca1f65 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -18,6 +18,7 @@ DistributionalDQNnet, DTActor, DuelingCnnDQNet, + DuelingMlpDQNet, LSTMNet, MLP, OnlineDTActor, diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 1c5d1d7059e..5e0bfdb6af4 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -8,9 +8,10 @@ import warnings from numbers import Number -from typing import Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import torch +import torchrl.modules from tensordict.nn import dispatch, TensorDictModuleBase from torch import nn from torch.nn import functional as F @@ -41,21 +42,32 @@ class MLP(nn.Sequential): Args: in_features (int, optional): number of input features; - out_features (int, list of int): number of output features. If iterable of integers, the output is reshaped to - the desired shape; - depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the - desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, - the depth information should be contained in the num_cells argument (see below). If num_cells is an - iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. - num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If - an integer is provided, every layer will have the same number of cells. If an iterable is provided, - the linear layers out_features will match the content of num_cells. - default: 32; - activation_class (Type[nn.Module]): activation class to be used. - default: nn.Tanh - activation_kwargs (dict, optional): kwargs to be used with the activation class; - norm_class (Type, optional): normalization class, if any. - norm_kwargs (dict, optional): kwargs to be used with the normalization layers; + out_features (int, torch.Size or equivalent): number of output + features. If iterable of integers, the output is reshaped to the + desired shape. + depth (int, optional): depth of the network. A depth of 0 will produce + a single linear layer network with the desired input and output size. + A length of 1 will create 2 linear layers etc. If no depth is indicated, + the depth information should be contained in the ``num_cells`` + argument (see below). If ``num_cells`` is an iterable and depth is + indicated, both should match: ``len(num_cells)`` must be equal to + ``depth``. + num_cells (int or sequence of int, optional): number of cells of every + layer in between the input and output. If an integer is provided, + every layer will have the same number of cells. If an iterable is provided, + the linear layers ``out_features`` will match the content of + ``num_cells``. Defaults to ``32``; + activation_class (Type[nn.Module] or callable, optional): activation + class or constructor to be used. + Defaults to :class:`~torch.nn.Tanh`. + activation_kwargs (dict or list of dicts, optional): kwargs to be used + with the activation class. Aslo accepts a list of kwargs, one for + each layer. + norm_class (Type or callable, optional): normalization class or + constructor, if any. + norm_kwargs (dict or list of dicts, optional): kwargs to be used with + the normalization layers. Aslo accepts a list of kwargs, one for + each layer. dropout (float, optional): dropout probability. Defaults to ``None`` (no dropout); bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. @@ -63,12 +75,14 @@ class MLP(nn.Sequential): single_bias_last_layer (bool): if ``True``, the last dimension of the bias of the last layer will be a singleton dimension. default: True; - layer_class (Type[nn.Module]): class to be used for the linear layers; - layer_kwargs (dict, optional): kwargs for the linear layers; + layer_class (Type[nn.Module] or callable, optional): class to be used + for the linear layers; + layer_kwargs (dict or list of dicts, optional): kwargs for the linear + layers. Aslo accepts a list of kwargs, one for each layer. activate_last_layer (bool): whether the MLP output should be activated. This is useful when the MLP output is used as the input for another module. default: False. - device (Optional[DEVICE_TYPING]): device to create the module on. + device (torch.device, optional): device to create the module on. Examples: >>> # All of the following examples provide valid, working MLPs @@ -148,21 +162,21 @@ class MLP(nn.Sequential): def __init__( self, - in_features: Optional[int] = None, - out_features: Union[int, Sequence[int]] = None, - depth: Optional[int] = None, - num_cells: Optional[Union[Sequence, int]] = None, - activation_class: Type[nn.Module] = nn.Tanh, - activation_kwargs: Optional[dict] = None, - norm_class: Optional[Type[nn.Module]] = None, - norm_kwargs: Optional[dict] = None, - dropout: Optional[float] = None, + in_features: int | None = None, + out_features: int | torch.Size = None, + depth: int | None = None, + num_cells: Sequence[int] | int | None = None, + activation_class: Type[nn.Module] | Callable = nn.Tanh, + activation_kwargs: dict | List[dict] | None = None, + norm_class: Optional[Type[nn.Module] | Callable] = None, + norm_kwargs: dict | List[dict] | None = None, + dropout: float | None = None, bias_last_layer: bool = True, single_bias_last_layer: bool = False, - layer_class: Type[nn.Module] = nn.Linear, - layer_kwargs: Optional[dict] = None, + layer_class: Type[nn.Module] | Callable = nn.Linear, + layer_kwargs: dict | None = None, activate_last_layer: bool = False, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): if out_features is None: raise ValueError("out_features must be specified for MLP.") @@ -183,16 +197,19 @@ def __init__( self.out_features = out_features self._out_features_num = _out_features_num self.activation_class = activation_class - self.activation_kwargs = ( - activation_kwargs if activation_kwargs is not None else {} - ) self.norm_class = norm_class - self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.dropout = dropout self.bias_last_layer = bias_last_layer self.single_bias_last_layer = single_bias_last_layer self.layer_class = layer_class - self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} + + self.activation_kwargs = activation_kwargs + self.norm_kwargs = norm_kwargs + self.layer_kwargs = layer_kwargs + self._activation_kwargs_iter = _iter_maybe_over_single(activation_kwargs) + self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs) + self._layer_kwargs_iter = _iter_maybe_over_single(layer_kwargs) + self.activate_last_layer = activate_last_layer if single_bias_last_layer: raise NotImplementedError @@ -214,12 +231,13 @@ def __init__( layers = self._make_net(device) super().__init__(*layers) - def _make_net(self, device: Optional[DEVICE_TYPING]) -> List[nn.Module]: + def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: layers = [] in_features = [self.in_features] + self.num_cells out_features = self.num_cells + [self._out_features_num] for i, (_in, _out) in enumerate(zip(in_features, out_features)): _bias = self.bias_last_layer if i == self.depth else True + layer_kwargs = next(self._layer_kwargs_iter) if _in is not None: layers.append( create_on_device( @@ -228,7 +246,7 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> List[nn.Module]: _in, _out, bias=_bias, - **self.layer_kwargs, + **layer_kwargs, ) ) else: @@ -241,21 +259,21 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> List[nn.Module]: ) layers.append( create_on_device( - lazy_version, device, _out, bias=_bias, **self.layer_kwargs + lazy_version, device, _out, bias=_bias, **layer_kwargs ) ) if i < self.depth or self.activate_last_layer: + norm_kwargs = next(self._norm_kwargs_iter) + activation_kwargs = next(self._activation_kwargs_iter) if self.dropout is not None: layers.append(create_on_device(nn.Dropout, device, p=self.dropout)) if self.norm_class is not None: layers.append( - create_on_device(self.norm_class, device, **self.norm_kwargs) + create_on_device(self.norm_class, device, **norm_kwargs) ) layers.append( - create_on_device( - self.activation_class, device, **self.activation_kwargs - ) + create_on_device(self.activation_class, device, **activation_kwargs) ) return layers @@ -274,33 +292,50 @@ class ConvNet(nn.Sequential): """A convolutional neural network. Args: - in_features (int, optional): number of input features; - depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the - desired input size, and with an output size equal to the last element of the num_cells argument. - If no depth is indicated, the depth information should be contained in the num_cells argument (see below). - If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to - the depth. - num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If - an integer is provided, every layer will have the same number of cells. If an iterable is provided, - the linear layers out_features will match the content of num_cells. - default: [32, 32, 32]; - kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the - depth, defined by the num_cells or depth arguments. - strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the - depth, defined by the num_cells or depth arguments. - activation_class (Type[nn.Module]): activation class to be used. - default: nn.Tanh - activation_kwargs (dict, optional): kwargs to be used with the activation class; - norm_class (Type, optional): normalization class, if any; - norm_kwargs (dict, optional): kwargs to be used with the normalization layers; - bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. - default: True; - aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain. - default: SquashDims; - aggregator_kwargs (dict, optional): kwargs for the aggregator_class; - squeeze_output (bool): whether the output should be squeezed of its singleton dimensions. - default: False. - device (Optional[DEVICE_TYPING]): device to create the module on. + in_features (int, optional): number of input features. If ``None``, a + :class:`~torch.nn.LazyConv2d` module is used for the first layer.; + depth (int, optional): depth of the network. A depth of 1 will produce + a single linear layer network with the desired input size, and + with an output size equal to the last element of the num_cells + argument. + If no depth is indicated, the depth information should be contained + in the ``num_cells`` argument (see below). + If ``num_cells`` is an iterable and ``depth`` is indicated, both + should match: ``len(num_cells)`` must be equal to the ``depth``. + num_cells (int or Sequence of int, optional): number of cells of + every layer in between the input and output. If an integer is + provided, every layer will have the same number of cells. If an + iterable is provided, the linear layers ``out_features`` will match + the content of num_cells. Defaults to ``[32, 32, 32]``. + kernel_sizes (int, sequence of int, optional): Kernel size(s) of the + conv network. If iterable, the length must match the depth, + defined by the ``num_cells`` or depth arguments. + Defaults to ``3``. + strides (int or sequence of int, optional): Stride(s) of the conv network. If + iterable, the length must match the depth, defined by the + ``num_cells`` or depth arguments. Defaults to ``1``. + activation_class (Type[nn.Module] or callable, optional): activation + class or constructor to be used. + Defaults to :class:`~torch.nn.Tanh`. + activation_kwargs (dict or list of dicts, optional): kwargs to be used + with the activation class. A list of kwargs can also be passed, + with one element per layer. + norm_class (Type or callable, optional): normalization class or + constructor, if any. + norm_kwargs (dict or list of dicts, optional): kwargs to be used with + the normalization layers. A list of kwargs can also be passed, + with one element per layer. + bias_last_layer (bool): if ``True``, the last Linear layer will have a + bias parameter. Defaults to ``True``. + aggregator_class (Type[nn.Module] or callable): aggregator class or + constructor to use at the end of the chain. + Defaults to :class:`torchrl.modules.utils.models.SquashDims`; + aggregator_kwargs (dict, optional): kwargs for the + ``aggregator_class``. + squeeze_output (bool): whether the output should be squeezed of its + singleton dimensions. + Defaults to ``False``. + device (torch.device, optional): device to create the module on. Examples: >>> # All of the following examples provide valid, working MLPs @@ -355,32 +390,28 @@ class ConvNet(nn.Sequential): def __init__( self, - in_features: Optional[int] = None, - depth: Optional[int] = None, - num_cells: Union[Sequence, int] = None, - kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3, - strides: Union[Sequence, int] = 1, - paddings: Union[Sequence, int] = 0, - activation_class: Type[nn.Module] = nn.ELU, - activation_kwargs: Optional[dict] = None, - norm_class: Optional[Type[nn.Module]] = None, - norm_kwargs: Optional[dict] = None, + in_features: int | None = None, + depth: int | None = None, + num_cells: Sequence[int] | int = None, + kernel_sizes: Union[Sequence[int], int] = 3, + strides: Sequence[int] | int = 1, + paddings: Sequence[int] | int = 0, + activation_class: Type[nn.Module] | Callable = nn.ELU, + activation_kwargs: dict | List[dict] | None = None, + norm_class: Type[nn.Module] | Callable | None = None, + norm_kwargs: dict | List[dict] | None = None, bias_last_layer: bool = True, - aggregator_class: Optional[Type[nn.Module]] = SquashDims, - aggregator_kwargs: Optional[dict] = None, + aggregator_class: Type[nn.Module] | Callable | None = SquashDims, + aggregator_kwargs: dict | None = None, squeeze_output: bool = False, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): if num_cells is None: num_cells = [32, 32, 32] self.in_features = in_features self.activation_class = activation_class - self.activation_kwargs = ( - activation_kwargs if activation_kwargs is not None else {} - ) self.norm_class = norm_class - self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.bias_last_layer = bias_last_layer self.aggregator_class = aggregator_class self.aggregator_kwargs = ( @@ -389,6 +420,13 @@ def __init__( self.squeeze_output = squeeze_output # self.single_bias_last_layer = single_bias_last_layer + self.activation_kwargs = ( + activation_kwargs if activation_kwargs is not None else {} + ) + self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self._activation_kwargs_iter = _iter_maybe_over_single(activation_kwargs) + self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs) + depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings) self.depth = depth if depth == 0: @@ -421,7 +459,7 @@ def __init__( layers = self._make_net(device) super().__init__(*layers) - def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: + def _make_net(self, device: DEVICE_TYPING | None) -> nn.Module: layers = [] in_features = [self.in_features] + list(self.num_cells[: self.depth]) out_features = list(self.num_cells) + [self.out_features] @@ -456,15 +494,13 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: ) ) + activation_kwargs = next(self._activation_kwargs_iter) layers.append( - create_on_device( - self.activation_class, device, **self.activation_kwargs - ) + create_on_device(self.activation_class, device, **activation_kwargs) ) if self.norm_class is not None: - layers.append( - create_on_device(self.norm_class, device, **self.norm_kwargs) - ) + norm_kwargs = next(self._norm_kwargs_iter) + layers.append(create_on_device(self.norm_class, device, **norm_kwargs)) if self.aggregator_class is not None: layers.append( @@ -494,34 +530,50 @@ class Conv3dNet(nn.Sequential): """A 3D-convolutional neural network. Args: - in_features (int, optional): number of input features. A lazy implementation that automatically retrieves - the input size will be used if none is provided. - depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the - desired input size, and with an output size equal to the last element of the num_cells argument. - If no depth is indicated, the depth information should be contained in the num_cells argument (see below). - If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to - the depth. - num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If - an integer is provided, every layer will have the same number of cells. If an iterable is provided, - the linear layers out_features will match the content of num_cells. - default: ``[32, 32, 32]`` or ``[32] * depth` is depth is not ``None``. - kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the - depth, defined by the num_cells or depth arguments. - strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the - depth, defined by the num_cells or depth arguments. - activation_class (Type[nn.Module]): activation class to be used. - default: nn.Tanh - activation_kwargs (dict, optional): kwargs to be used with the activation class; - norm_class (Type, optional): normalization class, if any; - norm_kwargs (dict, optional): kwargs to be used with the normalization layers; - bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. - default: True; - aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain. - default: SquashDims; - aggregator_kwargs (dict, optional): kwargs for the aggregator_class; - squeeze_output (bool): whether the output should be squeezed of its singleton dimensions. - default: False. - device (Optional[DEVICE_TYPING]): device to create the module on. + in_features (int, optional): number of input features. A lazy + implementation that automatically retrieves the input size will be + used if none is provided. + depth (int, optional): depth of the network. A depth of ``1`` will + produce a single linear layer network with the desired input size, + and with an output size equal to the last element of the + ``num_cells`` argument. If no ``depth`` is indicated, the ``depth`` + information should be contained in the ``num_cells`` argument + (see below). + If ``num_cells`` is an iterable and ``depth`` is indicated, + both should match: ``len(num_cells)`` must be equal to + the ``depth``. + num_cells (int or sequence of int, optional): number of cells of every + layer in between the input and output. If an integer is provided, + every layer will have the same number of cells and the depth will + be retrieved from ``depth``. If an iterable is + provided, the linear layers ``out_features`` will match the content + of num_cells. Defaults to ``[32, 32, 32]`` or ``[32] * depth` is + depth is not ``None``. + kernel_sizes (int, sequence of int, optional): Kernel size(s) of the + conv network. If iterable, the length must match the depth, + defined by the ``num_cells`` or depth arguments. Defaults to ``3``. + strides (int or sequence of int): Stride(s) of the conv network. + If iterable, the length must match the depth, defined by the + ``num_cells`` or depth arguments. Defaults to ``1``. + activation_class (Type[nn.Module] or callable): activation class or + constructor to be used. Defaults to :class:`~torch.nn.Tanh`. + activation_kwargs (dict or list of dicts, optional): kwargs to be used + with the activation class. A list of kwargs with one element per + layer can also be provided. + norm_class (Type or callable, optional): normalization class, if any. + norm_kwargs (dict or list of dicts, optional): kwargs to be used with + the normalization layers. A list of kwargs with one element per + layer can also be provided. + bias_last_layer (bool): if ``True``, the last Linear layer will have a + bias parameter. Defaults to ``True``. + aggregator_class (Type[nn.Module] or callable): aggregator class or + constructor to use at the end of the chain. Defaults to + :class:`~torchrl.modules.models.utils.SquashDims`. + aggregator_kwargs (dict, optional): kwargs for the ``aggregator_class`` + constructor. + squeeze_output (bool): whether the output should be squeezed of its + singleton dimensions. Defaults to ``False``. + device (torch.device, optional): device to create the module on. Examples: >>> # All of the following examples provide valid, working MLPs @@ -576,21 +628,21 @@ class Conv3dNet(nn.Sequential): def __init__( self, - in_features: Optional[int] = None, - depth: Optional[int] = None, - num_cells: Union[Sequence, int] = None, - kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3, - strides: Union[Sequence, int] = 1, - paddings: Union[Sequence, int] = 0, - activation_class: Type[nn.Module] = nn.ELU, - activation_kwargs: Optional[dict] = None, - norm_class: Optional[Type[nn.Module]] = None, - norm_kwargs: Optional[dict] = None, + in_features: int | None = None, + depth: int | None = None, + num_cells: Sequence[int] | int = None, + kernel_sizes: Sequence[int] | int = 3, + strides: Sequence[int] | int = 1, + paddings: Sequence[int] | int = 0, + activation_class: Type[nn.Module] | Callable = nn.ELU, + activation_kwargs: dict | List[dict] | None = None, + norm_class: Type[nn.Module] | Callable | None = None, + norm_kwargs: dict | List[dict] | None = None, bias_last_layer: bool = True, - aggregator_class: Optional[Type[nn.Module]] = SquashDims, - aggregator_kwargs: Optional[dict] = None, + aggregator_class: Type[nn.Module] | Callable | None = SquashDims, + aggregator_kwargs: dict | None = None, squeeze_output: bool = False, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): if num_cells is None: if depth is None: @@ -600,11 +652,15 @@ def __init__( self.in_features = in_features self.activation_class = activation_class + self.norm_class = norm_class + self.activation_kwargs = ( activation_kwargs if activation_kwargs is not None else {} ) - self.norm_class = norm_class self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self._activation_kwargs_iter = _iter_maybe_over_single(activation_kwargs) + self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs) + self.bias_last_layer = bias_last_layer self.aggregator_class = aggregator_class self.aggregator_kwargs = ( @@ -640,7 +696,7 @@ def __init__( layers = self._make_net(device) super().__init__(*layers) - def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: + def _make_net(self, device: DEVICE_TYPING | None) -> nn.Module: layers = [] in_features = [self.in_features] + self.num_cells[: self.depth] out_features = self.num_cells + [self.out_features] @@ -675,15 +731,13 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: ) ) + activation_kwargs = next(self._activation_kwargs_iter) layers.append( - create_on_device( - self.activation_class, device, **self.activation_kwargs - ) + create_on_device(self.activation_class, device, **activation_kwargs) ) if self.norm_class is not None: - layers.append( - create_on_device(self.norm_class, device, **self.norm_kwargs) - ) + norm_kwargs = next(self._norm_kwargs_iter) + layers.append(create_on_device(self.norm_class, device, **norm_kwargs)) if self.aggregator_class is not None: layers.append( @@ -717,8 +771,9 @@ class DuelingMlpDQNet(nn.Module): Presented in https://arxiv.org/abs/1511.06581 Args: - out_features (int): number of features for the advantage network - out_features_value (int): number of features for the value network + out_features (int, torch.Size or equivalent): number of features for the advantage network + out_features_value (int): number of features for the value network. + Defaults to ``1``. mlp_kwargs_feature (dict, optional): kwargs for the feature network. Default is @@ -730,8 +785,7 @@ class DuelingMlpDQNet(nn.Module): ... } mlp_kwargs_output (dict, optional): kwargs for the advantage and - value networks. - Default is + value networks. Default is >>> mlp_kwargs_output = { ... "depth": 1, @@ -740,16 +794,50 @@ class DuelingMlpDQNet(nn.Module): ... "bias_last_layer": True, ... } - device (Optional[DEVICE_TYPING]): device to create the module on. + device (torch.device, optional): device to create the module on. + + Examples: + >>> import torch + >>> from torchrl.modules import DuelingMlpDQNet + >>> # we can ask for a specific output shape + >>> net = DuelingMlpDQNet(out_features=(3, 2)) + >>> print(net) + DuelingMlpDQNet( + (features): MLP( + (0): LazyLinear(in_features=0, out_features=256, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=256, out_features=256, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=256, out_features=256, bias=True) + (5): ELU(alpha=1.0) + ) + (advantage): MLP( + (0): LazyLinear(in_features=0, out_features=512, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=512, out_features=6, bias=True) + ) + (value): MLP( + (0): LazyLinear(in_features=0, out_features=512, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=512, out_features=1, bias=True) + ) + ) + >>> x = torch.zeros(1, 5) + >>> y = net(x) + >>> print(y) + tensor([[[ 0.0232, -0.0477], + [-0.0226, -0.0019], + [-0.0314, 0.0069]]], grad_fn=) + """ def __init__( self, - out_features: int, + out_features: int | torch.Size, out_features_value: int = 1, - mlp_kwargs_feature: Optional[dict] = None, - mlp_kwargs_output: Optional[dict] = None, - device: Optional[DEVICE_TYPING] = None, + mlp_kwargs_feature: dict | None = None, + mlp_kwargs_output: dict | None = None, + device: DEVICE_TYPING | None = None, ): super().__init__() @@ -800,10 +888,10 @@ class DuelingCnnDQNet(nn.Module): Presented in https://arxiv.org/abs/1511.06581 Args: - out_features (int): number of features for the advantage network - out_features_value (int): number of features for the value network - cnn_kwargs (dict, optional): kwargs for the feature network. - Default is + out_features (int): number of features for the advantage network. + out_features_value (int): number of features for the value network. + cnn_kwargs (dict or list of dicts, optional): kwargs for the feature + network. Default is >>> cnn_kwargs = { ... 'num_cells': [32, 64, 64], @@ -811,8 +899,8 @@ class DuelingCnnDQNet(nn.Module): ... 'kernels': [8, 4, 3], ... } - mlp_kwargs (dict, optional): kwargs for the advantage and value network. - Default is + mlp_kwargs (dict or list of dicts, optional): kwargs for the advantage + and value network. Default is >>> mlp_kwargs = { ... "depth": 1, @@ -821,16 +909,48 @@ class DuelingCnnDQNet(nn.Module): ... "bias_last_layer": True, ... } - device (Optional[DEVICE_TYPING]): device to create the module on. + device (torch.device, optional): device to create the module on. + + Examples: + >>> import torch + >>> from torchrl.modules import DuelingCnnDQNet + >>> net = DuelingCnnDQNet(out_features=20) + >>> print(net) + DuelingCnnDQNet( + (features): ConvNet( + (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) + (5): ELU(alpha=1.0) + (6): SquashDims() + ) + (advantage): MLP( + (0): LazyLinear(in_features=0, out_features=512, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=512, out_features=20, bias=True) + ) + (value): MLP( + (0): LazyLinear(in_features=0, out_features=512, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=512, out_features=1, bias=True) + ) + ) + >>> x = torch.zeros(1, 3, 64, 64) + >>> y = net(x) + >>> print(y.shape) + torch.Size([1, 20]) + """ def __init__( self, out_features: int, out_features_value: int = 1, - cnn_kwargs: Optional[dict] = None, - mlp_kwargs: Optional[dict] = None, - device: Optional[DEVICE_TYPING] = None, + cnn_kwargs: dict | None = None, + mlp_kwargs: dict | None = None, + device: DEVICE_TYPING | None = None, ): super().__init__() @@ -869,7 +989,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DistributionalDQNnet(TensorDictModuleBase): - """Distributional Deep Q-Network. + """Distributional Deep Q-Network softmax layer. + + This layer should be used in between a regular model that predicts the + action values and a distribution which acts on logits values. Args: in_keys (list of str or tuples of str): input keys to the log-softmax @@ -877,6 +1000,19 @@ class DistributionalDQNnet(TensorDictModuleBase): out_keys (list of str or tuples of str): output keys to the log-softmax operation. Defaults to ``["action_value"]``. + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> net = DistributionalDQNnet() + >>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10]) + >>> net(td) + TensorDict( + fields={ + action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + """ _wrong_out_feature_dims_error = ( @@ -922,13 +1058,26 @@ def forward(self, tensordict): def ddpg_init_last_layer( module: nn.Sequential, scale: float = 6e-4, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ) -> None: - """Initializer for the last layer of DDPG. + """Initializer for the last layer of DDPG modules. Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf + Args: + module (nn.Module): an actor or critic to be initialized. + scale (float, optional): the noise scale. Defaults to ``6e-4``. + device (torch.device, optional): the device where the noise should be + created. Defaults to the device of the last layer's weight + parameter. + + Examples: + >>> from torchrl.modules.models.models import MLP, ddpg_init_last_layer + >>> mlp = MLP(in_features=4, out_features=5, num_cells=(10, 10)) + >>> # init the last layer of the MLP + >>> ddpg_init_last_layer(mlp) + """ for last_layer in reversed(module): if isinstance(last_layer, (nn.Linear, nn.Conv2d)): @@ -951,46 +1100,86 @@ class DdpgCnnActor(nn.Module): Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf - The DDPG Convolutional Actor takes as input an observation (some simple transformation of the observed pixels) and - returns an action vector from it. - It is trained to maximise the value returned by the DDPG Q Value network. + The DDPG Convolutional Actor takes as input an observation (some simple + transformation of the observed pixels) and returns an action vector from + it, as well as an observation embedding that can be reused for a value + estimation. It should be trained to maximise the value returned by the + DDPG Q Value network. Args: action_dim (int): length of the action vector. - conv_net_kwargs (dict, optional): kwargs for the ConvNet. - default: { - 'in_features': None, - "num_cells": [32, 64, 64], - "kernel_sizes": [8, 4, 3], - "strides": [4, 2, 1], - "paddings": [0, 0, 1], - 'activation_class': nn.ELU, - 'norm_class': None, - 'aggregator_class': SquashDims, - 'aggregator_kwargs': {"ndims_in": 3}, - 'squeeze_output': True, - } + conv_net_kwargs (dict or list of dicts, optional): kwargs for the ConvNet. + Defaults to + + >>> { + ... 'in_features': None, + ... "num_cells": [32, 64, 64], + ... "kernel_sizes": [8, 4, 3], + ... "strides": [4, 2, 1], + ... "paddings": [0, 0, 1], + ... 'activation_class': torch.nn.ELU, + ... 'norm_class': None, + ... 'aggregator_class': SquashDims, + ... 'aggregator_kwargs': {"ndims_in": 3}, + ... 'squeeze_output': True, + ... } # + mlp_net_kwargs: kwargs for MLP. - Default: { - 'in_features': None, - 'out_features': action_dim, - 'depth': 2, - 'num_cells': 200, - 'activation_class': nn.ELU, - 'bias_last_layer': True, - } - use_avg_pooling (bool, optional): if ``True``, a nn.AvgPooling layer is - used to aggregate the output. Default is ``False``. - device (Optional[DEVICE_TYPING]): device to create the module on. + Defaults to: + + >>> { + ... 'in_features': None, + ... 'out_features': action_dim, + ... 'depth': 2, + ... 'num_cells': 200, + ... 'activation_class': nn.ELU, + ... 'bias_last_layer': True, + ... } + + use_avg_pooling (bool, optional): if ``True``, a + :class:`~torch.nn.AvgPooling` layer is used to aggregate the + output. Defaults to ``False``. + device (torch.device, optional): device to create the module on. + + Examples: + >>> import torch + >>> from torchrl.modules import DdpgCnnActor + >>> actor = DdpgCnnActor(action_dim=4) + >>> print(actor) + DdpgCnnActor( + (convnet): ConvNet( + (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (5): ELU(alpha=1.0) + (6): SquashDims() + ) + (mlp): MLP( + (0): LazyLinear(in_features=0, out_features=200, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=200, out_features=200, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=200, out_features=4, bias=True) + ) + ) + >>> obs = torch.randn(10, 3, 64, 64) + >>> action, hidden = actor(obs) + >>> print(action.shape) + torch.Size([10, 4]) + >>> print(hidden.shape) + torch.Size([10, 2304]) + """ def __init__( self, action_dim: int, - conv_net_kwargs: Optional[dict] = None, - mlp_net_kwargs: Optional[dict] = None, + conv_net_kwargs: dict | None = None, + mlp_net_kwargs: dict | None = None, use_avg_pooling: bool = False, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): super().__init__() conv_net_default_kwargs = { @@ -1043,22 +1232,45 @@ class DdpgMlpActor(nn.Module): Args: action_dim (int): length of the action vector mlp_net_kwargs (dict, optional): kwargs for MLP. - Default: { - 'in_features': None, - 'out_features': action_dim, - 'depth': 2, - 'num_cells': [400, 300], - 'activation_class': nn.ELU, - 'bias_last_layer': True, - } - device (Optional[DEVICE_TYPING]): device to create the module on. + Defaults to + + >>> { + ... 'in_features': None, + ... 'out_features': action_dim, + ... 'depth': 2, + ... 'num_cells': [400, 300], + ... 'activation_class': nn.ELU, + ... 'bias_last_layer': True, + ... } + + device (torch.device, optional): device to create the module on. + + Examples: + >>> import torch + >>> from torchrl.modules import DdpgMlpActor + >>> actor = DdpgMlpActor(action_dim=4) + >>> print(actor) + DdpgMlpActor( + (mlp): MLP( + (0): LazyLinear(in_features=0, out_features=400, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=400, out_features=300, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=300, out_features=4, bias=True) + ) + ) + >>> obs = torch.zeros(10, 6) + >>> action = actor(obs) + >>> print(action.shape) + torch.Size([10, 4]) + """ def __init__( self, action_dim: int, - mlp_net_kwargs: Optional[dict] = None, - device: Optional[DEVICE_TYPING] = None, + mlp_net_kwargs: dict | None = None, + device: DEVICE_TYPING | None = None, ): super().__init__() mlp_net_default_kwargs = { @@ -1085,42 +1297,83 @@ class DdpgCnnQNet(nn.Module): Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf - The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. + The DDPG Q-value network takes as input an observation and an action, and + returns a scalar from it. Args: - conv_net_kwargs (dict, optional): kwargs for the convolutional network. - default: { - 'in_features': None, - "num_cells": [32, 64, 128], - "kernel_sizes": [8, 4, 3], - "strides": [4, 2, 1], - "paddings": [0, 0, 1], - 'activation_class': nn.ELU, - 'norm_class': None, - 'aggregator_class': nn.AdaptiveAvgPool2d, - 'aggregator_kwargs': {}, - 'squeeze_output': True, - } + conv_net_kwargs (dict, optional): kwargs for the + convolutional network. + Defaults to + + >>> { + ... 'in_features': None, + ... "num_cells": [32, 64, 128], + ... "kernel_sizes": [8, 4, 3], + ... "strides": [4, 2, 1], + ... "paddings": [0, 0, 1], + ... 'activation_class': nn.ELU, + ... 'norm_class': None, + ... 'aggregator_class': nn.AdaptiveAvgPool2d, + ... 'aggregator_kwargs': {}, + ... 'squeeze_output': True, + ... } + mlp_net_kwargs (dict, optional): kwargs for MLP. - Default: { - 'in_features': None, - 'out_features': 1, - 'depth': 2, - 'num_cells': 200, - 'activation_class': nn.ELU, - 'bias_last_layer': True, - } - use_avg_pooling (bool, optional): if ``True``, a nn.AvgPooling layer is - used to aggregate the output. Default is ``True``. - device (Optional[DEVICE_TYPING]): device to create the module on. + Defaults to + + >>> { + ... 'in_features': None, + ... 'out_features': 1, + ... 'depth': 2, + ... 'num_cells': 200, + ... 'activation_class': nn.ELU, + ... 'bias_last_layer': True, + ... } + + use_avg_pooling (bool, optional): if ``True``, a + :class:`~torch.nn.AvgPooling` layer is used to aggregate the + output. Default is ``True``. + device (torch.device, optional): device to create the module on. + + Examples: + >>> from torchrl.modules import DdpgCnnQNet + >>> import torch + >>> net = DdpgCnnQNet() + >>> print(net) + DdpgCnnQNet( + (convnet): ConvNet( + (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4)) + (1): ELU(alpha=1.0) + (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)) + (3): ELU(alpha=1.0) + (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + (5): ELU(alpha=1.0) + (6): AdaptiveAvgPool2d(output_size=(1, 1)) + (7): Squeeze2dLayer() + ) + (mlp): MLP( + (0): LazyLinear(in_features=0, out_features=200, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=200, out_features=200, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=200, out_features=1, bias=True) + ) + ) + >>> obs = torch.zeros(1, 3, 64, 64) + >>> action = torch.zeros(1, 4) + >>> value = net(obs, action) + >>> print(value.shape) + torch.Size([1, 1]) + + """ def __init__( self, - conv_net_kwargs: Optional[dict] = None, - mlp_net_kwargs: Optional[dict] = None, + conv_net_kwargs: dict | None = None, + mlp_net_kwargs: dict | None = None, use_avg_pooling: bool = True, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): super().__init__() conv_net_default_kwargs = { @@ -1167,37 +1420,68 @@ class DdpgMlpQNet(nn.Module): Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf - The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. - Because actions are integrated later than observations, two networks are created. + The DDPG Q-value network takes as input an observation and an action, + and returns a scalar from it. + Because actions are integrated later than observations, two networks are + created. Args: mlp_net_kwargs_net1 (dict, optional): kwargs for MLP. - Default: { - 'in_features': None, - 'out_features': 400, - 'depth': 0, - 'num_cells': [], - 'activation_class': nn.ELU, - 'bias_last_layer': True, - 'activate_last_layer': True, - } + Defaults to + + >>> { + ... 'in_features': None, + ... 'out_features': 400, + ... 'depth': 0, + ... 'num_cells': [], + ... 'activation_class': nn.ELU, + ... 'bias_last_layer': True, + ... 'activate_last_layer': True, + ... } + mlp_net_kwargs_net2 - Default: { - 'in_features': None, - 'out_features': 1, - 'depth': 1, - 'num_cells': [300, ], - 'activation_class': nn.ELU, - 'bias_last_layer': True, - } - device (Optional[DEVICE_TYPING]): device to create the module on. + Defaults to + + >>> { + ... 'in_features': None, + ... 'out_features': 1, + ... 'depth': 1, + ... 'num_cells': [300, ], + ... 'activation_class': nn.ELU, + ... 'bias_last_layer': True, + ... } + + device (torch.device, optional): device to create the module on. + + Examples: + >>> import torch + >>> from torchrl.modules import DdpgMlpQNet + >>> net = DdpgMlpQNet() + >>> print(net) + DdpgMlpQNet( + (mlp1): MLP( + (0): LazyLinear(in_features=0, out_features=400, bias=True) + (1): ELU(alpha=1.0) + ) + (mlp2): MLP( + (0): LazyLinear(in_features=0, out_features=300, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=300, out_features=1, bias=True) + ) + ) + >>> obs = torch.zeros(1, 32) + >>> action = torch.zeros(1, 4) + >>> value = net(obs, action) + >>> print(value.shape) + torch.Size([1, 1]) + """ def __init__( self, - mlp_net_kwargs_net1: Optional[dict] = None, - mlp_net_kwargs_net2: Optional[dict] = None, - device: Optional[DEVICE_TYPING] = None, + mlp_net_kwargs_net1: dict | None = None, + mlp_net_kwargs_net2: dict | None = None, + device: DEVICE_TYPING | None = None, ): super().__init__() mlp1_net_default_kwargs = { @@ -1239,7 +1523,8 @@ def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tens class LSTMNet(nn.Module): """An embedder for an LSTM preceded by an MLP. - The forward method returns the hidden states of the current state (input hidden states) and the output, as + The forward method returns the hidden states of the current state + (input hidden states) and the output, as the environment returns the 'observation' and 'next_observation'. Because the LSTM kernel only returns the last hidden state, hidden states @@ -1250,6 +1535,22 @@ class LSTMNet(nn.Module): with only one time step. This means that we explicitely assume that users will unsqueeze inputs of a single batch with multiple time steps. + Args: + out_features (int): number of output features. + lstm_kwargs (dict): the keyword arguments for the + :class:`~torch.nn.LSTM` layer. + mlp_kwargs (dict): the keyword arguments for the + :class:`~torchrl.modules.MLP` layer. + device (torch.device, optional): the device where the module should + be instantiated. + + Keyword Args: + lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that + indeicates where the LSTM class is to be retrieved. The ``"torchrl"`` + backend (:class:`~torchrl.modules.LSTM`) is slower but works with + :func:`~torch.vmap` and should work with :func:`~torch.compile`. + Defaults to ``"torch"``. + Examples: >>> batch = 7 >>> time_steps = 6 @@ -1274,7 +1575,9 @@ def __init__( out_features: int, lstm_kwargs: Dict, mlp_kwargs: Dict, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, + *, + lstm_backend: str | None = None, ) -> None: warnings.warn( "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed in v0.4.0.", @@ -1283,7 +1586,14 @@ def __init__( super().__init__() lstm_kwargs.update({"batch_first": True}) self.mlp = MLP(device=device, **mlp_kwargs) - self.lstm = nn.LSTM(device=device, **lstm_kwargs) + if lstm_backend is None: + lstm_backend = "torch" + self.lstm_backend = lstm_backend + if self.lstm_backend == "torch": + LSTM = nn.LSTM + else: + from torchrl.modules.tensordict_module.rnn import LSTM + self.lstm = LSTM(device=device, **lstm_kwargs) self.linear = nn.LazyLinear(out_features, device=device) def _lstm( @@ -1369,8 +1679,11 @@ def forward( class OnlineDTActor(nn.Module): """Online Decision Transformer Actor class. - Actor class for the Online Decision Transformer to sample actions from gaussian distribution as presented inresented in `"Online Decision Transformer" `. - Returns mu and sigma for the gaussian distribution to sample actions from. + Actor class for the Online Decision Transformer to sample actions from + gaussian distribution as presented inresented in + `"Online Decision Transformer" `_. + + Returns the mean and standard deviation for the gaussian distribution to sample actions from. Args: state_dim (int): state dimension. @@ -1378,7 +1691,7 @@ class OnlineDTActor(nn.Module): transformer_config (Dict or :class:`DecisionTransformer.DTConfig`): config for the GPT2 transformer. Defaults to :meth:`~.default_config`. - device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. + device (torch.device, optional): device to use. Defaults to None. Examples: >>> model = OnlineDTActor(state_dim=4, action_dim=2, @@ -1398,7 +1711,7 @@ def __init__( state_dim: int, action_dim: int, transformer_config: Dict | DecisionTransformer.DTConfig = None, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): super().__init__() if transformer_config is None: @@ -1451,7 +1764,7 @@ def forward( @classmethod def default_config(cls): - """Default configuration for :class:`~.OnlineDTActor`.""" + """Default configuration for :class:`~OnlineDTActor`.""" return DecisionTransformer.DTConfig( n_embd=512, n_layer=4, @@ -1467,7 +1780,8 @@ def default_config(cls): class DTActor(nn.Module): """Decision Transformer Actor class. - Actor class for the Decision Transformer to output deterministic action as presented in `"Decision Transformer" `. + Actor class for the Decision Transformer to output deterministic action as + presented in `"Decision Transformer" `. Returns the deterministic actions. Args: @@ -1476,7 +1790,7 @@ class DTActor(nn.Module): transformer_config (Dict or :class:`DecisionTransformer.DTConfig`, optional): config for the GPT2 transformer. Defaults to :meth:`~.default_config`. - device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. + device (torch.device, optional): device to use. Defaults to None. Examples: >>> model = DTActor(state_dim=4, action_dim=2, @@ -1495,7 +1809,7 @@ def __init__( state_dim: int, action_dim: int, transformer_config: Dict | DecisionTransformer.DTConfig = None, - device: Optional[DEVICE_TYPING] = None, + device: DEVICE_TYPING | None = None, ): super().__init__() if transformer_config is None: @@ -1532,7 +1846,7 @@ def forward( @classmethod def default_config(cls): - """Default configuration for :class:`~.DTActor`.""" + """Default configuration for :class:`~DTActor`.""" return DecisionTransformer.DTConfig( n_embd=512, n_layer=4, @@ -1543,3 +1857,14 @@ def default_config(cls): resid_pdrop=0.1, attn_pdrop=0.1, ) + + +def _iter_maybe_over_single(item: dict | List[dict] | None): + if item is None: + while True: + yield {} + elif isinstance(item, dict): + while True: + yield item + else: + yield from item diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 392a4ec4376..5cd928f3a91 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -12,7 +12,7 @@ from torchrl.data.utils import DEVICE_TYPING -from .exploration import NoisyLazyLinear, NoisyLinear +from torchrl.modules.models.exploration import NoisyLazyLinear, NoisyLinear LazyMapping = { nn.Linear: nn.LazyLinear, @@ -67,6 +67,14 @@ class SquashDims(nn.Module): Args: ndims_in (int): number of dimensions to be flattened. default = 3 + + Examples: + >>> from torchrl.modules.models.utils import SquashDims + >>> import torch + >>> x = torch.randn(1, 2, 3, 4) + >>> print(SquashDims()(x).shape) + torch.Size([1, 24]) + """ def __init__(self, ndims_in: int = 3): From 9fd635e3c6a5b42ffd98f1d674901b3ea57f37be Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Mar 2024 16:49:56 +0000 Subject: [PATCH 3/7] amend --- test/test_modules.py | 150 +++++++++++++++++++------------ torchrl/modules/models/models.py | 48 +++++++--- torchrl/modules/models/utils.py | 16 ++-- 3 files changed, 141 insertions(+), 73 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 94c8a809170..c2fd0cd35a9 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -59,65 +59,99 @@ def double_prec_fixture(): torch.set_default_dtype(dtype) -@pytest.mark.parametrize("in_features", [3, 10, None]) -@pytest.mark.parametrize("out_features", [3, (3, 10)]) -@pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))]) -@pytest.mark.parametrize( - "activation_class, activation_kwargs", - [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], -) -@pytest.mark.parametrize( - "norm_class, norm_kwargs", - [ - (nn.LazyBatchNorm1d, {}), - (nn.BatchNorm1d, {"num_features": 32}), - (nn.LayerNorm, {"normalized_shape": 32}), - ], -) -@pytest.mark.parametrize("dropout", [0.0, 0.5]) -@pytest.mark.parametrize("bias_last_layer", [True, False]) -@pytest.mark.parametrize("single_bias_last_layer", [True, False]) -@pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_mlp( - in_features, - out_features, - depth, - num_cells, - activation_class, - activation_kwargs, - dropout, - bias_last_layer, - norm_class, - norm_kwargs, - single_bias_last_layer, - layer_class, - device, - seed=0, -): - torch.manual_seed(seed) - batch = 2 - mlp = MLP( - in_features=in_features, - out_features=out_features, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, - activation_kwargs=activation_kwargs, - norm_class=norm_class, - norm_kwargs=norm_kwargs, - dropout=dropout, - bias_last_layer=bias_last_layer, - single_bias_last_layer=False, - layer_class=layer_class, - device=device, +class TestMLP: + @pytest.mark.parametrize("in_features", [3, 10, None]) + @pytest.mark.parametrize("out_features", [3, (3, 10)]) + @pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))]) + @pytest.mark.parametrize( + "activation_class, activation_kwargs", + [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], ) - if in_features is None: - in_features = 5 - x = torch.randn(batch, in_features, device=device) - y = mlp(x) - out_features = [out_features] if isinstance(out_features, Number) else out_features - assert y.shape == torch.Size([batch, *out_features]) + @pytest.mark.parametrize( + "norm_class, norm_kwargs", + [ + (nn.LazyBatchNorm1d, {}), + (nn.BatchNorm1d, {"num_features": 32}), + (nn.LayerNorm, {"normalized_shape": 32}), + ], + ) + @pytest.mark.parametrize("dropout", [0.0, 0.5]) + @pytest.mark.parametrize("bias_last_layer", [True, False]) + @pytest.mark.parametrize("single_bias_last_layer", [True, False]) + @pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_mlp( + self, + in_features, + out_features, + depth, + num_cells, + activation_class, + activation_kwargs, + dropout, + bias_last_layer, + norm_class, + norm_kwargs, + single_bias_last_layer, + layer_class, + device, + seed=0, + ): + torch.manual_seed(seed) + batch = 2 + mlp = MLP( + in_features=in_features, + out_features=out_features, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + activation_kwargs=activation_kwargs, + norm_class=norm_class, + norm_kwargs=norm_kwargs, + dropout=dropout, + bias_last_layer=bias_last_layer, + single_bias_last_layer=False, + layer_class=layer_class, + device=device, + ) + if in_features is None: + in_features = 5 + x = torch.randn(batch, in_features, device=device) + y = mlp(x) + out_features = ( + [out_features] if isinstance(out_features, Number) else out_features + ) + assert y.shape == torch.Size([batch, *out_features]) + + def test_kwargs(self): + def make_activation(shift): + return lambda x: x + shift + + def layer(*args, **kwargs): + linear = nn.Linear(*args, **kwargs) + linear.weight.data.copy_(torch.eye(4)) + return linear + + in_features = 4 + out_features = 4 + num_cells = [4, 4, 4] + mlp = MLP( + in_features=in_features, + out_features=out_features, + num_cells=num_cells, + activation_class=make_activation, + activation_kwargs=[{"shift": 0}, {"shift": 1}, {"shift": 2}], + layer_class=layer, + layer_kwargs=[{"bias": False}] * 4, + bias_last_layer=False, + ) + x = torch.zeros(4) + y = mlp(x) + for i, module in enumerate(mlp.modules()): + if isinstance(module, nn.Linear): + assert (module.weight == torch.eye(4)).all(), i + assert module.bias is None, i + assert (y == 3).all() @pytest.mark.parametrize("in_features", [3, 10, None]) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 5e0bfdb6af4..dda5da96a0d 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -7,11 +7,11 @@ import dataclasses import warnings +from copy import deepcopy from numbers import Number -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, Dict, List, Sequence, Tuple, Type, Union import torch -import torchrl.modules from tensordict.nn import dispatch, TensorDictModuleBase from torch import nn from torch.nn import functional as F @@ -168,7 +168,7 @@ def __init__( num_cells: Sequence[int] | int | None = None, activation_class: Type[nn.Module] | Callable = nn.Tanh, activation_kwargs: dict | List[dict] | None = None, - norm_class: Optional[Type[nn.Module] | Callable] = None, + norm_class: Type[nn.Module] | Callable | None = None, norm_kwargs: dict | List[dict] | None = None, dropout: float | None = None, bias_last_layer: bool = True, @@ -229,6 +229,10 @@ def __init__( consider matching or specifying a constant num_cells argument together with a a desired depth" ) layers = self._make_net(device) + layers = [ + layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer) + for layer in layers + ] super().__init__(*layers) def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: @@ -236,8 +240,10 @@ def _make_net(self, device: DEVICE_TYPING | None) -> List[nn.Module]: in_features = [self.in_features] + self.num_cells out_features = self.num_cells + [self._out_features_num] for i, (_in, _out) in enumerate(zip(in_features, out_features)): - _bias = self.bias_last_layer if i == self.depth else True layer_kwargs = next(self._layer_kwargs_iter) + _bias = layer_kwargs.pop( + "bias", self.bias_last_layer if i == self.depth else True + ) if _in is not None: layers.append( create_on_device( @@ -457,6 +463,10 @@ def __init__( self.depth = len(self.kernel_sizes) layers = self._make_net(device) + layers = [ + layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer) + for layer in layers + ] super().__init__(*layers) def _make_net(self, device: DEVICE_TYPING | None) -> nn.Module: @@ -694,6 +704,10 @@ def __init__( self.depth = len(self.kernel_sizes) layers = self._make_net(device) + layers = [ + layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer) + for layer in layers + ] super().__init__(*layers) def _make_net(self, device: DEVICE_TYPING | None) -> nn.Module: @@ -1599,8 +1613,8 @@ def __init__( def _lstm( self, input: torch.Tensor, - hidden0_in: Optional[torch.Tensor] = None, - hidden1_in: Optional[torch.Tensor] = None, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: squeeze0 = False squeeze1 = False @@ -1669,8 +1683,8 @@ def _lstm( def forward( self, input: torch.Tensor, - hidden0_in: Optional[torch.Tensor] = None, - hidden1_in: Optional[torch.Tensor] = None, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) @@ -1865,6 +1879,20 @@ def _iter_maybe_over_single(item: dict | List[dict] | None): yield {} elif isinstance(item, dict): while True: - yield item + yield deepcopy(item) else: - yield from item + yield from (deepcopy(_item) for _item in item) + + +class _ExecutableLayer(nn.Module): + """A thin wrapper around a function to be exectued as a module.""" + + def __init__(self, func): + super(_ExecutableLayer, self).__init__() + self.func = func + + def forward(self, *args, **kwargs): + return self.func(*args, **kwargs) + + def __repr__(self): + return f"{self.__class__.__name__}(func={self.func})" diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index 5cd928f3a91..0c650087235 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -2,10 +2,11 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import inspect import warnings -from typing import Optional, Sequence, Type +from typing import Callable, Sequence, Type import torch from torch import nn @@ -86,7 +87,7 @@ def forward(self, value: torch.Tensor) -> torch.Tensor: return value -def _find_depth(depth: Optional[int], *list_or_ints: Sequence): +def _find_depth(depth: int | None, *list_or_ints: Sequence): """Find depth based on a sequence of inputs and a depth indicator. If the depth is None, it is inferred by the length of one (or more) matching @@ -113,7 +114,10 @@ def _find_depth(depth: Optional[int], *list_or_ints: Sequence): def create_on_device( - module_class: Type[nn.Module], device: Optional[DEVICE_TYPING], *args, **kwargs + module_class: Type[nn.Module] | Callable, + device: DEVICE_TYPING | None, + *args, + **kwargs, ) -> nn.Module: """Create a new instance of :obj:`module_class` on :obj:`device`. @@ -130,8 +134,10 @@ def create_on_device( if "device" in fullargspec.args or "device" in fullargspec.kwonlyargs: return module_class(*args, device=device, **kwargs) else: - return module_class(*args, **kwargs).to(device) - # .to() is always available for nn.Module, and does nothing if the Module contains no parameters or buffers + result = module_class(*args, **kwargs) + if hasattr(result, "to"): + result = result.to(device) + return result def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool: From 40953c0d166d52c15306bde6e1737f8f394e7688 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Mar 2024 20:44:13 +0000 Subject: [PATCH 4/7] amend --- torchrl/modules/__init__.py | 3 +- torchrl/modules/models/__init__.py | 3 +- torchrl/modules/models/models.py | 107 +++++--------------- torchrl/modules/tensordict_module/actors.py | 3 +- torchrl/modules/tensordict_module/common.py | 74 +++++++++++++- 5 files changed, 102 insertions(+), 88 deletions(-) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 8f9dd1b13c1..970684bc329 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.common import DistributionalDQNnet + from .distributions import ( Delta, distributions_maps, @@ -24,7 +26,6 @@ DdpgMlpActor, DdpgMlpQNet, DecisionTransformer, - DistributionalDQNnet, DreamerActor, DTActor, DuelingCnnDQNet, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 518abca1f65..7e8ace40dcd 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.common import DistributionalDQNnet + from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior @@ -15,7 +17,6 @@ DdpgCnnQNet, DdpgMlpActor, DdpgMlpQNet, - DistributionalDQNnet, DTActor, DuelingCnnDQNet, DuelingMlpDQNet, diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index dda5da96a0d..2432e45d4cc 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -12,9 +12,7 @@ from typing import Callable, Dict, List, Sequence, Tuple, Type, Union import torch -from tensordict.nn import dispatch, TensorDictModuleBase from torch import nn -from torch.nn import functional as F from torchrl._utils import prod from torchrl.data.utils import DEVICE_TYPING @@ -27,6 +25,7 @@ Squeeze2dLayer, SqueezeLayer, ) +from torchrl.modules.tensordict_module.common import DistributionalDQNnet # noqa class MLP(nn.Sequential): @@ -206,9 +205,6 @@ def __init__( self.activation_kwargs = activation_kwargs self.norm_kwargs = norm_kwargs self.layer_kwargs = layer_kwargs - self._activation_kwargs_iter = _iter_maybe_over_single(activation_kwargs) - self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs) - self._layer_kwargs_iter = _iter_maybe_over_single(layer_kwargs) self.activate_last_layer = activate_last_layer if single_bias_last_layer: @@ -228,6 +224,14 @@ def __init__( "depth and num_cells length conflict, \ consider matching or specifying a constant num_cells argument together with a a desired depth" ) + + self._activation_kwargs_iter = _iter_maybe_over_single( + activation_kwargs, n=self.depth + ) + self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs, n=self.depth) + self._layer_kwargs_iter = _iter_maybe_over_single( + layer_kwargs, n=self.depth + 1 + ) layers = self._make_net(device) layers = [ layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer) @@ -430,8 +434,6 @@ def __init__( activation_kwargs if activation_kwargs is not None else {} ) self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - self._activation_kwargs_iter = _iter_maybe_over_single(activation_kwargs) - self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs) depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings) self.depth = depth @@ -462,6 +464,12 @@ def __init__( self.out_features = self.num_cells[-1] self.depth = len(self.kernel_sizes) + + self._activation_kwargs_iter = _iter_maybe_over_single( + activation_kwargs, n=self.depth + ) + self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs, n=self.depth) + layers = self._make_net(device) layers = [ layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer) @@ -668,8 +676,6 @@ def __init__( activation_kwargs if activation_kwargs is not None else {} ) self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - self._activation_kwargs_iter = _iter_maybe_over_single(activation_kwargs) - self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs) self.bias_last_layer = bias_last_layer self.aggregator_class = aggregator_class @@ -703,6 +709,12 @@ def __init__( self.out_features = self.num_cells[-1] self.depth = len(self.kernel_sizes) + + self._activation_kwargs_iter = _iter_maybe_over_single( + activation_kwargs, n=self.depth + ) + self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs, n=self.depth) + layers = self._make_net(device) layers = [ layer if isinstance(layer, nn.Module) else _ExecutableLayer(layer) @@ -1002,73 +1014,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return value + advantage - advantage.mean(dim=-1, keepdim=True) -class DistributionalDQNnet(TensorDictModuleBase): - """Distributional Deep Q-Network softmax layer. - - This layer should be used in between a regular model that predicts the - action values and a distribution which acts on logits values. - - Args: - in_keys (list of str or tuples of str): input keys to the log-softmax - operation. Defaults to ``["action_value"]``. - out_keys (list of str or tuples of str): output keys to the log-softmax - operation. Defaults to ``["action_value"]``. - - Examples: - >>> import torch - >>> from tensordict import TensorDict - >>> net = DistributionalDQNnet() - >>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10]) - >>> net(td) - TensorDict( - fields={ - action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([10]), - device=None, - is_shared=False) - - """ - - _wrong_out_feature_dims_error = ( - "DistributionalDQNnet requires dqn output to be at least " - "2-dimensional, with dimensions *Batch x #Atoms x #Actions. Got {0} " - "instead." - ) - - def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): - super().__init__() - if DQNet is not None: - warnings.warn( - f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.", - category=DeprecationWarning, - ) - if not ( - not isinstance(DQNet.out_features, Number) - and len(DQNet.out_features) > 1 - ): - raise RuntimeError(self._wrong_out_feature_dims_error) - self.dqn = DQNet - if in_keys is None: - in_keys = ["action_value"] - if out_keys is None: - out_keys = ["action_value"] - self.in_keys = in_keys - self.out_keys = out_keys - - @dispatch(auto_batch_size=False) - def forward(self, tensordict): - for in_key, out_key in zip(self.in_keys, self.out_keys): - q_values = tensordict.get(in_key) - if self.dqn is not None: - q_values = self.dqn(q_values) - if q_values.ndimension() < 2: - raise RuntimeError( - self._wrong_out_feature_dims_error.format(q_values.shape) - ) - tensordict.set(out_key, F.log_softmax(q_values, dim=-2)) - return tensordict - - def ddpg_init_last_layer( module: nn.Sequential, scale: float = 6e-4, @@ -1873,15 +1818,13 @@ def default_config(cls): ) -def _iter_maybe_over_single(item: dict | List[dict] | None): +def _iter_maybe_over_single(item: dict | List[dict] | None, n): if item is None: - while True: - yield {} + return iter([{} for _ in range(n)]) elif isinstance(item, dict): - while True: - yield deepcopy(item) + return iter([deepcopy(item) for _ in range(n)]) else: - yield from (deepcopy(_item) for _item in item) + return iter([deepcopy(_item) for _item in item]) class _ExecutableLayer(nn.Module): diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 8d9855283f5..c5010f8113d 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -21,8 +21,7 @@ from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.data.utils import _process_action_space_spec -from torchrl.modules.models.models import DistributionalDQNnet -from torchrl.modules.tensordict_module.common import SafeModule +from torchrl.modules.tensordict_module.common import DistributionalDQNnet, SafeModule from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 221ba3cde8d..edd3e5da97f 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -9,16 +9,19 @@ import inspect import re import warnings +from numbers import Number from typing import Iterable, List, Optional, Type, Union +import tensordict import torch -from tensordict import TensorDictBase, unravel_key_list +from tensordict import TensorDict, TensorDictBase, unravel_key_list -from tensordict.nn import TensorDictModule, TensorDictModuleBase +from tensordict.nn import dispatch, TensorDictModule, TensorDictModuleBase from tensordict.utils import NestedKey from torch import nn +from torch.nn import functional as F from torchrl.data.tensor_specs import CompositeSpec, TensorSpec @@ -466,3 +469,70 @@ def forward(self, tensordict): vmap_dim = ndim - 1 td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict) return tensordict.update(td) + + +class DistributionalDQNnet(TensorDictModuleBase): + """Distributional Deep Q-Network softmax layer. + + This layer should be used in between a regular model that predicts the + action values and a distribution which acts on logits values. + + Args: + in_keys (list of str or tuples of str): input keys to the log-softmax + operation. Defaults to ``["action_value"]``. + out_keys (list of str or tuples of str): output keys to the log-softmax + operation. Defaults to ``["action_value"]``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> net = DistributionalDQNnet() + >>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10]) + >>> net(td) + TensorDict( + fields={ + action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + + """ + + _wrong_out_feature_dims_error = ( + "DistributionalDQNnet requires dqn output to be at least " + "2-dimensional, with dimensions *Batch x #Atoms x #Actions. Got {0} " + "instead." + ) + + def __init__(self, *, in_keys=None, out_keys=None, DQNet: nn.Module = None): + super().__init__() + if DQNet is not None: + warnings.warn( + f"Passing a network to {type(self)} is going to be deprecated in v0.4.0.", + category=DeprecationWarning, + ) + if not ( + not isinstance(DQNet.out_features, Number) + and len(DQNet.out_features) > 1 + ): + raise RuntimeError(self._wrong_out_feature_dims_error) + self.dqn = DQNet + if in_keys is None: + in_keys = ["action_value"] + if out_keys is None: + out_keys = ["action_value"] + self.in_keys = in_keys + self.out_keys = out_keys + + @dispatch(auto_batch_size=False) + def forward(self, tensordict): + for in_key, out_key in zip(self.in_keys, self.out_keys): + q_values = tensordict.get(in_key) + if self.dqn is not None: + q_values = self.dqn(q_values) + if q_values.ndimension() < 2: + raise RuntimeError( + self._wrong_out_feature_dims_error.format(q_values.shape) + ) + tensordict.set(out_key, F.log_softmax(q_values, dim=-2)) + return tensordict From 17e4fe2181878a6f651f7ef4287bc30b11fce5ce Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Mar 2024 20:48:32 +0000 Subject: [PATCH 5/7] amend --- torchrl/modules/tensordict_module/common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index edd3e5da97f..8dd621c98b2 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -12,10 +12,9 @@ from numbers import Number from typing import Iterable, List, Optional, Type, Union -import tensordict import torch -from tensordict import TensorDict, TensorDictBase, unravel_key_list +from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import dispatch, TensorDictModule, TensorDictModuleBase from tensordict.utils import NestedKey From 17b96c89140ffab3fc9fb38cc6b7294747611fec Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 5 Mar 2024 08:15:50 +0000 Subject: [PATCH 6/7] amend --- torchrl/modules/models/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 2432e45d4cc..ef87dd07b8f 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -226,9 +226,11 @@ def __init__( ) self._activation_kwargs_iter = _iter_maybe_over_single( - activation_kwargs, n=self.depth + activation_kwargs, n=self.depth + self.activate_last_layer + ) + self._norm_kwargs_iter = _iter_maybe_over_single( + norm_kwargs, n=self.depth + self.activate_last_layer ) - self._norm_kwargs_iter = _iter_maybe_over_single(norm_kwargs, n=self.depth) self._layer_kwargs_iter = _iter_maybe_over_single( layer_kwargs, n=self.depth + 1 ) From 6a1dd2eb86fd1986e06a406bebe65ed55ba24646 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 5 Mar 2024 09:05:06 +0000 Subject: [PATCH 7/7] amend --- torchrl/modules/models/models.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index ef87dd07b8f..8e6fc75e12e 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -60,13 +60,13 @@ class MLP(nn.Sequential): class or constructor to be used. Defaults to :class:`~torch.nn.Tanh`. activation_kwargs (dict or list of dicts, optional): kwargs to be used - with the activation class. Aslo accepts a list of kwargs, one for - each layer. + with the activation class. Aslo accepts a list of kwargs of length + ``depth + int(activate_last_layer)``. norm_class (Type or callable, optional): normalization class or constructor, if any. norm_kwargs (dict or list of dicts, optional): kwargs to be used with - the normalization layers. Aslo accepts a list of kwargs, one for - each layer. + the normalization layers. Aslo accepts a list of kwargs of length + ``depth + int(activate_last_layer)``. dropout (float, optional): dropout probability. Defaults to ``None`` (no dropout); bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. @@ -77,7 +77,7 @@ class or constructor to be used. layer_class (Type[nn.Module] or callable, optional): class to be used for the linear layers; layer_kwargs (dict or list of dicts, optional): kwargs for the linear - layers. Aslo accepts a list of kwargs, one for each layer. + layers. Aslo accepts a list of kwargs of length ``depth + 1``. activate_last_layer (bool): whether the MLP output should be activated. This is useful when the MLP output is used as the input for another module. default: False. @@ -330,13 +330,13 @@ class ConvNet(nn.Sequential): class or constructor to be used. Defaults to :class:`~torch.nn.Tanh`. activation_kwargs (dict or list of dicts, optional): kwargs to be used - with the activation class. A list of kwargs can also be passed, - with one element per layer. + with the activation class. A list of kwargs of length ``depth`` + can also be passed, with one element per layer. norm_class (Type or callable, optional): normalization class or constructor, if any. norm_kwargs (dict or list of dicts, optional): kwargs to be used with - the normalization layers. A list of kwargs can also be passed, - with one element per layer. + the normalization layers. A list of kwargs of length ``depth`` can + also be passed, with one element per layer. bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. Defaults to ``True``. aggregator_class (Type[nn.Module] or callable): aggregator class or @@ -578,12 +578,12 @@ class Conv3dNet(nn.Sequential): activation_class (Type[nn.Module] or callable): activation class or constructor to be used. Defaults to :class:`~torch.nn.Tanh`. activation_kwargs (dict or list of dicts, optional): kwargs to be used - with the activation class. A list of kwargs with one element per - layer can also be provided. + with the activation class. A list of kwargs of length ``depth`` + with one element per layer can also be provided. norm_class (Type or callable, optional): normalization class, if any. norm_kwargs (dict or list of dicts, optional): kwargs to be used with - the normalization layers. A list of kwargs with one element per - layer can also be provided. + the normalization layers. A list of kwargs of length ``depth`` + with one element per layer can also be provided. bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. Defaults to ``True``. aggregator_class (Type[nn.Module] or callable): aggregator class or