Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 52 additions & 159 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ Some algorithms such as PPO require a probabilistic policy to be implemented.
In TorchRL, these policies take the form of a model, followed by a distribution
constructor.

.. note::
The choice of a probabilistic or regular actor class depends on the algorithm
.. note:: The choice of a probabilistic or regular actor class depends on the algorithm
that is being implemented. On-policy algorithms usually require a probabilistic
actor, off-policy usually have a deterministic actor with an extra exploration
strategy. There are, however, many exceptions to this rule.
Expand All @@ -103,8 +102,12 @@ and outputs the parameters of a distribution, while the distribution constructor
reads these parameters and gets a random sample from the distribution and/or
provides a :class:`torch.distributions.Distribution` object.

>>> from tensordict.nn import NormalParamExtractor, TensorDictSequential
>>> from tensordict.nn import NormalParamExtractor, TensorDictSequential, TensorDictModule
>>> from torchrl.modules import SafeProbabilisticModule
>>> from torchrl.envs import GymEnv
>>> from torch.distributions import Normal
>>> from torch import nn
>>>
>>> env = GymEnv("Pendulum-v1")
>>> action_spec = env.action_spec
>>> model = nn.Sequential(nn.LazyLinear(action_spec.shape[-1] * 2), NormalParamExtractor())
Expand All @@ -125,6 +128,7 @@ provides a :class:`torch.distributions.Distribution` object.
To facilitate the construction of probabilistic policies, we provide a dedicated
:class:`~torchrl.modules.tensordict_module.ProbabilisticActor`:

>>> from torchrl.modules import ProbabilisticActor
>>> policy = ProbabilisticActor(
... model,
... in_keys=["loc", "scale"],
Expand Down Expand Up @@ -154,69 +158,31 @@ of this action.
Q-Value actors
~~~~~~~~~~~~~~

Q-Value actors are a special type of policy that does not directly predict an action
from an observation, but picks the action that maximised the value (or *quality*)
of a (s,a) -> v map. This map can be a table or a function.
For discrete action spaces with continuous (or near-continuous such as pixels)
states, it is customary to use a non-linear model such as a neural network for
the map.
The semantic of the Q-Value network is hopefully quite simple: we just need to
feed a tensor-to-tensor map that given a certain state (the input tensor),
outputs a list of action values to choose from. The wrapper will write the
resulting action in the input tensordict along with the list of action values.
Q-Value actors are a type of policy that selects actions based on the maximum value
(or "quality") of a state-action pair. This value can be represented as a table or a
function. For discrete action spaces with continuous states, it's common to use a non-linear
model like a neural network to represent this function.

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
>>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> module = nn.Linear(3, 4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
QValueActor
^^^^^^^^^^^

Distributional Q-learning is slightly different: in this case, the value network
does not output a scalar value for each state-action value.
Instead, the value space is divided in a an arbitrary number of "bins". The
value network outputs a probability that the state-action value belongs to one bin
or another.
Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
the value network encodes a
of a (s,a) -> v map. This map can be a table or a function.
For discrete action spaces with continuous (or near-continuous such as pixels)
states, it is customary to use a non-linear model such as a neural network for
the map.
The semantic of the Q-Value network is hopefully quite simple: we just need to
feed a tensor-to-tensor map that given a certain state (the input tensor),
outputs a list of action values to choose from. The wrapper will write the
resulting action in the input tensordict along with the list of action values.
The :class:`~torchrl.modules.QValueActor` class takes in a module and an action
specification, and outputs the selected action and its corresponding value.

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> # Create a tensor dict with an observation
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
>>> # Define the action space
>>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> # Create a linear module to output action values
>>> module = nn.Linear(3, 4)
>>> # Create a QValueActor instance
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> # Run the actor on the tensor dict
>>> qvalue_actor(td)
>>> print(td)
TensorDict(
Expand All @@ -229,122 +195,48 @@ resulting action in the input tensordict along with the list of action values.
device=None,
is_shared=False)

Distributional Q-learning is slightly different: in this case, the value network
does not output a scalar value for each state-action value.
Instead, the value space is divided in a an arbitrary number of "bins". The
value network outputs a probability that the state-action value belongs to one bin
or another.
Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
the value network encodes a
of a (s,a) -> v map. This map can be a table or a function.
For discrete action spaces with continuous (or near-continuous such as pixels)
states, it is customary to use a non-linear model such as a neural network for
the map.
The semantic of the Q-Value network is hopefully quite simple: we just need to
feed a tensor-to-tensor map that given a certain state (the input tensor),
outputs a list of action values to choose from. The wrapper will write the
resulting action in the input tensordict along with the list of action values.
This will output a tensor dict with the selected action and its corresponding value.

Distributional Q-Learning
^^^^^^^^^^^^^^^^^^^^^^^^^

Distributional Q-learning is a variant of Q-learning that represents the value function as a
probability distribution over possible values, rather than a single scalar value.
This allows the agent to learn about the uncertainty in the environment and make more informed
decisions.
In TorchRL, distributional Q-learning is implemented using the :class:`~torchrl.modules.DistributionalQValueActor`
class. This class takes in a module, an action specification, and a support vector, and outputs the selected
action and its corresponding value distribution.


>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> # Create a tensor dict with an observation
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # Define the action space
>>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> module = nn.Linear(3, 4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> qvalue_actor(td)
>>> # Define the number of bins for the value distribution
>>> nbins = 3
>>> # Create an MLP module to output logits for the value distribution
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> # Create a DistributionalQValueActor instance
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> # Run the actor on the tensor dict
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

Distributional Q-learning is slightly different: in this case, the value network
does not output a scalar value for each state-action value.
Instead, the value space is divided in a an arbitrary number of "bins". The
value network outputs a probability that the state-action value belongs to one bin
or another.
Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
the value network encodes a :math:`\mathbb{R}^{M} \rightarrow \mathbb{R}^{N \times B}`
map. The following example shows how this works in TorchRL with the :class:`~torchrl.modules.tensordict_module.DistributionalQValueActor`
class:

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

This will output a tensor dict with the selected action and its corresponding value distribution.

.. currentmodule:: torchrl.modules.tensordict_module

Expand Down Expand Up @@ -403,11 +295,10 @@ without shared parameters. It is mainly intended as a replacement for

Domain-specific TensorDict modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torchrl.modules.tensordict_module

These modules include dedicated solutions for MBRL or RLHF pipelines.

.. currentmodule:: torchrl.modules.tensordict_module

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst
Expand Down Expand Up @@ -558,9 +449,11 @@ Some distributions are typically used in RL scripts.

Utils
-----

.. currentmodule:: torchrl.modules.utils

The module utils include functionals used to do some custom mappings as well as a tool to
build :class:`~torchrl.envs.TensorDictPrimer` instances from a given module.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
elif spec is not None and not isinstance(spec, Composite):
if len(self.out_keys) > 1:
if len(self.out_keys) - return_log_prob > 1:
raise RuntimeError(
f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
"Consider using a Composite object or no spec at all."
Expand Down
Loading