From 8d278401abba61e333b425aae813cadafd9ce9aa Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 15:04:41 +0000 Subject: [PATCH 1/3] tests --- test/test_cost.py | 634 ++++++++++++++++-- test/test_rb.py | 4 +- torchrl/_utils.py | 5 + torchrl/collectors/collectors.py | 12 +- torchrl/data/replay_buffers/replay_buffers.py | 12 +- torchrl/data/replay_buffers/samplers.py | 20 +- torchrl/objectives/common.py | 7 +- torchrl/objectives/ddpg.py | 22 +- torchrl/objectives/dqn.py | 24 +- torchrl/objectives/sac.py | 30 +- torchrl/objectives/td3.py | 26 +- torchrl/objectives/td3_bc.py | 28 +- torchrl/objectives/utils.py | 48 +- torchrl/objectives/value/advantages.py | 4 +- torchrl/trainers/trainers.py | 4 +- 15 files changed, 766 insertions(+), 114 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index bc4f98a4d9f..53a3966495a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -44,8 +44,17 @@ from tensordict.utils import unravel_key from torch import autograd, nn -from torchrl._utils import _standardize -from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded +from torchrl._utils import _standardize, rl_warnings +from torchrl.data import ( + Bounded, + Categorical, + Composite, + LazyTensorStorage, + MultiOneHot, + OneHot, + TensorDictPrioritizedReplayBuffer, + Unbounded, +) from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs import EnvBase, GymEnv, InitTracker, SerialEnv from torchrl.envs.libs.gym import _has_gym @@ -674,7 +683,7 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): loss_fn.make_value_estimator(td_est) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(td): loss = loss_fn(td) @@ -738,7 +747,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) @@ -897,7 +906,7 @@ def test_distributional_dqn( UserWarning, match="No target network updater has been associated with this loss module", ) - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ): loss = loss_fn(td) @@ -1084,6 +1093,96 @@ def test_distributional_dqn_reduction(self, reduction, atoms): continue assert loss[key].shape == torch.Size([]) + def test_dqn_prioritized_weights(self): + """Test DQN with prioritized replay buffer weighted loss reduction.""" + n_obs = 4 + n_actions = 3 + batch_size = 32 + buffer_size = 100 + + # Create DQN value network using QValueActor + module = nn.Linear(n_obs, n_actions) + action_spec = Categorical(n_actions) + value = QValueActor( + spec=Composite( + { + "action": action_spec, + "action_value": None, + "chosen_action_value": None, + }, + shape=[], + ), + action_space="categorical", + module=module, + ) + + # Create DQN loss + loss_fn = DQNLoss( + value_network=value, action_space="categorical", reduction="mean" + ) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randint(0, n_actions, (buffer_size,)), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss"]) + + # Verify manual weighted average matches + loss_fn_no_reduction = DQNLoss( + value_network=value, + action_space="categorical", + reduction="none", + use_prioritized_weights=False, + ) + loss_fn_no_reduction.make_value_estimator() + loss_fn_no_reduction.target_value_network_params = ( + loss_fn.target_value_network_params + ) + + loss_elements = loss_fn_no_reduction(sample2)["loss"] + manual_weighted_loss = (loss_elements * weights2).sum() / weights2.sum() + assert torch.allclose(loss_out2["loss"], manual_weighted_loss, rtol=1e-4) + class TestQMixer(LossModuleTestBase): seed = 0 @@ -1246,7 +1345,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): loss_fn.make_value_estimator(td_est) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(td): loss = loss_fn(td) @@ -1322,7 +1421,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) @@ -1547,6 +1646,96 @@ def test_mixer_keys( else: loss(td) + def test_dqn_prioritized_weights(self): + """Test DQN with prioritized replay buffer weighted loss reduction.""" + n_obs = 4 + n_actions = 3 + batch_size = 32 + buffer_size = 100 + + # Create DQN value network using QValueActor + module = nn.Linear(n_obs, n_actions) + action_spec = Categorical(n_actions) + value = QValueActor( + spec=Composite( + { + "action": action_spec, + "action_value": None, + "chosen_action_value": None, + }, + shape=[], + ), + action_space="categorical", + module=module, + ) + + # Create DQN loss + loss_fn = DQNLoss( + value_network=value, action_space="categorical", reduction="mean" + ) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randint(0, n_actions, (buffer_size,)), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss"]) + + # Verify manual weighted average matches + loss_fn_no_reduction = DQNLoss( + value_network=value, + action_space="categorical", + reduction="none", + use_prioritized_weights=False, + ) + loss_fn_no_reduction.make_value_estimator() + loss_fn_no_reduction.target_value_network_params = ( + loss_fn.target_value_network_params + ) + + loss_elements = loss_fn_no_reduction(sample2)["loss"] + manual_weighted_loss = (loss_elements * weights2).sum() / weights2.sum() + assert torch.allclose(loss_out2["loss"], manual_weighted_loss, rtol=1e-4) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -1758,7 +1947,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): with _check_td_steady(td), ( pytest.warns(UserWarning, match="No target network updater has been") - if (delay_actor or delay_value) + if (delay_actor or delay_value) and rl_warnings() else contextlib.nullcontext() ): loss = loss_fn(td) @@ -1882,7 +2071,9 @@ def test_ddpg_separate_losses( separate_losses=separate_losses, ) - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) # remove warning @@ -2004,7 +2195,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): ms_td = ms(td.clone()) with ( pytest.warns(UserWarning, match="No target network updater has been") - if (delay_value or delay_value) + if (delay_value or delay_value) and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) @@ -2110,7 +2301,7 @@ def test_ddpg_tensordict_run(self, td_est): with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater has been" - ): + ) if rl_warnings() else contextlib.nullcontext(): _ = loss_fn(td) def test_ddpg_notensordict(self): @@ -2132,7 +2323,9 @@ def test_ddpg_notensordict(self): } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss_val_td = loss(td) loss_val = loss(**kwargs) for i, key in enumerate(loss.out_keys): @@ -2182,6 +2375,94 @@ def test_ddpg_reduction(self, reduction): continue assert loss[key].shape == torch.Size([]) + def test_ddpg_prioritized_weights(self): + """Test DDPG with prioritized replay buffer weighted loss reduction.""" + n_obs = 4 + n_act = 2 + batch_size = 32 + buffer_size = 100 + + # Actor network + actor_net = MLP(in_features=n_obs, out_features=n_act, num_cells=[64, 64]) + actor = ValueOperator( + module=actor_net, + in_keys=["observation"], + out_keys=["action"], + ) + + # Q-value network + qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) + qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) + + # Create DDPG loss + loss_fn = DDPGLoss(actor_network=actor, value_network=qvalue) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randn(buffer_size, n_act).clamp(-1, 1), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss_value"]) + + # Verify weighted vs unweighted differ + loss_fn_no_weights = DDPGLoss( + actor_network=actor, + value_network=qvalue, + use_prioritized_weights=False, + ) + loss_fn_no_weights.make_value_estimator() + loss_fn_no_weights.value_network_params = loss_fn.value_network_params + loss_fn_no_weights.target_value_network_params = ( + loss_fn.target_value_network_params + ) + loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params + loss_fn_no_weights.target_actor_network_params = ( + loss_fn.target_actor_network_params + ) + + loss_out_no_weights = loss_fn_no_weights(sample2) + # Weighted and unweighted should differ (in general) + assert torch.isfinite(loss_out_no_weights["loss_value"]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -2437,7 +2718,7 @@ def test_td3( UserWarning, match="No target network updater has been associated with this loss module", ) - if (delay_actor or delay_qvalue) + if (delay_actor or delay_qvalue) and rl_warnings() else contextlib.nullcontext() ): with _check_td_steady(td): @@ -2550,7 +2831,7 @@ def test_td3_deactivate_vmap( UserWarning, match="No target network updater has been associated with this loss module", ) - if (delay_actor or delay_qvalue) + if (delay_actor or delay_qvalue) and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(td): torch.manual_seed(1) @@ -2579,7 +2860,7 @@ def test_td3_deactivate_vmap( UserWarning, match="No target network updater has been associated with this loss module", ) - if (delay_actor or delay_qvalue) + if (delay_actor or delay_qvalue) and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(td): torch.manual_seed(1) @@ -2655,7 +2936,9 @@ def test_td3_separate_losses( loss_function="l2", separate_losses=separate_losses, ) - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert all( @@ -2747,7 +3030,7 @@ def test_td3_batcher( with ( pytest.warns(UserWarning, match="No target network updater has been") - if (delay_qvalue or delay_actor) + if (delay_qvalue or delay_actor) and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) @@ -2931,7 +3214,9 @@ def test_td3_notensordict( } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(0) loss_val_td = loss(td) torch.manual_seed(0) @@ -2990,6 +3275,105 @@ def test_td3_reduction(self, reduction): continue assert loss[key].shape == torch.Size([]) + def test_td3_prioritized_weights(self): + """Test TD3 with prioritized replay buffer weighted loss reduction.""" + n_obs = 4 + n_act = 2 + batch_size = 32 + buffer_size = 100 + + # Actor network + actor_net = MLP(in_features=n_obs, out_features=n_act, num_cells=[64, 64]) + actor = ValueOperator( + module=actor_net, + in_keys=["observation"], + out_keys=["action"], + ) + + # Q-value network + qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) + qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) + + # Create TD3 loss + loss_fn = TD3Loss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=2, + action_spec=Bounded( + low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) + ), + ) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randn(buffer_size, n_act).clamp(-1, 1), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss_qvalue"]) + + # Verify weighted vs unweighted differ + loss_fn_no_weights = TD3Loss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=2, + action_spec=Bounded( + low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) + ), + use_prioritized_weights=False, + ) + loss_fn_no_weights.make_value_estimator() + loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params + loss_fn_no_weights.target_qvalue_network_params = ( + loss_fn.target_qvalue_network_params + ) + loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params + loss_fn_no_weights.target_actor_network_params = ( + loss_fn.target_actor_network_params + ) + + loss_out_no_weights = loss_fn_no_weights(sample2) + # Weighted and unweighted should differ (in general) + assert torch.isfinite(loss_out_no_weights["loss_qvalue"]) + class TestTD3BC(LossModuleTestBase): seed = 0 @@ -3245,7 +3629,7 @@ def test_td3bc( UserWarning, match="No target network updater has been associated with this loss module", ) - if (delay_actor or delay_qvalue) + if (delay_actor or delay_qvalue) and rl_warnings() else contextlib.nullcontext() ): with _check_td_steady(td): @@ -3378,7 +3762,9 @@ def test_td3bc_separate_losses( loss_function="l2", separate_losses=separate_losses, ) - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert all( @@ -3480,7 +3866,7 @@ def test_td3bc_batcher( with ( pytest.warns(UserWarning, match="No target network updater has been") - if (delay_qvalue or delay_actor) + if (delay_qvalue or delay_actor) and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) @@ -3664,7 +4050,9 @@ def test_td3bc_notensordict( } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(0) loss_val_td = loss(td) torch.manual_seed(0) @@ -4067,7 +4455,7 @@ def test_sac( with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -4244,7 +4632,7 @@ def test_sac_deactivate_vmap( torch.manual_seed(0) with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_vmap = loss_fn_vmap(td) td = tdc torch.manual_seed(0) @@ -4272,7 +4660,7 @@ def test_sac_deactivate_vmap( ) if torch.__version__ < "2.7" else contextlib.nullcontext(): with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_no_vmap = loss_fn_no_vmap(td) assert_allclose_td(loss_vmap, loss_no_vmap) @@ -4356,7 +4744,9 @@ def test_sac_separate_losses( num_qvalue_nets=1, separate_losses=separate_losses, ) - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -4488,7 +4878,7 @@ def test_sac_batcher( with pytest.warns( UserWarning, match="No target network updater has been associated with this loss module", - ): + ) if rl_warnings() else contextlib.nullcontext(): with _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() @@ -4711,7 +5101,9 @@ def test_sac_notensordict( # setting the seed for each loss so that drawing the random samples from value network # leads to same numbers for both runs torch.manual_seed(self.seed) - with pytest.warns(UserWarning, match="No target network updater"): + with pytest.warns( + UserWarning, match="No target network updater" + ) if rl_warnings() else contextlib.nullcontext(): loss_val = loss(**kwargs) torch.manual_seed(self.seed) @@ -5094,7 +5486,7 @@ def test_discrete_sac( with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -5214,7 +5606,7 @@ def test_discrete_sac_deactivate_vmap( tdc = td.clone() with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(1) loss_vmap = loss_fn_vmap(td) td = tdc @@ -5245,7 +5637,7 @@ def test_discrete_sac_deactivate_vmap( ) if torch.__version__ < "2.7" else contextlib.nullcontext(): with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(1) loss_no_vmap = loss_fn_no_vmap(td) assert_allclose_td(loss_vmap, loss_no_vmap) @@ -5343,7 +5735,7 @@ def test_discrete_sac_batcher( np.random.seed(0) with _check_td_steady(ms_td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() @@ -5524,7 +5916,9 @@ def test_discrete_sac_notensordict( } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss_val = loss(**kwargs) loss_val_td = loss(td) @@ -6669,7 +7063,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): UserWarning, match="No target network updater has been associated with this loss module", ) - if delay_qvalue + if delay_qvalue and rl_warnings() else contextlib.nullcontext() ): with _check_td_steady(td): @@ -6788,7 +7182,7 @@ def test_redq_separate_losses(self, separate_losses): with pytest.warns( UserWarning, match="No target network updater has been associated with this loss module", - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) # check that losses are independent @@ -6876,7 +7270,7 @@ def test_redq_deprecated_separate_losses(self, separate_losses): with pytest.warns( UserWarning, match="No target network updater has been associated with this loss module", - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) SoftUpdate(loss_fn, eps=0.5) @@ -7057,7 +7451,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): torch.manual_seed(0) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_qvalue + if delay_qvalue and rl_warnings() else contextlib.nullcontext() ): with _check_td_steady(td_clone1): @@ -7102,7 +7496,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_qvalue + if delay_qvalue and rl_warnings() else contextlib.nullcontext() ): with _check_td_steady(ms_td): @@ -7297,7 +7691,7 @@ def test_redq_notensordict( with pytest.warns( UserWarning, match="No target network updater has been associated with this loss module", - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_val = loss(**kwargs) torch.manual_seed(self.seed) loss_val_td = loss(td) @@ -7373,6 +7767,118 @@ def test_redq_reduction(self, reduction, deprecated_loss): continue assert loss[key].shape == torch.Size([]) + def test_sac_prioritized_weights(self): + """Test SAC with prioritized replay buffer weighted loss reduction.""" + n_obs = 4 + n_act = 2 + batch_size = 32 + buffer_size = 100 + + # Actor network + actor_net = nn.Sequential( + nn.Linear(n_obs, 64), + nn.ReLU(), + nn.Linear(64, 2 * n_act), + NormalParamExtractor(), + ) + actor_module = TensorDictModule( + actor_net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=actor_module, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + return_log_prob=True, + spec=Bounded( + low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) + ), + ) + + # Q-value network + qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) + qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) + + # Value network for SAC v1 + value_net = MLP(in_features=n_obs, out_features=1, num_cells=[64, 64]) + value = ValueOperator(module=value_net, in_keys=["observation"]) + + # Create SAC loss + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + ) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randn(buffer_size, n_act).clamp(-1, 1), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss_qvalue"]) + + # Verify weighted vs unweighted differ + loss_fn_no_weights = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + use_prioritized_weights=False, + ) + loss_fn_no_weights.make_value_estimator() + loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params + loss_fn_no_weights.target_qvalue_network_params = ( + loss_fn.target_qvalue_network_params + ) + loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params + loss_fn_no_weights.value_network_params = loss_fn.value_network_params + loss_fn_no_weights.target_value_network_params = ( + loss_fn.target_value_network_params + ) + + loss_out_no_weights = loss_fn_no_weights(sample2) + # Weighted and unweighted should differ (in general) + assert torch.isfinite(loss_out_no_weights["loss_qvalue"]) + class TestCQL(LossModuleTestBase): seed = 0 @@ -7532,7 +8038,7 @@ def test_cql( with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -7711,7 +8217,7 @@ def test_cql_deactivate_vmap( tdc = td.clone() with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(1) loss_vmap = loss_fn_vmap(td) td = tdc @@ -7743,7 +8249,7 @@ def test_cql_deactivate_vmap( ) if torch.__version__ < "2.7" else contextlib.nullcontext(): with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(1) loss_no_vmap = loss_fn_no_vmap(td) assert_allclose_td(loss_vmap, loss_no_vmap) @@ -7894,7 +8400,9 @@ def test_cql_batcher( torch.manual_seed(0) np.random.seed(0) - with pytest.warns(UserWarning, match="No target network updater"): + with pytest.warns( + UserWarning, match="No target network updater" + ) if rl_warnings() else contextlib.nullcontext(): with _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() @@ -8174,7 +8682,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): loss_fn.make_value_estimator(td_est) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(td): loss = loss_fn(td) @@ -8236,7 +8744,7 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) @@ -10892,7 +11400,7 @@ def test_reinforce_value_net( ) with ( pytest.warns(UserWarning, match="No target network updater has been") - if delay_value + if delay_value and rl_warnings() else contextlib.nullcontext() ): if advantage is not None: @@ -12661,7 +13169,7 @@ def test_iql( with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -12790,7 +13298,7 @@ def test_iql_deactivate_vmap( with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(1) loss_vmap = loss_fn_vmap(td) @@ -12818,7 +13326,7 @@ def test_iql_deactivate_vmap( ) if torch.__version__ < "2.7" else contextlib.nullcontext(): with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): torch.manual_seed(1) loss_no_vmap = loss_fn_no_vmap(td) assert_allclose_td(loss_vmap, loss_no_vmap) @@ -12872,7 +13380,9 @@ def test_iql_separate_losses(self, separate_losses): loss_function="l2", separate_losses=separate_losses, ) - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -13060,7 +13570,7 @@ def test_iql_batcher( np.random.seed(0) with _check_td_steady(ms_td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() @@ -13225,7 +13735,7 @@ def test_iql_notensordict( with pytest.warns( UserWarning, match="No target network updater has been associated with this loss module", - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_val = loss(**kwargs) loss_val_td = loss(td) assert len(loss_val) == 4 @@ -13272,7 +13782,7 @@ def test_iql_reduction(self, reduction): loss_fn.make_value_estimator() with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): @@ -13554,7 +14064,7 @@ def test_discrete_iql( with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -13695,7 +14205,9 @@ def test_discrete_iql_separate_losses(self, separate_losses): separate_losses=separate_losses, action_space="one-hot", ) - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -13884,7 +14396,7 @@ def test_discrete_iql_batcher( np.random.seed(0) with _check_td_steady(ms_td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() @@ -14054,7 +14566,7 @@ def test_discrete_iql_notensordict( with pytest.warns( UserWarning, match="No target network updater has been associated with this loss module", - ): + ) if rl_warnings() else contextlib.nullcontext(): loss_val = loss(**kwargs) loss_val_td = loss(td) assert len(loss_val) == 4 @@ -14102,7 +14614,7 @@ def test_discrete_iql_reduction(self, reduction): loss_fn.make_value_estimator() with _check_td_steady(td), pytest.warns( UserWarning, match="No target network updater" - ): + ) if rl_warnings() else contextlib.nullcontext(): loss = loss_fn(td) if reduction == "none": for key in loss.keys(): @@ -14328,7 +14840,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: module = custom_module(delay_module=True) _ = module.module1_params - with pytest.warns(UserWarning, match="No target network updater has been"): + with pytest.warns( + UserWarning, match="No target network updater has been" + ) if rl_warnings() else contextlib.nullcontext(): _ = module.target_module1_params if mode == "hard": upd = HardUpdate( @@ -17151,7 +17665,7 @@ def fun(a, b, time_dim=-2): def test_updater_warning(updater, kwarg): with warnings.catch_warnings(): dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot") - with pytest.warns(UserWarning): + with pytest.warns(UserWarning) if rl_warnings() else contextlib.nullcontext(): dqn.target_value_network_params with warnings.catch_warnings(): updater(dqn, **kwarg) diff --git a/test/test_rb.py b/test/test_rb.py index bc6acaeb3be..15b9b9af0e5 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1804,9 +1804,9 @@ def test_batch_errors(): @pytest.mark.skipif(not torchrl._utils.RL_WARNINGS, reason="RL_WARNINGS is not set") def test_add_warning(): - from torchrl._utils import RL_WARNINGS + from torchrl._utils import rl_warnings - if not RL_WARNINGS: + if not rl_warnings(): return rb = ReplayBuffer(storage=ListStorage(10), batch_size=3) with pytest.warns( diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 50cfa8af7d3..f090831d25c 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -1062,3 +1062,8 @@ def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]: runtime_env["env_vars"] = dict(runtime_env["env_vars"]) return ray_init_config + + +def rl_warnings(): + """Checks the status of the RL_WARNINGS env varioble.""" + return RL_WARNINGS diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3686368ae71..b7be73d243f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -47,7 +47,7 @@ compile_with_warmup, logger as torchrl_logger, prod, - RL_WARNINGS, + rl_warnings, VERBOSE, ) from torchrl.collectors.utils import split_trajectories @@ -1218,7 +1218,7 @@ def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: total_frames = float("inf") else: remainder = total_frames % frames_per_batch - if remainder != 0 and RL_WARNINGS: + if remainder != 0 and rl_warnings(): warnings.warn( f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " f"This means {frames_per_batch - remainder} additional frames will be collected." @@ -1238,7 +1238,7 @@ def _setup_init_random_frames( if ( init_random_frames not in (-1, None, 0) and init_random_frames % frames_per_batch != 0 - and RL_WARNINGS + and rl_warnings() ): warnings.warn( f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " @@ -1261,7 +1261,7 @@ def _setup_postproc(self, postproc: Callable | None) -> None: def _setup_frames_per_batch(self, frames_per_batch: int) -> None: """Calculate and validate frames per batch.""" - if frames_per_batch % self.n_env != 0 and RL_WARNINGS: + if frames_per_batch % self.n_env != 0 and rl_warnings(): warnings.warn( f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " f" this results in more frames_per_batch per iteration that requested" @@ -2809,7 +2809,7 @@ def _setup_multi_total_frames( total_frames = float("inf") else: remainder = total_frames % total_frames_per_batch - if remainder != 0 and RL_WARNINGS: + if remainder != 0 and rl_warnings(): warnings.warn( f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " f"This means {total_frames_per_batch - remainder} additional frames will be collected. " @@ -3741,7 +3741,7 @@ def update_policy_weights_( def frames_per_batch_worker(self, worker_idx: int | None) -> int: if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): return self._frames_per_batch[worker_idx] - if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: + if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings(): warnings.warn( f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," f" this results in more frames_per_batch per iteration that requested." diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 179afdee8b2..da090c9fc94 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -41,7 +41,7 @@ from torch import Tensor from torch.utils._pytree import tree_map -from torchrl._utils import accept_remote_rref_udf_invocation, RL_WARNINGS +from torchrl._utils import accept_remote_rref_udf_invocation, rl_warnings from torchrl.data.replay_buffers.samplers import ( PrioritizedSampler, RandomSampler, @@ -871,7 +871,7 @@ def add(self, data: Any) -> int: data = None if data is None: return torch.zeros((0, self._storage.ndim), dtype=torch.long) - if RL_WARNINGS and is_tensor_collection(data) and data.ndim: + if rl_warnings() and is_tensor_collection(data) and data.ndim: warnings.warn( f"Using `add()` with a TensorDict that has batch_size={data.batch_size}. " f"Use `extend()` to add multiple elements, or `add()` with a single element (batch_size=torch.Size([])). " @@ -1319,14 +1319,14 @@ class PrioritizedReplayBuffer(ReplayBuffer): >>> # get the info to find what the indices are >>> sample, info = rb.sample(5, return_info=True) >>> print(sample, info) - tensor([2, 7, 4, 3, 5]) {'_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])} + tensor([2, 7, 4, 3, 5]) {'priority_weight': array([1., 1., 1., 1., 1.], dtype=float32), 'index': array([2, 7, 4, 3, 5])} >>> # update priority >>> priority = torch.ones(5) * 5 >>> rb.update_priority(info["index"], priority) >>> # and now a new sample, the weights should be updated >>> sample, info = rb.sample(5, return_info=True) >>> print(sample, info) - tensor([2, 5, 2, 2, 5]) {'_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465], + tensor([2, 5, 2, 2, 5]) {'priority_weight': array([0.36278465, 0.36278465, 0.36278465, 0.36278465, 0.36278465], dtype=float32), 'index': array([2, 5, 2, 2, 5])} """ @@ -1861,7 +1861,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): >>> print(sample) TensorDict( fields={ - _weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), + priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ @@ -1884,7 +1884,7 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): >>> print(sample) TensorDict( fields={ - _weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), + priority_weight: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index d15bc9638e4..70f9f2f4e7f 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -21,7 +21,7 @@ from tensordict.utils import NestedKey from torch.utils._pytree import tree_map from torchrl._extension import EXTENSION_WARNING -from torchrl._utils import _replace_last, logger, RL_WARNINGS +from torchrl._utils import _replace_last, logger, rl_warnings from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage from torchrl.data.replay_buffers.utils import _auto_device, _is_int, unravel_index @@ -373,7 +373,7 @@ class PrioritizedSampler(Sampler): device=cpu, is_shared=False) >>> print(info) - {'_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, + {'priority_weight': array([1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11, 1.e-11], dtype=float32), 'index': array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])} .. note:: Using a :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` can smoothen the @@ -423,7 +423,7 @@ def __init__( self.dtype = dtype self._max_priority_within_buffer = max_priority_within_buffer self._init() - if RL_WARNINGS and SumSegmentTreeFp32 is None: + if rl_warnings() and SumSegmentTreeFp32 is None: logger.warning(EXTENSION_WARNING) def __repr__(self): @@ -588,7 +588,7 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: weight = torch.pow(weight / p_min, -self._beta) if storage.ndim > 1: index = unravel_index(index, storage.shape) - return index, {"_weight": weight} + return index, {"priority_weight": weight} def add(self, index: torch.Tensor | int) -> None: super().add(index) @@ -2068,7 +2068,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): episode [2, 2, 2, 2, 1, 1] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 1, 2] - >>> print("weight", info["_weight"].tolist()) + >>> print("weight", info["priority_weight"].tolist()) weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1]) >>> rb.update_priority(torch.arange(0,9,1), priority=priority) @@ -2077,7 +2077,7 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): episode [2, 2, 2, 2, 2, 2] >>> print("steps", sample["steps"].tolist()) steps [1, 2, 0, 1, 0, 1] - >>> print("weight", info["_weight"].tolist()) + >>> print("weight", info["priority_weight"].tolist()) weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06] """ @@ -2294,7 +2294,9 @@ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict] if isinstance(starts, tuple): starts = torch.stack(starts, -1) # starts = torch.as_tensor(starts, device=lengths.device) - info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device) + info["priority_weight"] = torch.as_tensor( + info["priority_weight"], device=lengths.device + ) # extends starting indices of each slice with sequence_length to get indices of all steps index = self._tensor_slices_from_startend( @@ -2302,7 +2304,9 @@ def sample(self, storage: Storage, batch_size: int) -> tuple[torch.Tensor, dict] ) # repeat the weight of each slice to match the number of steps - info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length) + info["priority_weight"] = torch.repeat_interleave( + info["priority_weight"], seq_length + ) if self.truncated_key is not None: # following logics borrowed from SliceSampler diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 1f6daabced1..0e1d25c689f 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -19,7 +19,7 @@ from torch import nn from torch.nn import Parameter -from torchrl._utils import RL_WARNINGS +from torchrl._utils import rl_warnings from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules.tensordict_module.rnn import set_recurrent_mode from torchrl.objectives.utils import ValueEstimators @@ -34,7 +34,7 @@ def _updater_check_forward_prehook(module, *args, **kwargs): if ( not all(module._has_update_associated.values()) - and RL_WARNINGS + and rl_warnings() and not is_compiling() ): warnings.warn( @@ -128,6 +128,7 @@ class _AcceptedKeys: tensor_keys: _AcceptedKeys _vmap_randomness = None default_value_estimator: ValueEstimators = None + use_prioritized_weights: str | bool = "auto" deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC @@ -449,7 +450,7 @@ def __getattr__(self, item): params = params.data elif ( not self._has_update_associated[item[7:-7]] - and RL_WARNINGS + and rl_warnings() and not is_compiling() ): # no updater associated diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index e22cf565f72..b5d2c67ada6 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -171,6 +171,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + priority_weight: NestedKey = "priority_weight" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys @@ -202,11 +203,13 @@ def __init__( gamma: float | None = None, separate_losses: bool = False, reduction: str | None = None, + use_prioritized_weights: str | bool = "auto", ) -> None: self._in_keys = None if reduction is None: reduction = "mean" super().__init__() + self.use_prioritized_weights = use_prioritized_weights self.delay_actor = delay_actor self.delay_value = delay_value @@ -268,6 +271,8 @@ def _set_in_keys(self): *self.value_network.in_keys, *[unravel_key(("next", key)) for key in self.value_network.in_keys], } + if self.use_prioritized_weights: + in_keys.add(unravel_key(self.tensor_keys.priority_weight)) self._in_keys = sorted(in_keys, key=str) @property @@ -295,8 +300,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a tuple of 2 tensors containing the DDPG loss. """ - loss_value, metadata = self.loss_value(tensordict) - loss_actor, metadata_actor = self.loss_actor(tensordict) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + loss_value, metadata = self.loss_value(tensordict, weights=weights) + loss_actor, metadata_actor = self.loss_actor(tensordict, weights=weights) metadata.update(metadata_actor) td_out = TensorDict( source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, @@ -315,6 +327,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def loss_actor( self, tensordict: TensorDictBase, + weights: torch.Tensor | None = None, ) -> [torch.Tensor, dict]: td_copy = tensordict.select( *self.actor_in_keys, *self.value_exclusive_keys, strict=False @@ -325,7 +338,7 @@ def loss_actor( td_copy = self.value_network(td_copy) loss_actor = -td_copy.get(self.tensor_keys.state_action_value).squeeze(-1) metadata = {} - loss_actor = _reduce(loss_actor, self.reduction) + loss_actor = _reduce(loss_actor, self.reduction, weights=weights) self._clear_weakrefs( tensordict, loss_actor, @@ -339,6 +352,7 @@ def loss_actor( def loss_value( self, tensordict: TensorDictBase, + weights: torch.Tensor | None = None, ) -> tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() @@ -372,7 +386,7 @@ def loss_value( "target_value_max": target_value.max(), "pred_value_max": pred_val.max(), } - loss_value = _reduce(loss_value, self.reduction) + loss_value = _reduce(loss_value, self.reduction, weights=weights) self._clear_weakrefs( tensordict, "value_network_params", diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index aeccf527108..5f91804825a 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -160,6 +160,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + priority_weight: NestedKey = "priority_weight" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys @@ -181,10 +182,12 @@ def __init__( action_space: str | TensorSpec = None, priority_key: str | None = None, reduction: str | None = None, + use_prioritized_weights: str | bool = "auto", ) -> None: if reduction is None: reduction = "mean" super().__init__() + self.use_prioritized_weights = use_prioritized_weights self._in_keys = None if double_dqn and not delay_value: raise ValueError("double_dqn=True requires delay_value=True.") @@ -364,7 +367,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: inplace=True, ) loss = distance_loss(pred_val_index, target_value, self.loss_function) - loss = _reduce(loss, reduction=self.reduction) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + loss = _reduce(loss, reduction=self.reduction, weights=weights) td_out = TensorDict(loss=loss) self._clear_weakrefs( @@ -440,6 +450,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" + priority_weight: NestedKey = "priority_weight" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys @@ -457,10 +468,12 @@ def __init__( delay_value: bool = True, priority_key: str | None = None, reduction: str | None = None, + use_prioritized_weights: str | bool = "auto", ): if reduction is None: reduction = "mean" super().__init__() + self.use_prioritized_weights = use_prioritized_weights self._set_deprecated_ctor_keys(priority=priority_key) self.register_buffer("gamma", torch.tensor(gamma)) self.delay_value = delay_value @@ -611,7 +624,14 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: loss.detach().unsqueeze(1).to(input_tensordict.device), inplace=True, ) - loss = _reduce(loss, reduction=self.reduction) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + loss = _reduce(loss, reduction=self.reduction, weights=weights) td_out = TensorDict(loss=loss) self._clear_weakrefs( tensordict, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index a3f9fe87560..715c8d22e4e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -292,6 +292,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + priority_weight: NestedKey = "priority_weight" def __post_init__(self): if self.log_prob is None: @@ -337,12 +338,14 @@ def __init__( reduction: str | None = None, skip_done_states: bool = False, deactivate_vmap: bool = False, + use_prioritized_weights: str | bool = "auto", ) -> None: self._in_keys = None self._out_keys = None if reduction is None: reduction = "mean" super().__init__() + self.use_prioritized_weights = use_prioritized_weights self._set_deprecated_ctor_keys(priority_key=priority_key) # Actor @@ -638,9 +641,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: } if self._version == 1: out["loss_value"] = loss_value - td_out = TensorDict(out, []) + td_out = TensorDict(out) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce( + value, reduction=self.reduction, weights=weights + ) if name.startswith("loss_") else value, ) @@ -1156,6 +1168,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" log_prob: NestedKey = "log_prob" + priority_weight: NestedKey = "priority_weight" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys @@ -1200,11 +1213,13 @@ def __init__( reduction: str | None = None, skip_done_states: bool = False, deactivate_vmap: bool = False, + use_prioritized_weights: str | bool = "auto", ): if reduction is None: reduction = "mean" self._in_keys = None super().__init__() + self.use_prioritized_weights = use_prioritized_weights self._set_deprecated_ctor_keys(priority_key=priority_key) self.convert_to_functional( @@ -1347,8 +1362,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "entropy": entropy.detach().mean(), } td_out = TensorDict(out, []) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) + lambda name, value: _reduce( + value, reduction=self.reduction, weights=weights + ) if name.startswith("loss_") else value, ) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index a201fe5a72c..31b38816944 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -202,6 +202,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + priority_weight: NestedKey = "priority_weight" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys @@ -240,10 +241,12 @@ def __init__( separate_losses: bool = False, reduction: str | None = None, deactivate_vmap: bool = False, + use_prioritized_weights: str | bool = "auto", ) -> None: if reduction is None: reduction = "mean" super().__init__() + self.use_prioritized_weights = use_prioritized_weights self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) @@ -378,7 +381,9 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: + def actor_loss( + self, tensordict, weights: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -401,7 +406,7 @@ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: metadata = { "state_action_value_actor": state_action_value_actor.detach(), } - loss_actor = _reduce(loss_actor, reduction=self.reduction) + loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights) self._clear_weakrefs( tensordict, "actor_network_params", @@ -411,7 +416,9 @@ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: ) return loss_actor, metadata - def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]: + def value_loss( + self, tensordict, weights: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) @@ -485,7 +492,7 @@ def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]: "pred_value": current_qvalue.detach(), "target_value": target_value.detach(), } - loss_qval = _reduce(loss_qval, reduction=self.reduction) + loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights) self._clear_weakrefs( tensordict, "actor_network_params", @@ -498,8 +505,15 @@ def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]: @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_save = tensordict - loss_actor, metadata_actor = self.actor_loss(tensordict) - loss_qval, metadata_value = self.value_loss(tensordict_save) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) + loss_qval, metadata_value = self.value_loss(tensordict_save, weights=weights) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] ) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 796801c75d9..72095842f9d 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -215,6 +215,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + priority_weight: NestedKey = "priority_weight" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys @@ -255,10 +256,12 @@ def __init__( separate_losses: bool = False, reduction: str | None = None, deactivate_vmap: bool = False, + use_prioritized_weights: str | bool = "auto", ) -> None: if reduction is None: reduction = "mean" super().__init__() + self.use_prioritized_weights = use_prioritized_weights self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key) @@ -392,7 +395,9 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: + def actor_loss( + self, tensordict, weights: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. @@ -400,6 +405,7 @@ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: Args: tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields are required for this to be computed. + weights (torch.Tensor, optional): importance sampling weights for weighted reduction. Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"` used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda value, and the lambda value `"lmbd"` itself. @@ -436,7 +442,7 @@ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: "bc_loss": bc_loss.detach(), "lmbd": lmbd, } - loss_actor = _reduce(loss_actor, reduction=self.reduction) + loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights) self._clear_weakrefs( tensordict, "actor_network_params", @@ -446,7 +452,9 @@ def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: ) return loss_actor, metadata - def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]: + def qvalue_loss( + self, tensordict, weights: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. @@ -454,6 +462,7 @@ def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]: Args: tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields are required for this to be computed. + weights (torch.Tensor, optional): importance sampling weights for weighted reduction. Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`. """ @@ -530,7 +539,7 @@ def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]: "pred_value": current_qvalue.detach(), "target_value": target_value.detach(), } - loss_qval = _reduce(loss_qval, reduction=self.reduction) + loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights) self._clear_weakrefs( tensordict, "actor_network_params", @@ -550,8 +559,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class's `"in_keys"` and `"out_keys"` attributes. """ tensordict_save = tensordict - loss_actor, metadata_actor = self.actor_loss(tensordict) - loss_qval, metadata_value = self.qvalue_loss(tensordict_save) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) + loss_qval, metadata_value = self.qvalue_loss(tensordict_save, weights=weights) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] ) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 9effa91362c..b97ef5ea6a9 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -617,7 +617,10 @@ def new_func(*args, in_dims=in_dims, out_dims=out_dims, **kwargs): def _reduce( - tensor: torch.Tensor, reduction: str, mask: torch.Tensor | None = None + tensor: torch.Tensor, + reduction: str, + mask: torch.Tensor | None = None, + weights: torch.Tensor | None = None, ) -> float | torch.Tensor: """Reduces a tensor given the reduction method. @@ -625,19 +628,56 @@ def _reduce( tensor (torch.Tensor): The tensor to reduce. reduction (str): The reduction method. mask (torch.Tensor, optional): A mask to apply to the tensor before reducing. + weights (torch.Tensor, optional): Importance sampling weights for weighted reduction. + When provided with reduction="mean", computes: (tensor * weights).sum() / weights.sum() + When provided with reduction="sum", computes: (tensor * weights).sum() + This is used for proper bias correction with prioritized replay buffers. Returns: float | torch.Tensor: The reduced tensor. """ if reduction == "none": - result = tensor + if weights is None: + result = tensor + if mask is not None: + result = result[mask] + elif mask is not None: + masked_weight = weights[mask] + masked_tensor = tensor[mask] + result = masked_tensor * masked_weight + else: + result = tensor * weights elif reduction == "mean": - if mask is not None: + if weights is not None: + # Weighted average: (tensor * weights).sum() / weights.sum() + if mask is not None: + masked_weight = weights[mask] + masked_tensor = tensor[mask] + result = (masked_tensor * masked_weight).sum() / masked_weight.sum() + else: + if tensor.shape != weights.shape: + raise ValueError( + f"Tensor and weights shapes must match, but got {tensor.shape} and {weights.shape}" + ) + result = (tensor * weights).sum() / weights.sum() + elif mask is not None: result = tensor[mask].mean() else: result = tensor.mean() elif reduction == "sum": - if mask is not None: + if weights is not None: + # Weighted sum: (tensor * weights).sum() + if mask is not None: + masked_weight = weights[mask] + masked_tensor = tensor[mask] + result = (masked_tensor * masked_weight).sum() + else: + if tensor.shape != weights.shape: + raise ValueError( + f"Tensor and weights shapes must match, but got {tensor.shape} and {weights.shape}" + ) + result = (tensor * weights).sum() + elif mask is not None: result = tensor[mask].sum() else: result = tensor.sum() diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index abd961631cd..9f010a782b8 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -27,7 +27,7 @@ from tensordict.utils import NestedKey, unravel_key from torch import Tensor -from torchrl._utils import logger, RL_WARNINGS +from torchrl._utils import logger, rl_warnings from torchrl.envs.utils import step_mdp from torchrl.objectives.utils import ( _maybe_get_or_select, @@ -451,7 +451,7 @@ def _call_value_nets( try: ndim = list(data.names).index("time") + 1 except ValueError: - if RL_WARNINGS: + if rl_warnings(): logger.warning( "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " "This warning can be turned off by setting the environment variable RL_WARNINGS to False." diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 182e510145a..25f3ffa6357 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -28,7 +28,7 @@ _CKPT_BACKEND, KeyDependentDefaultDict, logger as torchrl_logger, - RL_WARNINGS, + rl_warnings, timeit, VERBOSE, ) @@ -2043,7 +2043,7 @@ def __call__(self, batch: TensorDictBase | None = None) -> dict: batch_size = self.trainer.collector.getattr_rb("batch_size") if not write_count: return {} - if batch_size is None and RL_WARNINGS: + if batch_size is None and rl_warnings(): warnings.warn("Batch size is not set. Using 1.") batch_size = 1 update_count = self.trainer._optim_count From 7db3fa9c7ef7d911c2f131ec8c4e623166156301 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 20:53:19 +0000 Subject: [PATCH 2/3] modular-sac --- torchrl/objectives/sac.py | 215 +++++++++----------------------------- 1 file changed, 50 insertions(+), 165 deletions(-) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 715c8d22e4e..b0df9615a19 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -616,14 +616,27 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + if self._version == 1: - loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict) - loss_value, _ = self._value_loss(tensordict) + loss_qvalue, value_metadata = self.qvalue_v1_loss( + tensordict, weights=weights + ) + loss_value, _ = self.value_loss(tensordict, weights=weights) else: - loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict) + loss_qvalue, value_metadata = self.qvalue_v2_loss( + tensordict, weights=weights + ) loss_value = None - loss_actor, metadata_actor = self._actor_loss(tensordict) + loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + loss_alpha = _reduce(loss_alpha, reduction=self.reduction, weights=weights) tensordict.set(self.tensor_keys.priority, value_metadata["td_error"]) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape @@ -642,20 +655,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self._version == 1: out["loss_value"] = loss_value td_out = TensorDict(out) - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - td_out = td_out.named_apply( - lambda name, value: _reduce( - value, reduction=self.reduction, weights=weights - ) - if name.startswith("loss_") - else value, - ) self._clear_weakrefs( tensordict, td_out, @@ -673,8 +672,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() - def _actor_loss( - self, tensordict: TensorDictBase + def actor_loss( + self, tensordict: TensorDictBase, weights: torch.Tensor | None = None ) -> tuple[Tensor, dict[str, Tensor]]: with set_exploration_type( ExplorationType.RANDOM @@ -697,81 +696,9 @@ def _actor_loss( raise RuntimeError( f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" ) - return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} - - @dispatch - def actor_loss( - self, tensordict: TensorDictBase - ) -> tuple[Tensor, dict[str, Tensor]]: - """Compute the actor loss for SAC. - - This method computes the actor loss which encourages the policy to maximize - the expected Q-value while maintaining high entropy. - - Args: - tensordict (TensorDictBase): A tensordict containing the data needed for - computing the actor loss. Should contain the observation and other - required keys for the actor network. - - Returns: - A tuple containing: - - The actor loss tensor - - A dictionary with metadata including the log probability of actions - """ - return self._actor_loss(tensordict) - - @dispatch - def qvalue_loss( - self, tensordict: TensorDictBase - ) -> tuple[Tensor, dict[str, Tensor]]: - """Compute the Q-value loss for SAC. - - This method computes the Q-value loss which trains the Q-networks to estimate - the expected return for state-action pairs. - - Args: - tensordict (TensorDictBase): A tensordict containing the data needed for - computing the Q-value loss. Should contain the observation, action, - reward, done, and terminated keys. - - Returns: - A tuple containing: - - The Q-value loss tensor - - A dictionary with metadata including the TD error - """ - if self._version == 1: - return self._qvalue_v1_loss(tensordict) - else: - return self._qvalue_v2_loss(tensordict) - - @dispatch - def value_loss( - self, tensordict: TensorDictBase - ) -> tuple[Tensor, dict[str, Tensor]]: - """Compute the value loss for SAC (version 1 only). - - This method computes the value loss which trains the value network to estimate - the expected return for states. This is only used in SAC version 1. - - Args: - tensordict (TensorDictBase): A tensordict containing the data needed for - computing the value loss. Should contain the observation and other - required keys for the value network. - - Returns: - A tuple containing: - - The value loss tensor - - An empty dictionary (no metadata for value loss) - - Raises: - RuntimeError: If called on SAC version 2 (which doesn't use a value network) - """ - if self._version != 1: - raise RuntimeError( - "Value loss is only available in SAC version 1. " - "SAC version 2 doesn't use a separate value network." - ) - return self._value_loss(tensordict) + loss_actor = self._alpha * log_prob - min_q_logprob + loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights) + return loss_actor, {"log_prob": log_prob.detach()} def alpha_loss(self, log_prob: Tensor) -> Tensor: """Compute the alpha loss for SAC. @@ -808,8 +735,8 @@ def _cached_target_params_actor_value(self): torch.Size([]), ) - def _qvalue_v1_loss( - self, tensordict: TensorDictBase + def qvalue_v1_loss( + self, tensordict: TensorDictBase, weights: torch.Tensor | None = None ) -> tuple[Tensor, dict[str, Tensor]]: target_params = self._cached_target_params_actor_value with set_exploration_type(self.deterministic_sampling_mode): @@ -841,6 +768,7 @@ def _qvalue_v1_loss( loss_value = distance_loss( pred_val, target_chunks, loss_function=self.loss_function ).view(*shape) + loss_value = _reduce(loss_value, reduction=self.reduction, weights=weights) metadata = {"td_error": (pred_val - target_chunks).pow(2).flatten(0, 1)} return loss_value, metadata @@ -923,8 +851,8 @@ def _compute_target_v2(self, tensordict) -> Tensor: target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def _qvalue_v2_loss( - self, tensordict: TensorDictBase + def qvalue_v2_loss( + self, tensordict: TensorDictBase, weights: torch.Tensor | None = None ) -> tuple[Tensor, dict[str, Tensor]]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._compute_target_v2(tensordict) @@ -942,11 +870,12 @@ def _qvalue_v2_loss( target_value.expand_as(pred_val), loss_function=self.loss_function, ).sum(0) + loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights) metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata - def _value_loss( - self, tensordict: TensorDictBase + def value_loss( + self, tensordict: TensorDictBase, weights: torch.Tensor | None = None ) -> tuple[Tensor, dict[str, Tensor]]: # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() @@ -979,6 +908,7 @@ def _value_loss( loss_value = distance_loss( pred_val, target_val, loss_function=self.loss_function ) + loss_value = _reduce(loss_value, reduction=self.reduction, weights=weights) return loss_value, {} def _alpha_loss(self, log_prob: Tensor) -> Tensor: @@ -1342,40 +1272,35 @@ def in_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - loss_value, metadata_value = self._value_loss(tensordict) - loss_actor, metadata_actor = self._actor_loss(tensordict) + # Extract weights for prioritized replay buffer + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + + loss_qvalue, metadata_value = self.qvalue_loss(tensordict, weights=weights) + loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) loss_alpha = self._alpha_loss( log_prob=metadata_actor["log_prob"], ) + loss_alpha = _reduce(loss_alpha, reduction=self.reduction, weights=weights) tensordict.set(self.tensor_keys.priority, metadata_value["td_error"]) - if loss_actor.shape != loss_value.shape: + if loss_actor.shape != loss_qvalue.shape: raise RuntimeError( - f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}" + f"Losses shape mismatch: {loss_actor.shape}, and {loss_qvalue.shape}" ) entropy = -metadata_actor["log_prob"] out = { "loss_actor": loss_actor, - "loss_qvalue": loss_value, + "loss_qvalue": loss_qvalue, "loss_alpha": loss_alpha, "alpha": self._alpha, "entropy": entropy.detach().mean(), } td_out = TensorDict(out, []) - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - td_out = td_out.named_apply( - lambda name, value: _reduce( - value, reduction=self.reduction, weights=weights - ) - if name.startswith("loss_") - else value, - ) self._clear_weakrefs( tensordict, td_out, @@ -1464,50 +1389,8 @@ def _compute_target(self, tensordict) -> Tensor: target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - @dispatch - def actor_loss( - self, tensordict: TensorDictBase - ) -> tuple[Tensor, dict[str, Tensor]]: - """Compute the actor loss for discrete SAC. - - This method computes the actor loss which encourages the policy to maximize - the expected Q-value while maintaining high entropy for discrete actions. - - Args: - tensordict (TensorDictBase): A tensordict containing the data needed for - computing the actor loss. Should contain the observation and other - required keys for the actor network. - - Returns: - A tuple containing: - - The actor loss tensor - - A dictionary with metadata including the log probability of actions - """ - return self._actor_loss(tensordict) - - @dispatch def qvalue_loss( - self, tensordict: TensorDictBase - ) -> tuple[Tensor, dict[str, Tensor]]: - """Compute the Q-value loss for discrete SAC. - - This method computes the Q-value loss which trains the Q-networks to estimate - the expected return for state-action pairs in discrete action spaces. - - Args: - tensordict (TensorDictBase): A tensordict containing the data needed for - computing the Q-value loss. Should contain the observation, action, - reward, done, and terminated keys. - - Returns: - A tuple containing: - - The Q-value loss tensor - - A dictionary with metadata including the TD error - """ - return self._value_loss(tensordict) - - def _value_loss( - self, tensordict: TensorDictBase + self, tensordict: TensorDictBase, weights: torch.Tensor | None = None ) -> tuple[Tensor, dict[str, Tensor]]: target_value = self._compute_target(tensordict) tensordict_expand = self._vmap_qnetworkN0( @@ -1538,14 +1421,15 @@ def _value_loss( target_value.expand_as(chosen_action_value), loss_function=self.loss_function, ).sum(0) + loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights) metadata = { "td_error": td_error.detach().max(0)[0], } return loss_qval, metadata - def _actor_loss( - self, tensordict: TensorDictBase + def actor_loss( + self, tensordict: TensorDictBase, weights: torch.Tensor | None = None ) -> tuple[Tensor, dict[str, Tensor]]: # get probs and log probs for actions with self.actor_network_params.to_module(self.actor_network): @@ -1569,6 +1453,7 @@ def _actor_loss( loss = self._alpha * log_prob - min_q # unlike in continuous SAC, we can compute the exact expectation over all discrete actions loss = (prob * loss).sum(-1) + loss = _reduce(loss, reduction=self.reduction, weights=weights) return loss, {"log_prob": (log_prob * prob).sum(-1).detach()} From b18758e254b15363dc9215d035cd7b809729f312 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 21:01:37 +0000 Subject: [PATCH 3/3] reusable-methods --- torchrl/objectives/common.py | 19 +++++++++++++ torchrl/objectives/ddpg.py | 15 +++-------- torchrl/objectives/sac.py | 52 ++++++++++++++---------------------- torchrl/objectives/td3.py | 21 +++++---------- torchrl/objectives/td3_bc.py | 23 +++++----------- 5 files changed, 55 insertions(+), 75 deletions(-) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 0e1d25c689f..ea50ac54f63 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -492,6 +492,25 @@ def reset(self) -> None: # mainly used for PPO with KL target pass + def _maybe_get_priority_weight( + self, tensordict: TensorDictBase + ) -> torch.Tensor | None: + """Extract priority weights from tensordict if prioritized replay is enabled. + + Args: + tensordict (TensorDictBase): The input tensordict that may contain priority weights. + + Returns: + torch.Tensor | None: The priority weights if available and enabled, None otherwise. + """ + weights = None + if ( + self.use_prioritized_weights in (True, "auto") + and self.tensor_keys.priority_weight in tensordict.keys() + ): + weights = tensordict.get(self.tensor_keys.priority_weight) + return weights + def _reset_module_parameters(self, module_name, module): params_name = f"{module_name}_params" target_name = f"target_{module_name}_params" diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index b5d2c67ada6..119958fde48 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -300,15 +300,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a tuple of 2 tensors containing the DDPG loss. """ - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - loss_value, metadata = self.loss_value(tensordict, weights=weights) - loss_actor, metadata_actor = self.loss_actor(tensordict, weights=weights) + loss_value, metadata = self.loss_value(tensordict) + loss_actor, metadata_actor = self.loss_actor(tensordict) metadata.update(metadata_actor) td_out = TensorDict( source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, @@ -327,8 +320,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def loss_actor( self, tensordict: TensorDictBase, - weights: torch.Tensor | None = None, ) -> [torch.Tensor, dict]: + weights = self._maybe_get_priority_weight(tensordict) td_copy = tensordict.select( *self.actor_in_keys, *self.value_exclusive_keys, strict=False ).detach() @@ -352,8 +345,8 @@ def loss_actor( def loss_value( self, tensordict: TensorDictBase, - weights: torch.Tensor | None = None, ) -> tuple[torch.Tensor, dict]: + weights = self._maybe_get_priority_weight(tensordict) # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index b0df9615a19..ba1530fc6e9 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -616,26 +616,15 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - if self._version == 1: - loss_qvalue, value_metadata = self.qvalue_v1_loss( - tensordict, weights=weights - ) - loss_value, _ = self.value_loss(tensordict, weights=weights) + loss_qvalue, value_metadata = self.qvalue_v1_loss(tensordict) + loss_value, _ = self.value_loss(tensordict) else: - loss_qvalue, value_metadata = self.qvalue_v2_loss( - tensordict, weights=weights - ) + loss_qvalue, value_metadata = self.qvalue_v2_loss(tensordict) loss_value = None - loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) + loss_actor, metadata_actor = self.actor_loss(tensordict) loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + weights = self._maybe_get_priority_weight(tensordict) loss_alpha = _reduce(loss_alpha, reduction=self.reduction, weights=weights) tensordict.set(self.tensor_keys.priority, value_metadata["td_error"]) if (loss_actor.shape != loss_qvalue.shape) or ( @@ -673,8 +662,9 @@ def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() def actor_loss( - self, tensordict: TensorDictBase, weights: torch.Tensor | None = None + self, tensordict: TensorDictBase ) -> tuple[Tensor, dict[str, Tensor]]: + weights = self._maybe_get_priority_weight(tensordict) with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -736,8 +726,9 @@ def _cached_target_params_actor_value(self): ) def qvalue_v1_loss( - self, tensordict: TensorDictBase, weights: torch.Tensor | None = None + self, tensordict: TensorDictBase ) -> tuple[Tensor, dict[str, Tensor]]: + weights = self._maybe_get_priority_weight(tensordict) target_params = self._cached_target_params_actor_value with set_exploration_type(self.deterministic_sampling_mode): target_value = self.value_estimator.value_estimate( @@ -852,8 +843,9 @@ def _compute_target_v2(self, tensordict) -> Tensor: return target_value def qvalue_v2_loss( - self, tensordict: TensorDictBase, weights: torch.Tensor | None = None + self, tensordict: TensorDictBase ) -> tuple[Tensor, dict[str, Tensor]]: + weights = self._maybe_get_priority_weight(tensordict) # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._compute_target_v2(tensordict) @@ -875,8 +867,9 @@ def qvalue_v2_loss( return loss_qval, metadata def value_loss( - self, tensordict: TensorDictBase, weights: torch.Tensor | None = None + self, tensordict: TensorDictBase ) -> tuple[Tensor, dict[str, Tensor]]: + weights = self._maybe_get_priority_weight(tensordict) # value loss td_copy = tensordict.select(*self.value_network.in_keys, strict=False).detach() with self.value_network_params.to_module(self.value_network): @@ -1272,19 +1265,12 @@ def in_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - - loss_qvalue, metadata_value = self.qvalue_loss(tensordict, weights=weights) - loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) + loss_qvalue, metadata_value = self.qvalue_loss(tensordict) + loss_actor, metadata_actor = self.actor_loss(tensordict) loss_alpha = self._alpha_loss( log_prob=metadata_actor["log_prob"], ) + weights = self._maybe_get_priority_weight(tensordict) loss_alpha = _reduce(loss_alpha, reduction=self.reduction, weights=weights) tensordict.set(self.tensor_keys.priority, metadata_value["td_error"]) @@ -1390,8 +1376,9 @@ def _compute_target(self, tensordict) -> Tensor: return target_value def qvalue_loss( - self, tensordict: TensorDictBase, weights: torch.Tensor | None = None + self, tensordict: TensorDictBase ) -> tuple[Tensor, dict[str, Tensor]]: + weights = self._maybe_get_priority_weight(tensordict) target_value = self._compute_target(tensordict) tensordict_expand = self._vmap_qnetworkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), @@ -1429,8 +1416,9 @@ def qvalue_loss( return loss_qval, metadata def actor_loss( - self, tensordict: TensorDictBase, weights: torch.Tensor | None = None + self, tensordict: TensorDictBase ) -> tuple[Tensor, dict[str, Tensor]]: + weights = self._maybe_get_priority_weight(tensordict) # get probs and log probs for actions with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict.clone(False)) diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 31b38816944..eb389198cdd 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -381,9 +381,8 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss( - self, tensordict, weights: torch.Tensor | None = None - ) -> tuple[torch.Tensor, dict]: + def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: + weights = self._maybe_get_priority_weight(tensordict) tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -416,9 +415,8 @@ def actor_loss( ) return loss_actor, metadata - def value_loss( - self, tensordict, weights: torch.Tensor | None = None - ) -> tuple[torch.Tensor, dict]: + def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]: + weights = self._maybe_get_priority_weight(tensordict) tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) @@ -505,15 +503,8 @@ def value_loss( @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_save = tensordict - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) - loss_qval, metadata_value = self.value_loss(tensordict_save, weights=weights) + loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.value_loss(tensordict_save) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] ) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 72095842f9d..2f9613069a8 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -395,9 +395,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss( - self, tensordict, weights: torch.Tensor | None = None - ) -> tuple[torch.Tensor, dict]: + def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. @@ -405,11 +403,11 @@ def actor_loss( Args: tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields are required for this to be computed. - weights (torch.Tensor, optional): importance sampling weights for weighted reduction. Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"` used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda value, and the lambda value `"lmbd"` itself. """ + weights = self._maybe_get_priority_weight(tensordict) tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -452,9 +450,7 @@ def actor_loss( ) return loss_actor, metadata - def qvalue_loss( - self, tensordict, weights: torch.Tensor | None = None - ) -> tuple[torch.Tensor, dict]: + def qvalue_loss(self, tensordict) -> tuple[torch.Tensor, dict]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. @@ -462,10 +458,10 @@ def qvalue_loss( Args: tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields are required for this to be computed. - weights (torch.Tensor, optional): importance sampling weights for weighted reduction. Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`. """ + weights = self._maybe_get_priority_weight(tensordict) tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) @@ -559,15 +555,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class's `"in_keys"` and `"out_keys"` attributes. """ tensordict_save = tensordict - # Extract weights for prioritized replay buffer - weights = None - if ( - self.use_prioritized_weights in (True, "auto") - and self.tensor_keys.priority_weight in tensordict.keys() - ): - weights = tensordict.get(self.tensor_keys.priority_weight) - loss_actor, metadata_actor = self.actor_loss(tensordict, weights=weights) - loss_qval, metadata_value = self.qvalue_loss(tensordict_save, weights=weights) + loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.qvalue_loss(tensordict_save) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] )