diff --git a/test/test_cost.py b/test/test_cost.py index 9b8c2e15c0e..fbba7ac4c5a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -2960,6 +2960,34 @@ def __init__(self, actor_network, qvalue_network): class TestAdv: + @pytest.mark.parametrize( + "adv,kwargs", + [[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]], + ) + def test_dispatch( + self, + adv, + kwargs, + ): + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } + advantage, value_target = module(**kwargs) + assert advantage.shape == torch.Size([1, 10, 1]) + assert value_target.shape == torch.Size([1, 10, 1]) + @pytest.mark.parametrize( "adv,kwargs", [[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]], diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 8fd344b1b71..d8d7c128857 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -4,23 +4,24 @@ # LICENSE file in the root directory of this source tree. from functools import wraps -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch +from tensordict.nn import dispatch_kwargs from tensordict.tensordict import TensorDictBase from torch import nn, Tensor from torchrl.envs.utils import step_mdp from torchrl.modules import SafeModule + +from torchrl.objectives.utils import hold_out_net from torchrl.objectives.value.functional import ( + td_advantage_estimate, td_lambda_advantage_estimate, vec_generalized_advantage_estimate, vec_td_lambda_advantage_estimate, ) -from ..utils import hold_out_net -from .functional import td_advantage_estimate - def _self_set_grad_enabled(fun): @wraps(fun) @@ -41,8 +42,11 @@ class TDEstimate(nn.Module): before the TD is computed. differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. - value_key (str, optional): key pointing to the state value. Default is - `"state_value"`. + advantage_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "advantage". + value_target_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "value_target". + """ def __init__( @@ -51,7 +55,8 @@ def __init__( value_network: SafeModule, average_rewards: bool = False, differentiable: bool = False, - value_key: str = "state_value", + advantage_key: Union[str, Tuple] = "advantage", + value_target_key: Union[str, Tuple] = "value_target", ): super().__init__() try: @@ -63,7 +68,17 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = value_key + self.value_key = value_network.out_keys[0] + + self.advantage_key = advantage_key + self.value_target_key = value_target_key + + self.in_keys = ( + value_network.in_keys + + ["reward", "done"] + + [("next", in_key) for in_key in value_network.in_keys] + ) + self.out_keys = [self.advantage_key, self.value_target_key] @property def is_functional(self): @@ -73,22 +88,64 @@ def is_functional(self): ) @_self_set_grad_enabled + @dispatch_kwargs def forward( self, tensordict: TensorDictBase, - *unused_args, - params: Optional[List[Tensor]] = None, - target_params: Optional[List[Tensor]] = None, + params: Optional[TensorDictBase] = None, + target_params: Optional[TensorDictBase] = None, ) -> TensorDictBase: - """Computes the GAE given the data in tensordict. + """Computes the TDEstimate given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. Args: tensordict (TensorDictBase): A TensorDict containing the data (an observation key, "action", "reward", "done" and "next" tensordict state as returned by the environment) necessary to compute the value estimates and the TDEstimate. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. Returns: - An updated TensorDict with an "advantage" and a "value_error" keys + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDEstimate( + ... gamma=0.98, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDEstimate( + ... gamma=0.98, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) """ if tensordict.batch_dims < 1: @@ -111,7 +168,7 @@ def forward( "Expected params to be passed to advantage module but got none." ) if params is not None: - kwargs["params"] = params + kwargs["params"] = params.detach() with hold_out_net(self.value_network): self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) @@ -146,10 +203,13 @@ class TDLambdaEstimate(nn.Module): before the TD is computed. differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. - value_key (str, optional): key pointing to the state value. Default is - `"state_value"`. vectorized (bool, optional): whether to use the vectorized version of the lambda return. Default is `True`. + advantage_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "advantage". + value_target_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "value_target". + """ def __init__( @@ -159,8 +219,9 @@ def __init__( value_network: SafeModule, average_rewards: bool = False, differentiable: bool = False, - value_key: str = "state_value", vectorized: bool = True, + advantage_key: Union[str, Tuple] = "advantage", + value_target_key: Union[str, Tuple] = "value_target", ): super().__init__() try: @@ -174,7 +235,17 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = value_key + self.value_key = self.value_network.out_keys[0] + + self.advantage_key = advantage_key + self.value_target_key = value_target_key + + self.in_keys = ( + value_network.in_keys + + ["reward", "done"] + + [("next", in_key) for in_key in value_network.in_keys] + ) + self.out_keys = [self.advantage_key, self.value_target_key] @property def is_functional(self): @@ -184,23 +255,66 @@ def is_functional(self): ) @_self_set_grad_enabled + @dispatch_kwargs def forward( self, tensordict: TensorDictBase, - *unused_args, params: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, ) -> TensorDictBase: - """Computes the GAE given the data in tensordict. + """Computes the TDLambdaEstimate given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. Args: tensordict (TensorDictBase): A TensorDict containing the data (an observation key, "action", "reward", "done" and "next" tensordict state - as returned by the environment) necessary to compute the value - estimates and the TDLambdaEstimate. + as returned by the environment) necessary to compute the value estimates and the TDLambdaEstimate. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. Returns: - An updated TensorDict with an "advantage" and a "value_error" keys + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDLambdaEstimate( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDLambdaEstimate( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) """ if tensordict.batch_dims < 1: @@ -252,8 +366,8 @@ def forward( gamma, lmbda, value, next_value, reward, done ) - tensordict.set("advantage", adv) - tensordict.set("value_target", adv + value) + tensordict.set(self.advantage_key, adv) + tensordict.set(self.value_target_key, adv + value) return tensordict @@ -271,6 +385,10 @@ class GAE(nn.Module): Default is :obj:`False`. differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. + advantage_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "advantage". + value_target_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "value_target". GAE will return an :obj:`"advantage"` entry containing the advange value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -288,6 +406,8 @@ def __init__( value_network: SafeModule, average_gae: bool = False, differentiable: bool = False, + advantage_key: Union[str, Tuple] = "advantage", + value_target_key: Union[str, Tuple] = "value_target", ): super().__init__() try: @@ -297,10 +417,21 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.value_network = value_network + self.value_key = self.value_network.out_keys[0] self.average_gae = average_gae self.differentiable = differentiable + self.advantage_key = advantage_key + self.value_target_key = value_target_key + + self.in_keys = ( + value_network.in_keys + + ["reward", "done"] + + [("next", in_key) for in_key in value_network.in_keys] + ) + self.out_keys = [self.advantage_key, self.value_target_key] + @property def is_functional(self): return ( @@ -309,6 +440,7 @@ def is_functional(self): ) @_self_set_grad_enabled + @dispatch_kwargs def forward( self, tensordict: TensorDictBase, @@ -318,13 +450,57 @@ def forward( ) -> TensorDictBase: """Computes the GAE given the data in tensordict. + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. + Args: tensordict (TensorDictBase): A TensorDict containing the data (an observation key, "action", "reward", "done" and "next" tensordict state as returned by the environment) necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. Returns: - An updated TensorDict with an "advantage" and a "value_error" keys + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = GAE( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = GAE( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) """ if tensordict.batch_dims < 1: @@ -346,7 +522,7 @@ def forward( # value net params self.value_network(tensordict, **kwargs) - value = tensordict.get("state_value") + value = tensordict.get(self.value_key) step_td = step_mdp(tensordict) if target_params is not None: @@ -358,7 +534,7 @@ def forward( # we may still need to pass gradient, but we don't want to assign grads to # value net params self.value_network(step_td, **kwargs) - next_value = step_td.get("state_value") + next_value = step_td.get(self.value_key) done = tensordict.get("done") adv, value_target = vec_generalized_advantage_estimate( gamma, lmbda, value, next_value, reward, done @@ -368,7 +544,7 @@ def forward( adv = adv - adv.mean() adv = adv / adv.std().clamp_min(1e-4) - tensordict.set("advantage", adv) - tensordict.set("value_target", value_target) + tensordict.set(self.advantage_key, adv) + tensordict.set(self.value_target_key, value_target) return tensordict