diff --git a/test/test_cost.py b/test/test_cost.py index 60d1b1e374f..73ea5fd27b4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -877,11 +877,11 @@ def test_dqn_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -905,11 +905,11 @@ def test_distributional_dqn_reduction(self, reduction, atoms): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -1991,7 +1991,7 @@ def test_ddpg_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): @@ -2700,11 +2700,11 @@ def test_td3_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -3613,11 +3613,11 @@ def test_sac_reduction(self, reduction, version): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -4202,11 +4202,11 @@ def test_discrete_sac_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -5156,11 +5156,11 @@ def test_redq_reduction(self, reduction, deprecated_loss): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape[-1] == td.shape[0] else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -6201,7 +6201,6 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", [True, False]) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_ppo( self, loss_class, @@ -6210,7 +6209,6 @@ def test_ppo( advantage, td_est, functional, - reduction, ): torch.manual_seed(self.seed) td = self._create_seq_mock_data_ppo(device=device) @@ -6246,7 +6244,6 @@ def test_ppo( value, loss_critic_type="l2", functional=functional, - reduction=reduction, ) if advantage is not None: advantage(td) @@ -6259,15 +6256,6 @@ def test_ppo( kl = loss.pop("kl") assert (kl != 0).any() - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) - loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -6804,6 +6792,41 @@ def test_ppo_notensordict( assert loss_obj == loss_val_td.get("loss_objective") assert loss_crit == loss_val_td.get("loss_critic") + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_ppo_reduction(self, reduction, loss_class): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_seq_mock_data_ppo(device=device) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + ) + loss_fn = loss_class( + actor, + value, + loss_critic_type="l2", + reduction=reduction, + ) + advantage(td) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss_"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss_"): + continue + assert loss[key].shape == torch.Size([]) + class TestA2C(LossModuleTestBase): seed = 0 @@ -6969,8 +6992,7 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", (True, False)) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction): + def test_a2c(self, device, gradient_mode, advantage, td_est, functional): torch.manual_seed(self.seed) td = self._create_seq_mock_data_a2c(device=device) @@ -7005,7 +7027,6 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti value, loss_critic_type="l2", functional=functional, - reduction=reduction, ) # Check error is raised when actions require grads @@ -7023,14 +7044,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti elif td_est is not None: loss_fn.make_value_estimator(td_est) loss = loss_fn(td) - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - loss = loss.apply(func, batch_size=[]) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -7413,6 +7427,40 @@ def test_a2c_notensordict( assert loss_objective == loss_val_td["loss_objective"] assert loss_critic == loss_val_td["loss_critic"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_a2c_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_seq_mock_data_a2c(device=device) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + ) + loss_fn = A2CLoss( + actor, + value, + loss_critic_type="l2", + reduction=reduction, + ) + advantage(td) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss_"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss_"): + continue + assert loss[key].shape == torch.Size([]) + class TestReinforce(LossModuleTestBase): seed = 0 @@ -7659,26 +7707,16 @@ def _create_mock_common_layer_setup( return actor, critic, common, td @pytest.mark.parametrize("separate_losses", [False, True]) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction): + def test_reinforce_tensordict_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() loss_fn = ReinforceLoss( actor_network=actor, critic_network=critic, separate_losses=separate_losses, - reduction=reduction, ) loss = loss_fn(td) - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) assert all( (p.grad is None) or (p.grad == 0).all() @@ -7807,6 +7845,26 @@ def test_reinforce_notensordict( return assert loss_actor == loss_val_td["loss_actor"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_reinforce_reduction(self, reduction): + torch.manual_seed(self.seed) + actor, critic, common, td = self._create_mock_common_layer_setup() + loss_fn = ReinforceLoss( + actor_network=actor, + critic_network=critic, + reduction=reduction, + ) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss_"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss_"): + continue + assert loss[key].shape == torch.Size([]) + @pytest.mark.parametrize("device", get_default_devices()) class TestDreamer(LossModuleTestBase): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 1846db4989a..b52b055e357 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -477,7 +477,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index d871098cbb8..683e2f27553 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import functools import math import warnings @@ -560,10 +559,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.critic_coef: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) + if name.startswith("loss_") + else value, + batch_size=[], ) - return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): @@ -807,7 +808,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("ESS", _reduce(ess, self.reduction) / batch) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], @@ -1070,7 +1071,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: loss_critic = self.loss_critic(tensordict_copy) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 37b595f7cb7..aaf871c8c7f 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -402,7 +402,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("loss_value", self.loss_critic(tensordict)) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[],