From cd4a943f2d9932e0cd015c294ffb99a5bfdcfacd Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 11:58:20 +0100 Subject: [PATCH 01/12] dedicated onpolicy tests --- test/test_cost.py | 117 +++++++++++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 33 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 349dfbeb835..41a8dcb5a80 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6158,7 +6158,6 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", [True, False]) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_ppo( self, loss_class, @@ -6167,7 +6166,6 @@ def test_ppo( advantage, td_est, functional, - reduction, ): torch.manual_seed(self.seed) td = self._create_seq_mock_data_ppo(device=device) @@ -6203,7 +6201,6 @@ def test_ppo( value, loss_critic_type="l2", functional=functional, - reduction=reduction, ) if advantage is not None: advantage(td) @@ -6216,15 +6213,6 @@ def test_ppo( kl = loss.pop("kl") assert (kl != 0).any() - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) - loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -6761,6 +6749,39 @@ def test_ppo_notensordict( assert loss_obj == loss_val_td.get("loss_objective") assert loss_crit == loss_val_td.get("loss_critic") + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_ppo_reduction(self, reduction, loss_class): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_seq_mock_data_ppo(device=device) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + ) + loss_fn = loss_class( + actor, + value, + loss_critic_type="l2", + reduction=reduction, + ) + advantage(td) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + class TestA2C(LossModuleTestBase): seed = 0 @@ -6926,8 +6947,7 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", (True, False)) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction): + def test_a2c(self, device, gradient_mode, advantage, td_est, functional): torch.manual_seed(self.seed) td = self._create_seq_mock_data_a2c(device=device) @@ -6962,7 +6982,6 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti value, loss_critic_type="l2", functional=functional, - reduction=reduction, ) # Check error is raised when actions require grads @@ -6980,14 +6999,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti elif td_est is not None: loss_fn.make_value_estimator(td_est) loss = loss_fn(td) - if reduction == "none": - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -7370,6 +7382,37 @@ def test_a2c_notensordict( assert loss_objective == loss_val_td["loss_objective"] assert loss_critic == loss_val_td["loss_critic"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_a2c_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_seq_mock_data_a2c(device=device) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + ) + loss_fn = A2CLoss( + actor, + value, + loss_critic_type="l2", + ) + advantage(td) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + class TestReinforce(LossModuleTestBase): seed = 0 @@ -7616,26 +7659,16 @@ def _create_mock_common_layer_setup( return actor, critic, common, td @pytest.mark.parametrize("separate_losses", [False, True]) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction): + def test_reinforce_tensordict_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() loss_fn = ReinforceLoss( actor_network=actor, critic_network=critic, separate_losses=separate_losses, - reduction=reduction, ) loss = loss_fn(td) - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) assert all( (p.grad is None) or (p.grad == 0).all() @@ -7764,6 +7797,24 @@ def test_reinforce_notensordict( return assert loss_actor == loss_val_td["loss_actor"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_reinforce_reduction(self, reduction): + torch.manual_seed(self.seed) + actor, critic, common, td = self._create_mock_common_layer_setup() + loss_fn = ReinforceLoss( + actor_network=actor, + critic_network=critic, + reduction=reduction, + ) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + @pytest.mark.parametrize("device", get_default_devices()) class TestDreamer(LossModuleTestBase): From 00208ef38c7df16bf951b58610b8a0e6b70e65cf Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 12:19:46 +0100 Subject: [PATCH 02/12] dedicated onpolicy tests --- test/test_cost.py | 38 ++++++++++++++++++++++++-------------- torchrl/objectives/ppo.py | 13 +++++++------ 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 41a8dcb5a80..6276be3b316 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -875,7 +875,8 @@ def test_dqn_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.parametrize("atoms", range(4, 10)) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) @@ -901,7 +902,8 @@ def test_distributional_dqn_reduction(self, reduction, atoms): assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) class TestQMixer(LossModuleTestBase): @@ -1979,7 +1981,8 @@ def test_ddpg_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -2685,7 +2688,8 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -3592,7 +3596,8 @@ def test_sac_reduction(self, reduction, version): assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -4175,7 +4180,8 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -5120,10 +5126,11 @@ def test_redq_reduction(self, reduction, deprecated_loss): if reduction == "none": for key in loss.keys(): if key.startswith("loss"): - assert loss[key].shape[-1] == td.shape[0] + assert loss[key].shape == td.shape else: for key in loss.keys(): - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) class TestCQL(LossModuleTestBase): @@ -6778,8 +6785,9 @@ def test_ppo_reduction(self, reduction, loss_class): for key in loss.keys(): if key.startswith("loss"): assert loss[key].shape == td.shape - else: - for key in loss.keys(): + else: + for key in loss.keys(): + if key.startswith("loss"): assert loss[key].shape == torch.Size([]) @@ -7409,8 +7417,9 @@ def test_a2c_reduction(self, reduction): for key in loss.keys(): if key.startswith("loss"): assert loss[key].shape == td.shape - else: - for key in loss.keys(): + else: + for key in loss.keys(): + if key.startswith("loss"): assert loss[key].shape == torch.Size([]) @@ -7811,8 +7820,9 @@ def test_reinforce_reduction(self, reduction): for key in loss.keys(): if key.startswith("loss"): assert loss[key].shape == td.shape - else: - for key in loss.keys(): + else: + for key in loss.keys(): + if key.startswith("loss"): assert loss[key].shape == torch.Size([]) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 763403603ee..7b0ee0f70b3 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import functools import math import warnings @@ -560,10 +559,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.critic_coef: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) + if name.startswith("loss_") + else value, + batch_size=[], ) - return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): @@ -807,7 +808,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("ESS", _reduce(ess, self.reduction) / batch) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], @@ -1070,7 +1071,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: loss_critic = self.loss_critic(tensordict_copy) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], From de1f674e8fbac9a6f3763b12682ea55e8c4cf6f6 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 12:31:16 +0100 Subject: [PATCH 03/12] dedicated onpolicy tests --- test/test_cost.py | 1 + torchrl/objectives/a2c.py | 2 +- torchrl/objectives/reinforce.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 6276be3b316..63632755657 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7410,6 +7410,7 @@ def test_a2c_reduction(self, reduction): actor, value, loss_critic_type="l2", + reduction=reduction, ) advantage(td) loss = loss_fn(td) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c39ff78a7f0..d55fb145919 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -477,7 +477,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 37b595f7cb7..aaf871c8c7f 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -402,7 +402,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("loss_value", self.loss_critic(tensordict)) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], From 7921860d9254b5f03ff7a79efe5dccda2deed74d Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 16:21:45 +0100 Subject: [PATCH 04/12] format --- test/test_cost.py | 121 +++++++++++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 63632755657..33712c8af1a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -514,7 +514,7 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): assert loss_fn.tensor_keys.priority in td.keys() - sum([item for _, item in loss.items()]).backward() + sum([item for name, item in loss.items() if name.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 # Check param update effect on targets @@ -581,15 +581,21 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 # Check param update effect on targets @@ -727,7 +733,7 @@ def test_distributional_dqn( assert loss_fn.tensor_keys.priority in td.keys() - sum([item for _, item in loss.items()]).backward() + sum([item for name, item in loss.items() if name.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 if delay_value: @@ -1067,7 +1073,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() - sum([item for _, item in loss.items()]).backward() + sum([item for name, item in loss.items() if name.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 if delay_value: @@ -1152,15 +1158,21 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*td.keys(True, True))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss")] + ).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 # Check param update effect on targets @@ -1833,7 +1845,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) sum( - [item for name, item in loss.items() if name.startswith("loss_")] + [item for name, item in loss_ms.items() if name.startswith("loss_")] ).backward() parameters = list(actor.parameters()) + list(value.parameters()) for p in parameters: @@ -2466,8 +2478,12 @@ def test_td3_batcher( if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" @@ -2476,7 +2492,7 @@ def test_td3_batcher( assert_allclose_td(loss, loss_ms) sum( - [item for name, item in loss.items() if name.startswith("loss_")] + [item for name, item in loss_ms.items() if name.startswith("loss_")] ).backward() named_parameters = loss_fn.named_parameters() @@ -2635,11 +2651,8 @@ def test_td3_notensordict( loss_val_td = loss(td) torch.manual_seed(0) loss_val = loss(**kwargs) - for i in loss_val: - assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" - - for i, key in enumerate(loss.out_keys): - torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + loss_val_reconstruct = TensorDict(dict(zip(loss.out_keys, loss_val)), []) + assert_allclose_td(loss_val_reconstruct, loss_val_td) # test select loss.select_out_keys("loss_actor", "loss_qvalue") @@ -3277,8 +3290,12 @@ def test_sac_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" @@ -3286,7 +3303,7 @@ def test_sac_batcher( with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) sum( - [item for name, item in loss.items() if name.startswith("loss_")] + [item for name, item in loss_ms.items() if name.startswith("loss_")] ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: @@ -3963,8 +3980,12 @@ def test_discrete_sac_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" @@ -3972,7 +3993,7 @@ def test_discrete_sac_batcher( with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) sum( - [item for name, item in loss.items() if name.startswith("loss_")] + [item for name, item in loss_ms.items() if name.startswith("loss_")] ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: @@ -4889,8 +4910,12 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" @@ -4898,7 +4923,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) sum( - [item for name, item in loss.items() if name.startswith("loss_")] + [item for name, item in loss_ms.items() if name.startswith("loss_")] ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: @@ -5507,8 +5532,12 @@ def test_cql_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" @@ -5516,7 +5545,7 @@ def test_cql_batcher( with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) sum( - [item for name, item in loss.items() if name.startswith("loss_")] + [item for name, item in loss_ms.items() if name.startswith("loss_")] ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: @@ -9004,7 +9033,9 @@ def test_iql( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -9271,15 +9302,21 @@ def test_iql_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): @@ -9769,7 +9806,9 @@ def test_discrete_iql( raise NotImplementedError(k) loss_fn.zero_grad() - sum([item for _, item in loss.items()]).backward() + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() named_parameters = list(loss_fn.named_parameters()) named_buffers = list(loss_fn.named_buffers()) @@ -10039,15 +10078,21 @@ def test_discrete_iql_batcher( loss = loss_fn(td) if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum([item for _, item in loss.items()]) - _loss_ms = sum([item for _, item in loss_ms.items()]) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) assert ( abs(_loss - _loss_ms) < 1e-3 ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" else: with pytest.raises(AssertionError): assert_allclose_td(loss, loss_ms) - sum([item for _, item in loss_ms.items()]).backward() + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() named_parameters = loss_fn.named_parameters() for name, p in named_parameters: if not name.startswith("target_"): From 4c9e0e6fff7867051b667963c0ca14d689b0255e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 27 Feb 2024 14:35:28 +0000 Subject: [PATCH 05/12] [Refactor] Remove remnant legacy functional calls (#1973) --- test/test_tensordictmodules.py | 1256 +++----------------- torchrl/envs/utils.py | 1 - tutorials/sphinx-tutorials/coding_ddpg.py | 17 +- tutorials/sphinx-tutorials/torchrl_demo.py | 21 +- 4 files changed, 190 insertions(+), 1105 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c2df40be012..7e0fef99786 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -9,12 +9,7 @@ import torch from mocking_classes import DiscreteActionVecMockEnv from tensordict import pad, TensorDict, unravel_key_list -from tensordict.nn import ( - InteractionType, - make_functional, - TensorDictModule, - TensorDictSequential, -) +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -255,15 +250,33 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + +class TestTDSequence: + # Temporarily disabling this test until 473 is merged in tensordict + # def test_in_key_warning(self): + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] + # ) + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] + # ) + @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - params = make_functional(net) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) if spec_type is None: spec = None @@ -272,31 +285,51 @@ def test_functional(self, safe, spec_type): elif spec_type == "unbounded": spec = UnboundedContinuousTensorSpec(4) + kwargs = {} + if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return + pytest.skip("safe and spec is None is checked elsewhere") else: - tensordict_module = SafeModule( - spec=spec, - module=net, + tdmodule1 = SafeModule( + net1, + spec=None, in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + spec=None, + in_keys=["hidden"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = SafeModule( + spec=spec, + module=net2, + in_keys=["hidden"], out_keys=["out"], - safe=safe, + safe=False, + **kwargs, ) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 3 + tdmodule[1] = tdmodule2 + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 3 + del tdmodule[2] + assert len(tdmodule) == 2 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=TensorDict({"module": params}, [])) + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) @@ -308,16 +341,19 @@ def test_functional(self, safe, spec_type): @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful_probabilistic(self, safe, spec_type, lazy): torch.manual_seed(0) param_multiplier = 2 - - tdnet = SafeModule( - module=NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = NormalParamWrapper(net2) if spec_type is None: spec = None @@ -331,1075 +367,128 @@ def test_functional_probabilistic(self, safe, spec_type): kwargs = {"distribution_class": TanhNormal} if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return + pytest.skip("safe and spec is None is checked elsewhere") else: + tdmodule1 = SafeModule( + net1, + in_keys=["in"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + in_keys=["hidden"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeModule( + module=net2, + in_keys=["hidden"], + out_keys=["loc", "scale"], + spec=None, + safe=False, + ) + prob_module = SafeProbabilisticModule( + spec=spec, in_keys=["loc", "scale"], out_keys=["out"], - spec=spec, - safe=safe, + safe=False, **kwargs, ) + tdmodule = SafeProbabilisticTensorDictSequential( + tdmodule1, dummy_tdmodule, tdmodule2, prob_module + ) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 4 + tdmodule[1] = tdmodule2 + tdmodule[2] = prob_module + assert len(tdmodule) == 4 - tensordict_module = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tensordict_module) + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 4 + del tdmodule[3] + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 + assert tdmodule[2] is prob_module td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td, params=params) + tdmodule(td) assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) + dist = tdmodule.get_dist(td) + assert dist.rsample().shape[: td.ndimension()] == td.shape + # test bounds if not safe and spec_type == "bounded": assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() elif safe and spec_type == "bounded": assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.BatchNorm1d(32 * param_multiplier) - params = make_functional(net) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=TensorDict({"module": params}, [])) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) + def test_submodule_sequence(self): + td_module_1 = SafeModule( + nn.Linear(3, 2), + in_keys=["in"], + out_keys=["hidden"], + ) + td_module_2 = SafeModule( + nn.Linear(2, 4), + in_keys=["hidden"], + out_keys=["out"], + ) + td_module = SafeSequential(td_module_1, td_module_2) - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): + @pytest.mark.parametrize("stack", [True, False]) + def test_sequential_partial(self, stack): torch.manual_seed(0) param_multiplier = 2 - tdnet = SafeModule( - module=NormalParamWrapper(nn.BatchNorm1d(32 * param_multiplier)), - spec=None, - in_keys=["in"], - out_keys=["loc", "scale"], - ) + net1 = nn.Linear(3, 4) - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 32) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(32) - else: - raise NotImplementedError + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = NormalParamWrapper(net2) + net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - kwargs = {"distribution_class": TanhNormal} + net3 = nn.Linear(4, 4 * param_multiplier) + net3 = NormalParamWrapper(net3) + net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) + spec = BoundedTensorSpec(-0.1, 0.1, 4) - return - else: - prob_module = SafeProbabilisticModule( + kwargs = {"distribution_class": TanhNormal} + + tdmodule1 = SafeModule( + net1, + in_keys=["a"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeProbabilisticTensorDictSequential( + net2, + SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=["out"], spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - td = TensorDict({"in": torch.randn(3, 32 * param_multiplier)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 32]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net = nn.Linear(3, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tdmodule = SafeModule( - spec=spec, - module=net, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net = NormalParamWrapper(nn.Linear(3, 4 * param_multiplier)) - tdnet = SafeModule( - module=net, in_keys=["in"], out_keys=["loc", "scale"], spec=None - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tdmodule = SafeProbabilisticTensorDictSequential(tdnet, prob_module) - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - -class TestTDSequence: - # Temporarily disabling this test until 473 is merged in tensordict - # def test_in_key_warning(self): - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] - # ) - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] - # ) - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 1 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - kwargs = {} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - spec=spec, - module=net2, - in_keys=["hidden"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful_probabilistic(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - in_keys=["in"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - spec=spec, - in_keys=["loc", "scale"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, in_keys=["hidden"], out_keys=["loc", "scale"] - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, funs_to_decorate=["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td, params=params) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, spec=None, in_keys=["in"], out_keys=["hidden"], safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_functional_with_buffer_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - dummy_net = nn.Sequential(nn.Linear(7, 7), nn.BatchNorm1d(7)) - net2 = nn.Sequential( - nn.Linear(7, 7 * param_multiplier), nn.BatchNorm1d(7 * param_multiplier) - ) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 7) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(7) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, in_keys=["in"], out_keys=["hidden"], spec=None, safe=False - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - params = make_functional(tdmodule, ["forward", "get_dist"]) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - with params.unlock_(): - params["module", "1"] = params["module", "2"] - params["module", "2"] = params["module", "3"] - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - with params.unlock_(): - del params["module", "3"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 7)}, [3]) - tdmodule(td, params=params) - - dist = tdmodule.get_dist(td, params=params) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 7]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 1 - - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - net2, - spec=spec, - in_keys=["hidden"], - out_keys=["out"], - safe=safe, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - params = make_functional(tdmodule) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - with params.unlock_(): - params["module", "1"] = params["module", "2"] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - with params.unlock_(): - del params["module", "2"] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.skipif( - not _has_functorch, reason="vmap can only be used with functorch" - ) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - def test_vmap_probabilistic(self, safe, spec_type): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule(net2, in_keys=["hidden"], out_keys=["loc", "scale"]) - prob_module = SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, tdmodule2, prob_module - ) - - params = make_functional(tdmodule) - - # vmap = True - params = params.expand(10) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - if safe and spec_type == "bounded": - with pytest.raises( - RuntimeError, match="vmap cannot be used with safe=True" - ): - td_out = vmap(tdmodule, (None, 0))(td, params) - return - else: - td_out = vmap(tdmodule, (None, 0))(td, params) - assert td_out is not td - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - # vmap = (0, 0) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - td_repeat = td.expand(10, *td.batch_size) - td_out = vmap(tdmodule, (0, 0))(td_repeat, params) - assert td_out is not td_repeat - assert td_out.shape == torch.Size([10, 3]) - assert td_out.get("out").shape == torch.Size([10, 3, 4]) - # test bounds - if not safe and spec_type == "bounded": - assert ((td_out.get("out") > 0.1) | (td_out.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td_out.get("out") < 0.1) | (td_out.get("out") > -0.1)).all() - - @pytest.mark.parametrize("functional", [True, False]) - def test_submodule_sequence(self, functional): - td_module_1 = SafeModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = SafeModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = SafeSequential(td_module_1, td_module_2) - - if functional: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - params = make_functional(sub_seq_1) - sub_seq_1(td_1, params=params) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - params = make_functional(sub_seq_2) - sub_seq_2(td_2, params=params) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - else: - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - - @pytest.mark.parametrize("stack", [True, False]) - @pytest.mark.parametrize("functional", [True, False]) - def test_sequential_partial(self, stack, functional): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) - net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - - net3 = nn.Linear(4, 4 * param_multiplier) - net3 = NormalParamWrapper(net3) - net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - - spec = BoundedTensorSpec(-0.1, 0.1, 4) - - kwargs = {"distribution_class": TanhNormal} - - tdmodule1 = SafeModule( - net1, - in_keys=["a"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeProbabilisticTensorDictSequential( - net2, - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=True, + safe=True, **kwargs, ), ) @@ -1417,11 +506,6 @@ def test_sequential_partial(self, stack, functional): tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True ) - if functional: - params = make_functional(tdmodule) - else: - params = None - if stack: td = torch.stack( [ @@ -1430,10 +514,7 @@ def test_sequential_partial(self, stack, functional): ], 0, ) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() @@ -1444,10 +525,7 @@ def test_sequential_partial(self, stack, functional): assert "b" in td[0].keys() else: td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - if functional: - tdmodule(td, params=params) - else: - tdmodule(td) + tdmodule(td) assert "loc" in td.keys() assert "scale" in td.keys() assert "out" in td.keys() diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index fa3d28848a8..e779bfc165d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1145,7 +1145,6 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False) ) try: - # signature modified by make_functional sig = policy.forward.__signature__ except AttributeError: sig = inspect.signature(policy.forward) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 252b4fd2146..4a818474985 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -297,12 +297,11 @@ def _loss_actor( ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys) # Get an action from the actor network: since we made it functional, we need to pass the params - td_copy = self.actor_network(td_copy, params=self.actor_network_params) + with self.actor_network_params.to_module(self.actor_network): + td_copy = self.actor_network(td_copy) # get the value associated with that action - td_copy = self.value_network( - td_copy, - params=self.value_network_params.detach(), - ) + with self.value_network_params.detach().to_module(self.value_network): + td_copy = self.value_network(td_copy) return -td_copy.get("state_action_value") @@ -324,7 +323,8 @@ def _loss_value( td_copy = tensordict.clone() # V(s, a) - self.value_network(td_copy, params=self.value_network_params) + with self.value_network_params.to_module(self.value_network): + self.value_network(td_copy) pred_val = td_copy.get("state_action_value").squeeze(-1) # we manually reconstruct the parameters of the actor-critic, where the first @@ -339,9 +339,8 @@ def _loss_value( batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) - target_value = self.value_estimator.value_estimate( - tensordict, target_params=target_params - ).squeeze(-1) + with target_params.to_module(self.value_estimator): + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index ce3f0bb4b98..25213503e19 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -543,21 +543,30 @@ # Functional Programming (Ensembling / Meta-RL) # ---------------------------------------------- -from tensordict.nn import make_functional +from tensordict import TensorDict -params = make_functional(sequence) -len(list(sequence.parameters())) # functional modules have no parameters +params = TensorDict.from_module(sequence) +print("extracted params", params) ############################################################################### +# functional call using tensordict: -sequence(tensordict, params) +with params.to_module(sequence): + sequence(tensordict) ############################################################################### - +# Using vectorized map for model ensembling from torch import vmap params_expand = params.expand(4) -tensordict_exp = vmap(sequence, (None, 0))(tensordict, params_expand) + + +def exec_sequence(params, data): + with params.to_module(sequence): + return sequence(data) + + +tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict) print(tensordict_exp) ############################################################################### From b749e50df69c23ced38cdb43427d79e18deabca4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 27 Feb 2024 09:40:55 -0500 Subject: [PATCH 06/12] amend --- torchrl/objectives/a2c.py | 2 +- torchrl/objectives/decision_transformer.py | 2 +- torchrl/objectives/deprecated.py | 2 +- torchrl/objectives/ppo.py | 6 +++--- torchrl/objectives/redq.py | 2 +- torchrl/objectives/sac.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index d55fb145919..b52b055e357 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -471,7 +471,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_objective": loss}, batch_size=[]) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict) diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 954bd0b9a42..ec6ed4f5252 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -220,7 +220,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_log_likelihood": -log_likelihood, "loss_entropy": -entropy_bonus, "loss_alpha": loss_alpha, - "entropy": entropy.detach(), + "entropy": entropy.detach().mean(), "alpha": self.alpha.detach(), } return TensorDict(out, []) diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 3c95fd4cd11..9eb7d9a07e3 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -311,7 +311,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_qval, "loss_alpha": loss_alpha, "alpha": self.alpha, - "entropy": -sample_log_prob.detach(), + "entropy": -sample_log_prob.detach().mean(), }, [], ) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 7b0ee0f70b3..683e2f27553 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -554,7 +554,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[]) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict) @@ -800,7 +800,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict) @@ -1065,7 +1065,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) - td_out.set("entropy", entropy.detach()) # for logging + td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) if self.critic_coef: loss_critic = self.loss_critic(tensordict_copy) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index ecd112a58c9..817483a0269 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -563,7 +563,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_qval, "loss_alpha": loss_alpha, "alpha": self.alpha.detach(), - "entropy": -sample_log_prob.detach(), + "entropy": -sample_log_prob.detach().mean(), "state_action_value_actor": state_action_value_actor.detach(), "action_log_prob_actor": action_log_prob_actor.detach(), "next.state_value": next_state_value.detach(), diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index e51c10aeef1..277d068ca3e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -572,7 +572,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_qvalue, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": entropy, + "entropy": entropy.detach().mean(), } if self._version == 1: out["loss_value"] = loss_value @@ -1136,7 +1136,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_qvalue": loss_value, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": entropy, + "entropy": entropy.detach().mean(), } td_out = TensorDict(out, []) td_out = td_out.named_apply( From 3d8dbcf08dd0f8ba135bc8c0094dde4e3ba0ee7e Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 11:58:20 +0100 Subject: [PATCH 07/12] dedicated onpolicy tests --- test/test_cost.py | 117 +++++++++++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 33 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 60d1b1e374f..67c35400969 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6201,7 +6201,6 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", [True, False]) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_ppo( self, loss_class, @@ -6210,7 +6209,6 @@ def test_ppo( advantage, td_est, functional, - reduction, ): torch.manual_seed(self.seed) td = self._create_seq_mock_data_ppo(device=device) @@ -6246,7 +6244,6 @@ def test_ppo( value, loss_critic_type="l2", functional=functional, - reduction=reduction, ) if advantage is not None: advantage(td) @@ -6259,15 +6256,6 @@ def test_ppo( kl = loss.pop("kl") assert (kl != 0).any() - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) - loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -6804,6 +6792,39 @@ def test_ppo_notensordict( assert loss_obj == loss_val_td.get("loss_objective") assert loss_crit == loss_val_td.get("loss_critic") + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_ppo_reduction(self, reduction, loss_class): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_seq_mock_data_ppo(device=device) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + ) + loss_fn = loss_class( + actor, + value, + loss_critic_type="l2", + reduction=reduction, + ) + advantage(td) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + class TestA2C(LossModuleTestBase): seed = 0 @@ -6969,8 +6990,7 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", (True, False)) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reduction): + def test_a2c(self, device, gradient_mode, advantage, td_est, functional): torch.manual_seed(self.seed) td = self._create_seq_mock_data_a2c(device=device) @@ -7005,7 +7025,6 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti value, loss_critic_type="l2", functional=functional, - reduction=reduction, ) # Check error is raised when actions require grads @@ -7023,14 +7042,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional, reducti elif td_est is not None: loss_fn.make_value_estimator(td_est) loss = loss_fn(td) - if reduction == "none": - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) loss_critic = loss["loss_critic"] loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) loss_critic.backward(retain_graph=True) @@ -7413,6 +7425,37 @@ def test_a2c_notensordict( assert loss_objective == loss_val_td["loss_objective"] assert loss_critic == loss_val_td["loss_critic"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_a2c_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_seq_mock_data_a2c(device=device) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + ) + loss_fn = A2CLoss( + actor, + value, + loss_critic_type="l2", + ) + advantage(td) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + class TestReinforce(LossModuleTestBase): seed = 0 @@ -7659,26 +7702,16 @@ def _create_mock_common_layer_setup( return actor, critic, common, td @pytest.mark.parametrize("separate_losses", [False, True]) - @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_reinforce_tensordict_separate_losses(self, separate_losses, reduction): + def test_reinforce_tensordict_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() loss_fn = ReinforceLoss( actor_network=actor, critic_network=critic, separate_losses=separate_losses, - reduction=reduction, ) loss = loss_fn(td) - if reduction == "none": - - def func(x): - if x.dtype != torch.float: - return - return x.mean() - - loss = loss.apply(func, batch_size=[]) assert all( (p.grad is None) or (p.grad == 0).all() @@ -7807,6 +7840,24 @@ def test_reinforce_notensordict( return assert loss_actor == loss_val_td["loss_actor"] + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_reinforce_reduction(self, reduction): + torch.manual_seed(self.seed) + actor, critic, common, td = self._create_mock_common_layer_setup() + loss_fn = ReinforceLoss( + actor_network=actor, + critic_network=critic, + reduction=reduction, + ) + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + assert loss[key].shape == torch.Size([]) + @pytest.mark.parametrize("device", get_default_devices()) class TestDreamer(LossModuleTestBase): From e59652782b7f20d5cdb48bf23cfc87961148b896 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 16:28:00 +0100 Subject: [PATCH 08/12] merge origin --- test/test_cost.py | 52 ++++++++++++++++++--------------------- torchrl/objectives/ppo.py | 13 +++++----- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 67c35400969..49e5cbcab79 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -881,9 +881,8 @@ def test_dqn_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.parametrize("atoms", range(4, 10)) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) @@ -909,9 +908,8 @@ def test_distributional_dqn_reduction(self, reduction, atoms): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) class TestQMixer(LossModuleTestBase): @@ -1995,9 +1993,8 @@ def test_ddpg_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss_"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -2704,9 +2701,8 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -3617,9 +3613,8 @@ def test_sac_reduction(self, reduction, version): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -4206,9 +4201,8 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) @pytest.mark.skipif( @@ -5157,12 +5151,11 @@ def test_redq_reduction(self, reduction, deprecated_loss): if reduction == "none": for key in loss.keys(): if key.startswith("loss"): - assert loss[key].shape[-1] == td.shape[0] + assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): - continue - assert loss[key].shape == torch.Size([]) + if key.startswith("loss"): + assert loss[key].shape == torch.Size([]) class TestCQL(LossModuleTestBase): @@ -6821,8 +6814,9 @@ def test_ppo_reduction(self, reduction, loss_class): for key in loss.keys(): if key.startswith("loss"): assert loss[key].shape == td.shape - else: - for key in loss.keys(): + else: + for key in loss.keys(): + if key.startswith("loss"): assert loss[key].shape == torch.Size([]) @@ -7452,8 +7446,9 @@ def test_a2c_reduction(self, reduction): for key in loss.keys(): if key.startswith("loss"): assert loss[key].shape == td.shape - else: - for key in loss.keys(): + else: + for key in loss.keys(): + if key.startswith("loss"): assert loss[key].shape == torch.Size([]) @@ -7854,8 +7849,9 @@ def test_reinforce_reduction(self, reduction): for key in loss.keys(): if key.startswith("loss"): assert loss[key].shape == td.shape - else: - for key in loss.keys(): + else: + for key in loss.keys(): + if key.startswith("loss"): assert loss[key].shape == torch.Size([]) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index d871098cbb8..683e2f27553 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -import functools import math import warnings @@ -560,10 +559,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.critic_coef: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) - td_out = td_out.apply( - functools.partial(_reduce, reduction=self.reduction), batch_size=[] + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) + if name.startswith("loss_") + else value, + batch_size=[], ) - return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): @@ -807,7 +808,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("ESS", _reduce(ess, self.reduction) / batch) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], @@ -1070,7 +1071,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: loss_critic = self.loss_critic(tensordict_copy) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], From 42e96a44fce15ebc4a69fec9071cc452426a2caa Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 12:31:16 +0100 Subject: [PATCH 09/12] dedicated onpolicy tests --- test/test_cost.py | 1 + torchrl/objectives/a2c.py | 2 +- torchrl/objectives/reinforce.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 49e5cbcab79..33712c8af1a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7439,6 +7439,7 @@ def test_a2c_reduction(self, reduction): actor, value, loss_critic_type="l2", + reduction=reduction, ) advantage(td) loss = loss_fn(td) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 1846db4989a..b52b055e357 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -477,7 +477,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 37b595f7cb7..aaf871c8c7f 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -402,7 +402,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("loss_value", self.loss_critic(tensordict)) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) if name.startswith("loss_") else value, batch_size=[], From 38e9c1fa98f54eb51a1c13b4f935b9369f49a962 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 16:42:35 +0100 Subject: [PATCH 10/12] merge origin --- test/test_cost.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index c11a6b02dba..dbd94204d90 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6823,9 +6823,9 @@ def test_ppo_reduction(self, reduction, loss_class): assert loss[key].shape == td.shape else: for key in loss.keys(): - if key.startswith("loss"): - assert loss[key].shape == torch.Size([]) - + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) class TestA2C(LossModuleTestBase): seed = 0 @@ -7456,8 +7456,9 @@ def test_a2c_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - if key.startswith("loss"): - assert loss[key].shape == torch.Size([]) + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) class TestReinforce(LossModuleTestBase): @@ -7859,8 +7860,9 @@ def test_reinforce_reduction(self, reduction): assert loss[key].shape == td.shape else: for key in loss.keys(): - if key.startswith("loss"): - assert loss[key].shape == torch.Size([]) + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) @pytest.mark.parametrize("device", get_default_devices()) From d2046347f55c4175de3aaa7b8aaf694cd2280522 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 16:45:12 +0100 Subject: [PATCH 11/12] format --- test/test_cost.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_cost.py b/test/test_cost.py index dbd94204d90..efb54e57340 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6827,6 +6827,7 @@ def test_ppo_reduction(self, reduction, loss_class): continue assert loss[key].shape == torch.Size([]) + class TestA2C(LossModuleTestBase): seed = 0 From 76d24e1ef3f6c43cdcd6f4c12273c9811dddec04 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 27 Feb 2024 16:54:13 +0100 Subject: [PATCH 12/12] minor change --- test/test_cost.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index efb54e57340..73ea5fd27b4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -877,11 +877,11 @@ def test_dqn_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -905,11 +905,11 @@ def test_distributional_dqn_reduction(self, reduction, atoms): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -1991,11 +1991,11 @@ def test_ddpg_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -2700,11 +2700,11 @@ def test_td3_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -3613,11 +3613,11 @@ def test_sac_reduction(self, reduction, version): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -4202,11 +4202,11 @@ def test_discrete_sac_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -5156,11 +5156,11 @@ def test_redq_reduction(self, reduction, deprecated_loss): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): - assert loss[key].shape == td.shape + if key.startswith("loss_"): + assert loss[key].shape[-1] == td.shape[0] else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -6819,11 +6819,11 @@ def test_ppo_reduction(self, reduction, loss_class): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -7453,11 +7453,11 @@ def test_a2c_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([]) @@ -7857,11 +7857,11 @@ def test_reinforce_reduction(self, reduction): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): - if key.startswith("loss"): + if key.startswith("loss_"): assert loss[key].shape == td.shape else: for key in loss.keys(): - if not key.startswith("loss"): + if not key.startswith("loss_"): continue assert loss[key].shape == torch.Size([])