From 925b601d8cf2ade62900cf47a51a1391df7653e5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 14 Dec 2022 15:08:13 +0000 Subject: [PATCH 1/6] init --- examples/ppo/ppo.py | 3 +- test/test_cost.py | 107 ++++++++-- torchrl/objectives/ppo.py | 48 ++--- torchrl/objectives/value/advantages.py | 266 +++++++++++++------------ 4 files changed, 244 insertions(+), 180 deletions(-) diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index 846ec759bbd..70505ae0f47 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -175,8 +175,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.gamma, cfg.lmbda, value_network=critic_model, - average_rewards=True, - gradient_mode=False, + average_gae=True, ) trainer.register_op( "process_optim_batch", diff --git a/test/test_cost.py b/test/test_cost.py index c22efec6aff..e249e0efce1 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import re from copy import deepcopy _has_functorch = True @@ -21,7 +22,7 @@ import torch from _utils_internal import dtype_fixture, get_available_devices # noqa from mocking_classes import ContinuousActionConvMockEnv -from tensordict.nn import get_functional +from tensordict.nn import get_functional, TensorDictModule # from torchrl.data.postprocs.utils import expand_as_right from tensordict.tensordict import assert_allclose_td, TensorDict @@ -1597,23 +1598,25 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage): value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td": advantage = TDEstimate( - gamma=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td_lambda": advantage = TDLambdaEstimate( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) else: raise NotImplementedError - loss_fn = loss_class( - actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" - ) - + loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2") + with pytest.raises( + KeyError, match=re.escape('key "advantage" not found in TensorDict with') + ): + _ = loss_fn(td) + advantage(td) loss = loss_fn(td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) @@ -1659,20 +1662,17 @@ def test_ppo_shared(self, loss_class, device, advantage): gamma=0.9, lmbda=0.9, value_network=value, - gradient_mode=False, ) elif advantage == "td": advantage = TDEstimate( gamma=0.9, value_network=value, - gradient_mode=False, ) elif advantage == "td_lambda": advantage = TDLambdaEstimate( gamma=0.9, lmbda=0.9, value_network=value, - gradient_mode=False, ) else: raise NotImplementedError @@ -1681,9 +1681,13 @@ def test_ppo_shared(self, loss_class, device, advantage): value, gamma=0.9, loss_critic_type="l2", - advantage_module=advantage, ) + with pytest.raises( + KeyError, match=re.escape('key "advantage" not found in TensorDict with') + ): + _ = loss_fn(td) + advantage(td) loss = loss_fn(td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) @@ -1731,29 +1735,33 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td": advantage = TDEstimate( - gamma=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td_lambda": advantage = TDLambdaEstimate( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) else: raise NotImplementedError - loss_fn = loss_class( - actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" - ) + loss_fn = loss_class(actor, value, gamma=0.9, loss_critic_type="l2") floss_fn, params, buffers = make_functional_with_buffers(loss_fn) # fill params with zero for p in params: p.data.zero_() # assert len(list(floss_fn.parameters())) == 0 + with pytest.raises( + KeyError, match=re.escape('key "advantage" not found in TensorDict with') + ): + _ = floss_fn(params, buffers, td) + advantage(td) loss = floss_fn(params, buffers, td) + loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -2948,6 +2956,69 @@ def __init__(self, actor_network, qvalue_network): break +class TestAdv: + @pytest.mark.parametrize( + "adv,kwargs", + [[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]], + ) + def test_diff_reward( + self, + adv, + kwargs, + ): + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next": {"obs": torch.randn(1, 10, 3)}, + }, + [1, 10], + ) + td = module(td.clone(False)) + # check that the advantage can't backprop to the value params + td["advantage"].sum().backward() + for p in value_net.parameters(): + assert p.grad is None or (p.grad == 0).all() + # check that rewards have a grad + assert td["reward"].grad.norm() > 0 + + @pytest.mark.parametrize( + "adv,kwargs", + [[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]], + ) + def test_non_differentiable(self, adv, kwargs): + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next": {"obs": torch.randn(1, 10, 3)}, + }, + [1, 10], + ) + td = module(td.clone(False)) + assert td["advantage"].is_leaf + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 903e54a7810..26b73e70612 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Callable, Optional, Tuple +from typing import Tuple import torch from tensordict.tensordict import TensorDict, TensorDictBase @@ -55,14 +55,13 @@ def __init__( actor: SafeProbabilisticSequential, critic: SafeModule, advantage_key: str = "advantage", - advantage_diff_key: str = "value_error", + value_target_key: str = "value_target", entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coef: float = 0.01, critic_coef: float = 1.0, gamma: float = 0.99, loss_critic_type: str = "smooth_l1", - advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() self.convert_to_functional( @@ -72,7 +71,7 @@ def __init__( # params of critic must be refs to actor if they're shared self.convert_to_functional(critic, "critic", compare_against=self.actor_params) self.advantage_key = advantage_key - self.advantage_diff_key = advantage_diff_key + self.value_target_key = value_target_key self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus and entropy_coef self.register_buffer( @@ -83,9 +82,6 @@ def __init__( ) self.register_buffer("gamma", torch.tensor(gamma, device=self.device)) self.loss_critic_type = loss_critic_type - self.advantage_module = advantage_module - if self.advantage_module is not None: - self.advantage_module = advantage_module.to(self.device) def reset(self) -> None: pass @@ -119,35 +115,29 @@ def _log_weight( return log_weight, dist def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: - if self.advantage_diff_key in tensordict.keys(): - advantage_diff = tensordict.get(self.advantage_diff_key) - if not advantage_diff.requires_grad: - raise RuntimeError( - "value_target retrieved from tensordict does not requires grad." - ) - loss_value = distance_loss( - advantage_diff, - torch.zeros_like(advantage_diff), - loss_function=self.loss_critic_type, - ) - else: - advantage = tensordict.get(self.advantage_key) + try: + target_return = tensordict.get(self.value_target_key) tensordict_select = tensordict.select(*self.critic.in_keys) - value = self.critic( + state_value = self.critic( tensordict_select, params=self.critic_params, ).get("state_value") - value_target = advantage + value.detach() loss_value = distance_loss( - value, value_target, loss_function=self.loss_critic_type + target_return, + state_value, + loss_function=self.loss_critic_type, + ) + except KeyError: + raise KeyError( + f"the key {self.value_target_key} was not found in the input tensordict. " + f"Make sure you provided the right key and the value_target (i.e. the target " + f"return) has been retrieved accordingly. Advantage classes such as GAE, " + f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that " + f"can be used for the value loss." ) return self.critic_coef * loss_value def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.advantage_module is not None: - tensordict = self.advantage_module( - tensordict, - ) tensordict = tensordict.clone() advantage = tensordict.get(self.advantage_key) log_weight, dist = self._log_weight(tensordict) @@ -226,8 +216,6 @@ def _clip_bounds(self): ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.advantage_module is not None: - tensordict = self.advantage_module(tensordict) tensordict = tensordict.clone() advantage = tensordict.get(self.advantage_key) log_weight, dist = self._log_weight(tensordict) @@ -349,8 +337,6 @@ def __init__( self.samples_mc_kl = samples_mc_kl def forward(self, tensordict: TensorDictBase) -> TensorDict: - if self.advantage_module is not None: - tensordict = self.advantage_module(tensordict) tensordict = tensordict.clone() advantage = tensordict.get(self.advantage_key) log_weight, dist = self._log_weight(tensordict) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index a4ecd994a54..32d06f2cdff 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from functools import wraps from typing import List, Optional, Union import torch @@ -21,6 +22,15 @@ from .functional import td_advantage_estimate +def _self_set_grad_enabled(fun): + @wraps(fun) + def new_fun(self, *args, **kwargs): + with torch.set_grad_enabled(self.differentiable): + return fun(self, *args, **kwargs) + + return new_fun + + class TDEstimate(nn.Module): """Temporal Difference estimate of advantage function. @@ -29,7 +39,7 @@ class TDEstimate(nn.Module): value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. - gradient_mode (bool, optional): if True, gradients are propagated throught + differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. value_key (str, optional): key pointing to the state value. Default is `"state_value"`. @@ -40,7 +50,7 @@ def __init__( gamma: Union[float, torch.Tensor], value_network: SafeModule, average_rewards: bool = False, - gradient_mode: bool = False, + differentiable: bool = False, value_key: str = "state_value", ): super().__init__() @@ -48,7 +58,7 @@ def __init__( self.value_network = value_network self.average_rewards = average_rewards - self.gradient_mode = gradient_mode + self.differentiable = differentiable self.value_key = value_key @property @@ -58,6 +68,7 @@ def is_functional(self): and self.value_network.__dict__["_is_stateless"] ) + @_self_set_grad_enabled def forward( self, tensordict: TensorDictBase, @@ -75,49 +86,47 @@ def forward( An updated TensorDict with an "advantage" and a "value_error" keys """ - with torch.set_grad_enabled(self.gradient_mode): - if tensordict.batch_dims < 1: - raise RuntimeError( - "Expected input tensordict to have at least one dimensions, got" - f"tensordict.batch_size = {tensordict.batch_size}" - ) - reward = tensordict.get("reward") - if self.average_rewards: - reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) - tensordict.set( - "reward", reward - ) # we must update the rewards if they are used later in the code - - gamma = self.gamma - kwargs = {} - if self.is_functional and params is None: - raise RuntimeError( - "Expected params to be passed to advantage module but got none." - ) - if params is not None: - kwargs["params"] = params + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got" + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get("reward") + if self.average_rewards: + reward = reward - reward.mean() + reward = reward / reward.std().clamp_min(1e-4) + tensordict.set( + "reward", reward + ) # we must update the rewards if they are used later in the code + + gamma = self.gamma + kwargs = {} + if self.is_functional and params is None: + raise RuntimeError( + "Expected params to be passed to advantage module but got none." + ) + if params is not None: + kwargs["params"] = params + with hold_out_net(self.value_network): self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) + # we may still need to pass gradient, but we don't want to assign grads to + # value net params + step_td = step_mdp(tensordict) + if target_params is not None: + # we assume that target parameters are not differentiable + kwargs["params"] = target_params + elif "params" in kwargs: + kwargs["params"] = [param.detach() for param in kwargs["params"]] with hold_out_net(self.value_network): - # we may still need to pass gradient, but we don't want to assign grads to - # value net params - step_td = step_mdp(tensordict) - if target_params is not None: - # we assume that target parameters are not differentiable - kwargs["params"] = target_params - elif "params" in kwargs: - kwargs["params"] = [param.detach() for param in kwargs["params"]] self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) done = tensordict.get("done") - with torch.set_grad_enabled(self.gradient_mode): - adv = td_advantage_estimate(gamma, value, next_value, reward, done) - tensordict.set("advantage", adv.detach()) - if self.gradient_mode: - tensordict.set("value_error", adv) + adv = td_advantage_estimate(gamma, value, next_value, reward, done) + tensordict.set("advantage", adv) + tensordict.set("value_target", adv + value) return tensordict @@ -130,8 +139,8 @@ class TDLambdaEstimate(nn.Module): value_network (SafeModule): value operator used to retrieve the value estimates. average_rewards (bool, optional): if True, rewards will be standardized before the TD is computed. - gradient_mode (bool, optional): if True, gradients are propagated throught - the computation of the value function. Default is `False`. + differentiable (bool, optional): if True, gradients are propagated throught + the computation of the value function. Default is :obj:`False`. value_key (str, optional): key pointing to the state value. Default is `"state_value"`. vectorized (bool, optional): whether to use the vectorized version of the @@ -144,7 +153,7 @@ def __init__( lmbda: Union[float, torch.Tensor], value_network: SafeModule, average_rewards: bool = False, - gradient_mode: bool = False, + differentiable: bool = False, value_key: str = "state_value", vectorized: bool = True, ): @@ -155,7 +164,7 @@ def __init__( self.vectorized = vectorized self.average_rewards = average_rewards - self.gradient_mode = gradient_mode + self.differentiable = differentiable self.value_key = value_key @property @@ -165,6 +174,7 @@ def is_functional(self): and self.value_network.__dict__["_is_stateless"] ) + @_self_set_grad_enabled def forward( self, tensordict: TensorDictBase, @@ -182,59 +192,57 @@ def forward( An updated TensorDict with an "advantage" and a "value_error" keys """ - with torch.set_grad_enabled(self.gradient_mode): - if tensordict.batch_dims < 1: - raise RuntimeError( - "Expected input tensordict to have at least one dimensions, got" - f"tensordict.batch_size = {tensordict.batch_size}" - ) - reward = tensordict.get("reward") - if self.average_rewards: - reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) - tensordict.set( - "reward", reward - ) # we must update the rewards if they are used later in the code - - gamma = self.gamma - lmbda = self.lmbda - - kwargs = {} - if self.is_functional and params is None: - raise RuntimeError( - "Expected params to be passed to advantage module but got none." - ) - if params is not None: - kwargs["params"] = params + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got" + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get("reward") + if self.average_rewards: + reward = reward - reward.mean() + reward = reward / reward.std().clamp_min(1e-4) + tensordict.set( + "reward", reward + ) # we must update the rewards if they are used later in the code + + gamma = self.gamma + lmbda = self.lmbda + + kwargs = {} + if self.is_functional and params is None: + raise RuntimeError( + "Expected params to be passed to advantage module but got none." + ) + if params is not None: + kwargs["params"] = params + with hold_out_net(self.value_network): self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) + step_td = step_mdp(tensordict) + if target_params is not None: + # we assume that target parameters are not differentiable + kwargs["params"] = target_params + elif "params" in kwargs: + kwargs["params"] = [param.detach() for param in kwargs["params"]] with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params - step_td = step_mdp(tensordict) - if target_params is not None: - # we assume that target parameters are not differentiable - kwargs["params"] = target_params - elif "params" in kwargs: - kwargs["params"] = [param.detach() for param in kwargs["params"]] self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) done = tensordict.get("done") - with torch.set_grad_enabled(self.gradient_mode): - if self.vectorized: - adv = vec_td_lambda_advantage_estimate( - gamma, lmbda, value, next_value, reward, done - ) - else: - adv = td_lambda_advantage_estimate( - gamma, lmbda, value, next_value, reward, done - ) - - tensordict.set("advantage", adv.detach()) - if self.gradient_mode: - tensordict.set("value_error", adv) + if self.vectorized: + adv = vec_td_lambda_advantage_estimate( + gamma, lmbda, value, next_value, reward, done + ) + else: + adv = td_lambda_advantage_estimate( + gamma, lmbda, value, next_value, reward, done + ) + + tensordict.set("advantage", adv) + tensordict.set("value_target", adv + value) return tensordict @@ -248,9 +256,10 @@ class GAE(nn.Module): gamma (scalar): exponential mean discount. lmbda (scalar): trajectory discount. value_network (SafeModule): value operator used to retrieve the value estimates. - average_rewards (bool): if True, rewards will be standardized before the GAE is computed. - gradient_mode (bool): if True, gradients are propagated throught the computation of the value function. - Default is `False`. + average_gae (bool): if True, the resulting GAE values will be standardized. + Default is :obj:`False`. + differentiable (bool, optional): if True, gradients are propagated throught + the computation of the value function. Default is :obj:`False`. GAE will return an :obj:`"advantage"` entry containing the advange value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -266,16 +275,16 @@ def __init__( gamma: Union[float, torch.Tensor], lmbda: float, value_network: SafeModule, - average_rewards: bool = False, - gradient_mode: bool = False, + average_gae: bool = False, + differentiable: bool = False, ): super().__init__() self.register_buffer("gamma", torch.tensor(gamma)) self.register_buffer("lmbda", torch.tensor(lmbda)) self.value_network = value_network - self.average_rewards = average_rewards - self.gradient_mode = gradient_mode + self.average_gae = average_gae + self.differentiable = differentiable @property def is_functional(self): @@ -284,6 +293,7 @@ def is_functional(self): and self.value_network.__dict__["_is_stateless"] ) + @_self_set_grad_enabled def forward( self, tensordict: TensorDictBase, @@ -301,50 +311,48 @@ def forward( An updated TensorDict with an "advantage" and a "value_error" keys """ - with torch.set_grad_enabled(self.gradient_mode): - if tensordict.batch_dims < 1: - raise RuntimeError( - "Expected input tensordict to have at least one dimensions, got" - f"tensordict.batch_size = {tensordict.batch_size}" - ) - reward = tensordict.get("reward") - if self.average_rewards: - reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) - tensordict.set( - "reward", reward - ) # we must update the rewards if they are used later in the code - - gamma, lmbda = self.gamma, self.lmbda - kwargs = {} - if self.is_functional and params is None: - raise RuntimeError( - "Expected params to be passed to advantage module but got none." - ) - if params is not None: - kwargs["params"] = params + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got" + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get("reward") + gamma, lmbda = self.gamma, self.lmbda + kwargs = {} + if self.is_functional and params is None: + raise RuntimeError( + "Expected params to be passed to advantage module but got none." + ) + if params is not None: + kwargs["params"] = params + with hold_out_net(self.value_network): + # we may still need to pass gradient, but we don't want to assign grads to + # value net params self.value_network(tensordict, **kwargs) - value = tensordict.get("state_value") + value = tensordict.get("state_value") + + step_td = step_mdp(tensordict) + if target_params is not None: + # we assume that target parameters are not differentiable + kwargs["params"] = target_params + elif "params" in kwargs: + kwargs["params"] = [param.detach() for param in kwargs["params"]] with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params - step_td = step_mdp(tensordict) - if target_params is not None: - # we assume that target parameters are not differentiable - kwargs["params"] = target_params - elif "params" in kwargs: - kwargs["params"] = [param.detach() for param in kwargs["params"]] self.value_network(step_td, **kwargs) - next_value = step_td.get("state_value") - done = tensordict.get("done") - adv, value_target = vec_generalized_advantage_estimate( - gamma, lmbda, value, next_value, reward, done - ) + next_value = step_td.get("state_value") + done = tensordict.get("done") + adv, value_target = vec_generalized_advantage_estimate( + gamma, lmbda, value, next_value, reward, done + ) + + if self.average_gae: + adv = adv - adv.mean() + adv = adv / adv.std().clamp_min(1e-4) - tensordict.set("advantage", adv.detach()) + tensordict.set("advantage", adv) tensordict.set("value_target", value_target) - if self.gradient_mode: - tensordict.set("value_error", value_target - value) return tensordict From a946fc6f08f5ca3b7374cc66200cdfc22cd819fc Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 14 Dec 2022 17:21:09 +0000 Subject: [PATCH 2/6] amend --- examples/a2c/a2c.py | 24 +++++----- examples/a2c/config.yaml | 1 - examples/ppo/config.yaml | 1 - examples/ppo/ppo.py | 31 ++++++------- test/test_cost.py | 57 ++++++++++++----------- test/test_helpers.py | 4 -- torchrl/objectives/a2c.py | 44 +++++++----------- torchrl/objectives/reinforce.py | 63 ++++++++------------------ torchrl/objectives/value/advantages.py | 22 +++++---- torchrl/trainers/helpers/losses.py | 27 ----------- 10 files changed, 104 insertions(+), 170 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index f29d62e3444..aa5253f65dd 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -143,19 +143,17 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg=cfg, ) - if not cfg.advantage_in_loss: - critic_model = model.get_value_operator() - advantage = TDEstimate( - cfg.gamma, - value_network=critic_model, - average_rewards=True, - gradient_mode=False, - ) - advantage = advantage.to(device) - trainer.register_op( - "process_optim_batch", - advantage, - ) + critic_model = model.get_value_operator() + advantage = TDEstimate( + cfg.gamma, + value_network=critic_model, + average_rewards=True, + gradient_mode=False, + ) + trainer.register_op( + "process_optim_batch", + advantage, + ) final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/examples/a2c/config.yaml b/examples/a2c/config.yaml index f3c05c95c60..780b22ccb44 100644 --- a/examples/a2c/config.yaml +++ b/examples/a2c/config.yaml @@ -26,7 +26,6 @@ gamma: 0.99 entropy_coef: 0.01 # Entropy factor for the A2C loss critic_coef: 0.25 # Critic factor for the A2C loss critic_loss_function: l2 # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). -advantage_in_loss: False # if True, the advantage is computed on the sub-batch # Trainer optim_steps_per_batch: 1 # Number of optimization steps in between two collection of data. diff --git a/examples/ppo/config.yaml b/examples/ppo/config.yaml index ea93e8f51d4..ca0bd6a5684 100644 --- a/examples/ppo/config.yaml +++ b/examples/ppo/config.yaml @@ -28,4 +28,3 @@ loss_function: smooth_l1 batch_transform: 1 entropy_coef: 0.1 default_policy_scale: 1.0 -advantage_in_loss: 1 diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index 70505ae0f47..b19a178fe0e 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -169,22 +169,21 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.loss == "kl": trainer.register_op("pre_optim_steps", loss_module.reset) - if not cfg.advantage_in_loss: - critic_model = model.get_value_operator() - advantage = GAE( - cfg.gamma, - cfg.lmbda, - value_network=critic_model, - average_gae=True, - ) - trainer.register_op( - "process_optim_batch", - advantage, - ) - trainer._process_optim_batch_ops = [ - trainer._process_optim_batch_ops[-1], - *trainer._process_optim_batch_ops[:-1], - ] + critic_model = model.get_value_operator() + advantage = GAE( + cfg.gamma, + cfg.lmbda, + value_network=critic_model, + average_gae=True, + ) + trainer.register_op( + "process_optim_batch", + advantage, + ) + trainer._process_optim_batch_ops = [ + trainer._process_optim_batch_ops[-1], + *trainer._process_optim_batch_ops[:-1], + ] final_seed = collector.set_seed(cfg.seed) print(f"init seed: {cfg.seed}, final seed: {final_seed}") diff --git a/test/test_cost.py b/test/test_cost.py index e249e0efce1..9b8c2e15c0e 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1868,22 +1868,20 @@ def test_a2c(self, device, gradient_mode, advantage): value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td": advantage = TDEstimate( - gamma=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td_lambda": advantage = TDLambdaEstimate( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) else: raise NotImplementedError - loss_fn = A2CLoss( - actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" - ) + loss_fn = A2CLoss(actor, value, gamma=0.9, loss_critic_type="l2") # Check error is raised when actions require grads td["action"].requires_grad = True @@ -1894,16 +1892,13 @@ def test_a2c(self, device, gradient_mode, advantage): _ = loss_fn._log_probs(td) td["action"].requires_grad = False - # Check error is raised when advantage_diff_key present and does not required grad - td[loss_fn.advantage_diff_key] = torch.randn_like(td["reward"]) + td = td.exclude(loss_fn.value_target_key) + with pytest.raises( - RuntimeError, - match="value_target retrieved from tensordict does not require grad.", + KeyError, match=re.escape('key "advantage" not found in TensorDict with') ): - loss = loss_fn.loss_critic(td) - td = td.exclude(loss_fn.advantage_diff_key) - assert loss_fn.advantage_diff_key not in td.keys() - + _ = loss_fn(td) + advantage(td) loss = loss_fn(td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) @@ -1947,25 +1942,28 @@ def test_a2c_diff(self, device, gradient_mode, advantage): value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td": advantage = TDEstimate( - gamma=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, value_network=value, differentiable=gradient_mode ) elif advantage == "td_lambda": advantage = TDLambdaEstimate( - gamma=0.9, lmbda=0.9, value_network=value, gradient_mode=gradient_mode + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) else: raise NotImplementedError - loss_fn = A2CLoss( - actor, value, advantage_module=advantage, gamma=0.9, loss_critic_type="l2" - ) + loss_fn = A2CLoss(actor, value, gamma=0.9, loss_critic_type="l2") floss_fn, params, buffers = make_functional_with_buffers(loss_fn) + with pytest.raises( + KeyError, match=re.escape('key "advantage" not found in TensorDict with') + ): + _ = floss_fn(params, buffers, td) + advantage(td) loss = floss_fn(params, buffers, td) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) @@ -2015,24 +2013,24 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): spec=NdUnboundedContinuousTensorSpec(n_act), ) if advantage == "gae": - advantage_module = GAE( + advantage = GAE( gamma=gamma, lmbda=0.9, value_network=get_functional(value_net), - gradient_mode=gradient_mode, + differentiable=gradient_mode, ) elif advantage == "td": - advantage_module = TDEstimate( + advantage = TDEstimate( gamma=gamma, value_network=get_functional(value_net), - gradient_mode=gradient_mode, + differentiable=gradient_mode, ) elif advantage == "td_lambda": - advantage_module = TDLambdaEstimate( + advantage = TDLambdaEstimate( gamma=0.9, lmbda=0.9, value_network=get_functional(value_net), - gradient_mode=gradient_mode, + differentiable=gradient_mode, ) else: raise NotImplementedError @@ -2041,7 +2039,6 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): actor_net, critic=value_net, gamma=gamma, - advantage_module=advantage_module, delay_value=delay_value, ) @@ -2056,6 +2053,12 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value): [batch], ) + with pytest.raises( + KeyError, match=re.escape('key "advantage" not found in TensorDict with') + ): + _ = loss_fn(td) + params = TensorDict(value_net.state_dict(), []).unflatten_keys(".") + advantage(td, params=params) loss_td = loss_fn(td) autograd.grad( loss_td.get("loss_actor"), diff --git a/test/test_helpers.py b/test/test_helpers.py index 42c9f64403a..c2df97d8486 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -415,7 +415,6 @@ def test_ppo_maker( def test_a2c_maker( device, from_pixels, shared_mapping, gsde, exploration, action_space ): - A2CModelConfig.advantage_in_loss = False if not gsde and exploration != "random": pytest.skip("no need to test this setting") flags = list(from_pixels + shared_mapping + gsde) @@ -558,9 +557,6 @@ def test_a2c_maker( proof_environment.close() del proof_environment - cfg.advantage_in_loss = False - loss_fn = make_a2c_loss(actor_value, cfg) - cfg.advantage_in_loss = True loss_fn = make_a2c_loss(actor_value, cfg) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 886a6b437d7..fd99e8abcff 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Optional, Tuple +from typing import Tuple import torch from tensordict.tensordict import TensorDict, TensorDictBase @@ -44,14 +44,13 @@ def __init__( actor: SafeProbabilisticSequential, critic: SafeModule, advantage_key: str = "advantage", - advantage_diff_key: str = "value_error", + value_target_key: str = "value_target", entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coef: float = 0.01, critic_coef: float = 1.0, gamma: float = 0.99, loss_critic_type: str = "smooth_l1", - advantage_module: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, ): super().__init__() self.convert_to_functional( @@ -59,7 +58,7 @@ def __init__( ) self.convert_to_functional(critic, "critic", compare_against=self.actor_params) self.advantage_key = advantage_key - self.advantage_diff_key = advantage_diff_key + self.value_target_key = value_target_key self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus and entropy_coef self.register_buffer( @@ -70,9 +69,6 @@ def __init__( ) self.register_buffer("gamma", torch.tensor(gamma, device=self.device)) self.loss_critic_type = loss_critic_type - self.advantage_module = advantage_module - if advantage_module: - self.advantage_module = advantage_module.to(self.device) def reset(self) -> None: pass @@ -100,35 +96,29 @@ def _log_probs( return log_prob, dist def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: - if self.advantage_diff_key in tensordict.keys(): - advantage_diff = tensordict.get(self.advantage_diff_key) - if not advantage_diff.requires_grad: - raise RuntimeError( - "value_target retrieved from tensordict does not require grad." - ) - loss_value = distance_loss( - advantage_diff, - torch.zeros_like(advantage_diff), - loss_function=self.loss_critic_type, - ) - else: - advantage = tensordict.get(self.advantage_key) + try: + target_return = tensordict.get(self.value_target_key) tensordict_select = tensordict.select(*self.critic.in_keys) - value = self.critic( + state_value = self.critic( tensordict_select, params=self.critic_params, ).get("state_value") - value_target = advantage + value.detach() loss_value = distance_loss( - value, value_target, loss_function=self.loss_critic_type + target_return, + state_value, + loss_function=self.loss_critic_type, + ) + except KeyError: + raise KeyError( + f"the key {self.value_target_key} was not found in the input tensordict. " + f"Make sure you provided the right key and the value_target (i.e. the target " + f"return) has been retrieved accordingly. Advantage classes such as GAE, " + f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that " + f"can be used for the value loss." ) return self.critic_coef * loss_value def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.advantage_module is not None: - tensordict = self.advantage_module( - tensordict, - ) tensordict = tensordict.clone() advantage = tensordict.get(self.advantage_key) log_probs, dist = self._log_probs(tensordict) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 827ade168fb..f7330d7d52b 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -1,9 +1,8 @@ -from typing import Callable, Optional +from typing import Optional import torch from tensordict.tensordict import TensorDict, TensorDictBase -from torchrl.envs.utils import step_mdp from torchrl.modules import SafeModule, SafeProbabilisticSequential from torchrl.objectives import distance_loss from torchrl.objectives.common import LossModule @@ -20,30 +19,21 @@ class ReinforceLoss(LossModule): def __init__( self, actor_network: SafeProbabilisticSequential, - advantage_module: Callable[[TensorDictBase], TensorDictBase], critic: Optional[SafeModule] = None, delay_value: bool = False, gamma: float = 0.99, advantage_key: str = "advantage", - advantage_diff_key: str = "value_error", + value_target_key: str = "value_target", loss_critic_type: str = "smooth_l1", ) -> None: super().__init__() self.delay_value = delay_value self.advantage_key = advantage_key - self.advantage_diff_key = advantage_diff_key + self.value_target_key = value_target_key self.loss_critic_type = loss_critic_type self.register_buffer("gamma", torch.tensor(gamma)) - if ( - hasattr(advantage_module, "is_functional") - and not advantage_module.is_functional - ): - raise RuntimeError( - "The advantage module must be functional, as it must support params and target params arguments" - ) - # Actor self.convert_to_functional( actor_network, @@ -60,15 +50,7 @@ def __init__( compare_against=list(actor_network.parameters()), ) - self.advantage_module = advantage_module - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - # get advantage - tensordict = self.advantage_module( - tensordict, - params=self.critic_params, - target_params=self.target_critic_params, - ) advantage = tensordict.get(self.advantage_key) # compute log-prob @@ -87,33 +69,24 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return td_out def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: - if self.advantage_diff_key in tensordict.keys(): - advantage_diff = tensordict.get(self.advantage_diff_key) - if not advantage_diff.requires_grad: - raise RuntimeError( - "value_target retrieved from tensordict does not requires grad." - ) - loss_value = distance_loss( - advantage_diff, - torch.zeros_like(advantage_diff), - loss_function=self.loss_critic_type, - ) - else: - with torch.no_grad(): - reward = tensordict.get("reward") - next_td = step_mdp(tensordict) - next_value = self.critic( - next_td, - params=self.critic_params, - ).get("state_value") - value_target = reward + next_value * self.gamma - tensordict_select = tensordict.select(*self.critic.in_keys).clone() - value = self.critic( + try: + target_return = tensordict.get(self.value_target_key) + tensordict_select = tensordict.select(*self.critic.in_keys) + state_value = self.critic( tensordict_select, params=self.critic_params, ).get("state_value") - loss_value = distance_loss( - value, value_target, loss_function=self.loss_critic_type + target_return, + state_value, + loss_function=self.loss_critic_type, + ) + except KeyError: + raise KeyError( + f"the key {self.value_target_key} was not found in the input tensordict. " + f"Make sure you provided the right key and the value_target (i.e. the target " + f"return) has been retrieved accordingly. Advantage classes such as GAE, " + f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that " + f"can be used for the value loss." ) return loss_value diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 32d06f2cdff..7061c9369a1 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -79,8 +79,9 @@ def forward( """Computes the GAE given the data in tensordict. Args: - tensordict (TensorDictBase): A TensorDict containing the data (observation, action, reward, done state) - necessary to compute the value estimates and the GAE. + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value estimates and the TDEstimate. Returns: An updated TensorDict with an "advantage" and a "value_error" keys @@ -118,7 +119,7 @@ def forward( # we assume that target parameters are not differentiable kwargs["params"] = target_params elif "params" in kwargs: - kwargs["params"] = [param.detach() for param in kwargs["params"]] + kwargs["params"] = kwargs["params"].detach() with hold_out_net(self.value_network): self.value_network(step_td, **kwargs) next_value = step_td.get(self.value_key) @@ -185,8 +186,10 @@ def forward( """Computes the GAE given the data in tensordict. Args: - tensordict (TensorDictBase): A TensorDict containing the data (observation, action, reward, done state) - necessary to compute the value estimates and the GAE. + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value + estimates and the TDLambdaEstimate. Returns: An updated TensorDict with an "advantage" and a "value_error" keys @@ -224,7 +227,7 @@ def forward( # we assume that target parameters are not differentiable kwargs["params"] = target_params elif "params" in kwargs: - kwargs["params"] = [param.detach() for param in kwargs["params"]] + kwargs["params"] = kwargs["params"].detach() with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params @@ -304,8 +307,9 @@ def forward( """Computes the GAE given the data in tensordict. Args: - tensordict (TensorDictBase): A TensorDict containing the data (observation, action, reward, done state) - necessary to compute the value estimates and the GAE. + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value estimates and the GAE. Returns: An updated TensorDict with an "advantage" and a "value_error" keys @@ -337,7 +341,7 @@ def forward( # we assume that target parameters are not differentiable kwargs["params"] = target_params elif "params" in kwargs: - kwargs["params"] = [param.detach() for param in kwargs["params"]] + kwargs["params"] = kwargs["params"].detach() with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index f4ae7a9a842..5c2e823b346 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -178,22 +178,11 @@ def make_a2c_loss(model, cfg) -> A2CLoss: actor_model = model.get_policy_operator() critic_model = model.get_value_operator() - if cfg.advantage_in_loss: - advantage = TDEstimate( - gamma=cfg.gamma, - value_network=critic_model, - average_rewards=True, - gradient_mode=False, - ) - else: - advantage = None - kwargs = { "actor": actor_model, "critic": critic_model, "loss_critic_type": cfg.critic_loss_function, "entropy_coef": cfg.entropy_coef, - "advantage_module": advantage, } loss_module = A2CLoss(**kwargs) @@ -212,21 +201,9 @@ def make_ppo_loss(model, cfg) -> PPOLoss: actor_model = model.get_policy_operator() critic_model = model.get_value_operator() - if cfg.advantage_in_loss: - advantage = GAE( - cfg.gamma, - cfg.lmbda, - value_network=critic_model, - average_rewards=True, - gradient_mode=False, - ) - else: - advantage = None - kwargs = { "actor": actor_model, "critic": critic_model, - "advantage_module": advantage, "loss_critic_type": cfg.loss_function, "entropy_coef": cfg.entropy_coef, } @@ -289,8 +266,6 @@ class A2CLossConfig: # Critic factor for the A2C loss critic_loss_function: str = "smooth_l1" # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). - advantage_in_loss: bool = False - # if True, the advantage is computed on the sub-batch. @dataclass @@ -313,8 +288,6 @@ class PPOLossConfig: # Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula. loss_function: str = "smooth_l1" # loss function for the value network. Either one of l1, l2 or smooth_l1 (default). - advantage_in_loss: bool = False - # if True, the advantage is computed on the sub-batch., critic_coef: float = 1.0 # Critic loss multiplier when computing the total loss. From b7fee6666892213372d67502449ad854fa2236b1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 14 Dec 2022 18:06:31 +0000 Subject: [PATCH 3/6] amend --- examples/a2c/a2c.py | 1 - torchrl/trainers/helpers/losses.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/a2c/a2c.py b/examples/a2c/a2c.py index aa5253f65dd..52ce390441b 100644 --- a/examples/a2c/a2c.py +++ b/examples/a2c/a2c.py @@ -148,7 +148,6 @@ def main(cfg: "DictConfig"): # noqa: F821 cfg.gamma, value_network=critic_model, average_rewards=True, - gradient_mode=False, ) trainer.register_op( "process_optim_batch", diff --git a/torchrl/trainers/helpers/losses.py b/torchrl/trainers/helpers/losses.py index 5c2e823b346..d627380eca6 100644 --- a/torchrl/trainers/helpers/losses.py +++ b/torchrl/trainers/helpers/losses.py @@ -25,7 +25,6 @@ # from torchrl.objectives.redq import REDQLoss from torchrl.objectives.utils import TargetNetUpdater -from torchrl.objectives.value.advantages import GAE, TDEstimate def make_target_updater( From 0a6ac6488323b44918ff8f366529754d48d999ca Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 14 Dec 2022 20:53:10 +0000 Subject: [PATCH 4/6] amend --- examples/ppo/ppo.py | 2 +- torchrl/objectives/value/advantages.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index b19a178fe0e..6ddb1a34d7b 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -178,7 +178,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) trainer.register_op( "process_optim_batch", - advantage, + lambda tensordict: advantage(tensordict.to(device)), ) trainer._process_optim_batch_ops = [ trainer._process_optim_batch_ops[-1], diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 7061c9369a1..8fd344b1b71 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -54,7 +54,11 @@ def __init__( value_key: str = "state_value", ): super().__init__() - self.register_buffer("gamma", torch.tensor(gamma)) + try: + device = next(value_network.parameters()).device + except StopIteration: + device = torch.device("cpu") + self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.value_network = value_network self.average_rewards = average_rewards @@ -159,8 +163,12 @@ def __init__( vectorized: bool = True, ): super().__init__() - self.register_buffer("gamma", torch.tensor(gamma)) - self.register_buffer("lmbda", torch.tensor(lmbda)) + try: + device = next(value_network.parameters()).device + except StopIteration: + device = torch.device("cpu") + self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.value_network = value_network self.vectorized = vectorized @@ -282,8 +290,12 @@ def __init__( differentiable: bool = False, ): super().__init__() - self.register_buffer("gamma", torch.tensor(gamma)) - self.register_buffer("lmbda", torch.tensor(lmbda)) + try: + device = next(value_network.parameters()).device + except StopIteration: + device = torch.device("cpu") + self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.value_network = value_network self.average_gae = average_gae From 96bdc2ab2fcc0842938d455497cbbe7a60392dd8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 14 Dec 2022 21:19:34 +0000 Subject: [PATCH 5/6] amend --- torchrl/trainers/trainers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 48de6a77719..07ad793668e 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1008,7 +1008,7 @@ def __call__(self, batch: TensorDictBase) -> TensorDictBase: * batch.shape[1] ) len_mask = traj_len >= sub_traj_len - valid_trajectories = torch.arange(batch.shape[0])[len_mask] + valid_trajectories = torch.arange(batch.shape[0], device=batch.device)[len_mask] batch_size = self.batch_size // sub_traj_len if batch_size == 0: From 18c0691729f5636dd7c4ee242fa5fea01222bda5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 15 Dec 2022 12:01:21 +0000 Subject: [PATCH 6/6] init --- test/test_cost.py | 28 +++ torchrl/objectives/value/advantages.py | 236 +++++++++++++++++++++---- 2 files changed, 234 insertions(+), 30 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 9b8c2e15c0e..fbba7ac4c5a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -2960,6 +2960,34 @@ def __init__(self, actor_network, qvalue_network): class TestAdv: + @pytest.mark.parametrize( + "adv,kwargs", + [[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]], + ) + def test_dispatch( + self, + adv, + kwargs, + ): + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } + advantage, value_target = module(**kwargs) + assert advantage.shape == torch.Size([1, 10, 1]) + assert value_target.shape == torch.Size([1, 10, 1]) + @pytest.mark.parametrize( "adv,kwargs", [[GAE, {"lmbda": 0.95}], [TDEstimate, {}], [TDLambdaEstimate, {"lmbda": 0.95}]], diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 8fd344b1b71..d8d7c128857 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -4,23 +4,24 @@ # LICENSE file in the root directory of this source tree. from functools import wraps -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch +from tensordict.nn import dispatch_kwargs from tensordict.tensordict import TensorDictBase from torch import nn, Tensor from torchrl.envs.utils import step_mdp from torchrl.modules import SafeModule + +from torchrl.objectives.utils import hold_out_net from torchrl.objectives.value.functional import ( + td_advantage_estimate, td_lambda_advantage_estimate, vec_generalized_advantage_estimate, vec_td_lambda_advantage_estimate, ) -from ..utils import hold_out_net -from .functional import td_advantage_estimate - def _self_set_grad_enabled(fun): @wraps(fun) @@ -41,8 +42,11 @@ class TDEstimate(nn.Module): before the TD is computed. differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. - value_key (str, optional): key pointing to the state value. Default is - `"state_value"`. + advantage_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "advantage". + value_target_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "value_target". + """ def __init__( @@ -51,7 +55,8 @@ def __init__( value_network: SafeModule, average_rewards: bool = False, differentiable: bool = False, - value_key: str = "state_value", + advantage_key: Union[str, Tuple] = "advantage", + value_target_key: Union[str, Tuple] = "value_target", ): super().__init__() try: @@ -63,7 +68,17 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = value_key + self.value_key = value_network.out_keys[0] + + self.advantage_key = advantage_key + self.value_target_key = value_target_key + + self.in_keys = ( + value_network.in_keys + + ["reward", "done"] + + [("next", in_key) for in_key in value_network.in_keys] + ) + self.out_keys = [self.advantage_key, self.value_target_key] @property def is_functional(self): @@ -73,22 +88,64 @@ def is_functional(self): ) @_self_set_grad_enabled + @dispatch_kwargs def forward( self, tensordict: TensorDictBase, - *unused_args, - params: Optional[List[Tensor]] = None, - target_params: Optional[List[Tensor]] = None, + params: Optional[TensorDictBase] = None, + target_params: Optional[TensorDictBase] = None, ) -> TensorDictBase: - """Computes the GAE given the data in tensordict. + """Computes the TDEstimate given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. Args: tensordict (TensorDictBase): A TensorDict containing the data (an observation key, "action", "reward", "done" and "next" tensordict state as returned by the environment) necessary to compute the value estimates and the TDEstimate. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. Returns: - An updated TensorDict with an "advantage" and a "value_error" keys + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDEstimate( + ... gamma=0.98, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDEstimate( + ... gamma=0.98, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) """ if tensordict.batch_dims < 1: @@ -111,7 +168,7 @@ def forward( "Expected params to be passed to advantage module but got none." ) if params is not None: - kwargs["params"] = params + kwargs["params"] = params.detach() with hold_out_net(self.value_network): self.value_network(tensordict, **kwargs) value = tensordict.get(self.value_key) @@ -146,10 +203,13 @@ class TDLambdaEstimate(nn.Module): before the TD is computed. differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. - value_key (str, optional): key pointing to the state value. Default is - `"state_value"`. vectorized (bool, optional): whether to use the vectorized version of the lambda return. Default is `True`. + advantage_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "advantage". + value_target_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "value_target". + """ def __init__( @@ -159,8 +219,9 @@ def __init__( value_network: SafeModule, average_rewards: bool = False, differentiable: bool = False, - value_key: str = "state_value", vectorized: bool = True, + advantage_key: Union[str, Tuple] = "advantage", + value_target_key: Union[str, Tuple] = "value_target", ): super().__init__() try: @@ -174,7 +235,17 @@ def __init__( self.average_rewards = average_rewards self.differentiable = differentiable - self.value_key = value_key + self.value_key = self.value_network.out_keys[0] + + self.advantage_key = advantage_key + self.value_target_key = value_target_key + + self.in_keys = ( + value_network.in_keys + + ["reward", "done"] + + [("next", in_key) for in_key in value_network.in_keys] + ) + self.out_keys = [self.advantage_key, self.value_target_key] @property def is_functional(self): @@ -184,23 +255,66 @@ def is_functional(self): ) @_self_set_grad_enabled + @dispatch_kwargs def forward( self, tensordict: TensorDictBase, - *unused_args, params: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, ) -> TensorDictBase: - """Computes the GAE given the data in tensordict. + """Computes the TDLambdaEstimate given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. Args: tensordict (TensorDictBase): A TensorDict containing the data (an observation key, "action", "reward", "done" and "next" tensordict state - as returned by the environment) necessary to compute the value - estimates and the TDLambdaEstimate. + as returned by the environment) necessary to compute the value estimates and the TDLambdaEstimate. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. Returns: - An updated TensorDict with an "advantage" and a "value_error" keys + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDLambdaEstimate( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = TDLambdaEstimate( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) """ if tensordict.batch_dims < 1: @@ -252,8 +366,8 @@ def forward( gamma, lmbda, value, next_value, reward, done ) - tensordict.set("advantage", adv) - tensordict.set("value_target", adv + value) + tensordict.set(self.advantage_key, adv) + tensordict.set(self.value_target_key, adv + value) return tensordict @@ -271,6 +385,10 @@ class GAE(nn.Module): Default is :obj:`False`. differentiable (bool, optional): if True, gradients are propagated throught the computation of the value function. Default is :obj:`False`. + advantage_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "advantage". + value_target_key (str or tuple of str, optional): the key of the advantage entry. + Defaults to "value_target". GAE will return an :obj:`"advantage"` entry containing the advange value. It will also return a :obj:`"value_target"` entry with the return value that is to be used @@ -288,6 +406,8 @@ def __init__( value_network: SafeModule, average_gae: bool = False, differentiable: bool = False, + advantage_key: Union[str, Tuple] = "advantage", + value_target_key: Union[str, Tuple] = "value_target", ): super().__init__() try: @@ -297,10 +417,21 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.value_network = value_network + self.value_key = self.value_network.out_keys[0] self.average_gae = average_gae self.differentiable = differentiable + self.advantage_key = advantage_key + self.value_target_key = value_target_key + + self.in_keys = ( + value_network.in_keys + + ["reward", "done"] + + [("next", in_key) for in_key in value_network.in_keys] + ) + self.out_keys = [self.advantage_key, self.value_target_key] + @property def is_functional(self): return ( @@ -309,6 +440,7 @@ def is_functional(self): ) @_self_set_grad_enabled + @dispatch_kwargs def forward( self, tensordict: TensorDictBase, @@ -318,13 +450,57 @@ def forward( ) -> TensorDictBase: """Computes the GAE given the data in tensordict. + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. + Args: tensordict (TensorDictBase): A TensorDict containing the data (an observation key, "action", "reward", "done" and "next" tensordict state as returned by the environment) necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. Returns: - An updated TensorDict with an "advantage" and a "value_error" keys + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = GAE( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = SafeModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = GAE( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) """ if tensordict.batch_dims < 1: @@ -346,7 +522,7 @@ def forward( # value net params self.value_network(tensordict, **kwargs) - value = tensordict.get("state_value") + value = tensordict.get(self.value_key) step_td = step_mdp(tensordict) if target_params is not None: @@ -358,7 +534,7 @@ def forward( # we may still need to pass gradient, but we don't want to assign grads to # value net params self.value_network(step_td, **kwargs) - next_value = step_td.get("state_value") + next_value = step_td.get(self.value_key) done = tensordict.get("done") adv, value_target = vec_generalized_advantage_estimate( gamma, lmbda, value, next_value, reward, done @@ -368,7 +544,7 @@ def forward( adv = adv - adv.mean() adv = adv / adv.std().clamp_min(1e-4) - tensordict.set("advantage", adv) - tensordict.set("value_target", value_target) + tensordict.set(self.advantage_key, adv) + tensordict.set(self.value_target_key, value_target) return tensordict