[BugFix] Add scalar_output_mode to loss modules for reduction='none'#3426
Merged
[BugFix] Add scalar_output_mode to loss modules for reduction='none'#3426
Conversation
When using SACLoss or DiscreteSACLoss with reduction="none", the output TensorDict now correctly preserves the input batch_size (e.g., [B, T] for memory-based models). Previously, the output TensorDict always had batch_size=[] regardless of the reduction setting. This is important for RNN/Transformer-based SAC agents where the time dimension [T] needs to be preserved for proper gradient computation. Fixes #2338
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3426
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
Contributor
|
| Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
|---|---|---|---|---|---|
| test_tensor_to_bytestream_speed[pickle] | 81.0677μs | 80.2526μs | 12.4606 KOps/s | 12.3926 KOps/s | |
| test_tensor_to_bytestream_speed[torch.save] | 0.1410ms | 0.1404ms | 7.1229 KOps/s | 7.2206 KOps/s | |
| test_tensor_to_bytestream_speed[untyped_storage] | 0.1328s | 0.1323s | 7.5567 Ops/s | 8.9377 Ops/s | |
| test_tensor_to_bytestream_speed[numpy] | 2.7951μs | 2.7931μs | 358.0289 KOps/s | 367.0601 KOps/s | |
| test_tensor_to_bytestream_speed[safetensors] | 37.1928μs | 36.6705μs | 27.2699 KOps/s | 25.8992 KOps/s | |
| test_simple | 0.6781s | 0.5792s | 1.7266 Ops/s | 1.7324 Ops/s | |
| test_transformed | 1.2612s | 1.1687s | 0.8556 Ops/s | 0.8692 Ops/s | |
| test_serial | 1.6996s | 1.6957s | 0.5897 Ops/s | 0.5969 Ops/s | |
| test_parallel | 1.2167s | 1.1261s | 0.8880 Ops/s | 0.8998 Ops/s | |
| test_step_mdp_speed[True-True-True-True-True] | 0.2610ms | 44.1385μs | 22.6560 KOps/s | 21.6725 KOps/s | |
| test_step_mdp_speed[True-True-True-True-False] | 0.4340ms | 25.2971μs | 39.5302 KOps/s | 39.0230 KOps/s | |
| test_step_mdp_speed[True-True-True-False-True] | 0.4403ms | 25.2572μs | 39.5927 KOps/s | 39.4978 KOps/s | |
| test_step_mdp_speed[True-True-True-False-False] | 42.6410μs | 13.9027μs | 71.9283 KOps/s | 71.2420 KOps/s | |
| test_step_mdp_speed[True-True-False-True-True] | 0.4734ms | 48.5110μs | 20.6139 KOps/s | 20.7045 KOps/s | |
| test_step_mdp_speed[True-True-False-True-False] | 0.4457ms | 28.1562μs | 35.5162 KOps/s | 35.3908 KOps/s | |
| test_step_mdp_speed[True-True-False-False-True] | 64.7910μs | 27.8334μs | 35.9281 KOps/s | 36.0704 KOps/s | |
| test_step_mdp_speed[True-True-False-False-False] | 0.4325ms | 16.7203μs | 59.8075 KOps/s | 59.8983 KOps/s | |
| test_step_mdp_speed[True-False-True-True-True] | 0.4659ms | 52.0497μs | 19.2124 KOps/s | 19.8153 KOps/s | |
| test_step_mdp_speed[True-False-True-True-False] | 67.1710μs | 30.5138μs | 32.7721 KOps/s | 32.4763 KOps/s | |
| test_step_mdp_speed[True-False-True-False-True] | 0.4502ms | 28.1357μs | 35.5421 KOps/s | 35.5239 KOps/s | |
| test_step_mdp_speed[True-False-True-False-False] | 0.4328ms | 16.3999μs | 60.9760 KOps/s | 60.6191 KOps/s | |
| test_step_mdp_speed[True-False-False-True-True] | 0.4685ms | 53.4312μs | 18.7157 KOps/s | 18.7696 KOps/s | |
| test_step_mdp_speed[True-False-False-True-False] | 0.4479ms | 33.4629μs | 29.8838 KOps/s | 29.8722 KOps/s | |
| test_step_mdp_speed[True-False-False-False-True] | 78.2310μs | 29.9479μs | 33.3913 KOps/s | 33.4750 KOps/s | |
| test_step_mdp_speed[True-False-False-False-False] | 0.4329ms | 19.5182μs | 51.2342 KOps/s | 51.4986 KOps/s | |
| test_step_mdp_speed[False-True-True-True-True] | 0.4888ms | 51.0384μs | 19.5931 KOps/s | 19.7396 KOps/s | |
| test_step_mdp_speed[False-True-True-True-False] | 67.0910μs | 31.1972μs | 32.0542 KOps/s | 32.5423 KOps/s | |
| test_step_mdp_speed[False-True-True-False-True] | 0.4496ms | 31.8169μs | 31.4298 KOps/s | 31.5726 KOps/s | |
| test_step_mdp_speed[False-True-True-False-False] | 0.4363ms | 18.4529μs | 54.1920 KOps/s | 54.7203 KOps/s | |
| test_step_mdp_speed[False-True-False-True-True] | 2.7189ms | 53.5954μs | 18.6583 KOps/s | 18.6113 KOps/s | |
| test_step_mdp_speed[False-True-False-True-False] | 65.1110μs | 33.7763μs | 29.6065 KOps/s | 29.7463 KOps/s | |
| test_step_mdp_speed[False-True-False-False-True] | 0.4529ms | 34.9940μs | 28.5763 KOps/s | 28.6532 KOps/s | |
| test_step_mdp_speed[False-True-False-False-False] | 0.4365ms | 21.5244μs | 46.4588 KOps/s | 47.0304 KOps/s | |
| test_step_mdp_speed[False-False-True-True-True] | 0.4707ms | 56.8145μs | 17.6011 KOps/s | 17.7830 KOps/s | |
| test_step_mdp_speed[False-False-True-True-False] | 72.5610μs | 36.5766μs | 27.3399 KOps/s | 27.6110 KOps/s | |
| test_step_mdp_speed[False-False-True-False-True] | 0.1010ms | 34.4744μs | 29.0071 KOps/s | 28.9772 KOps/s | |
| test_step_mdp_speed[False-False-True-False-False] | 0.4372ms | 21.2642μs | 47.0275 KOps/s | 47.0445 KOps/s | |
| test_step_mdp_speed[False-False-False-True-True] | 0.4653ms | 57.3387μs | 17.4402 KOps/s | 17.0695 KOps/s | |
| test_step_mdp_speed[False-False-False-True-False] | 0.4714ms | 39.0917μs | 25.5809 KOps/s | 26.0039 KOps/s | |
| test_step_mdp_speed[False-False-False-False-True] | 76.0110μs | 36.5276μs | 27.3765 KOps/s | 27.1920 KOps/s | |
| test_step_mdp_speed[False-False-False-False-False] | 0.4376ms | 23.8280μs | 41.9675 KOps/s | 41.1680 KOps/s | |
| test_non_tensor_env_rollout_speed[1000-single-True] | 0.8723s | 0.7781s | 1.2851 Ops/s | 1.2982 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-single-False] | 0.7265s | 0.6394s | 1.5640 Ops/s | 1.5693 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-no-buffers-True] | 1.7680s | 1.6959s | 0.5897 Ops/s | 0.5973 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-no-buffers-False] | 1.5436s | 1.4666s | 0.6819 Ops/s | 0.6852 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-buffers-True] | 2.0186s | 1.9461s | 0.5139 Ops/s | 0.5197 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-buffers-False] | 1.8087s | 1.7229s | 0.5804 Ops/s | 0.5890 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-True] | 4.8120s | 4.7395s | 0.2110 Ops/s | 0.2138 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-False] | 4.6176s | 4.5351s | 0.2205 Ops/s | 0.2234 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-buffers-True] | 2.0513s | 1.9805s | 0.5049 Ops/s | 0.5096 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-buffers-False] | 1.8625s | 1.7121s | 0.5841 Ops/s | 0.5982 Ops/s | |
| test_values[generalized_advantage_estimate-True-True] | 11.6427ms | 11.1177ms | 89.9465 Ops/s | 95.4496 Ops/s | |
| test_values[vec_generalized_advantage_estimate-True-True] | 15.5232ms | 11.3447ms | 88.1467 Ops/s | 56.0230 Ops/s | |
| test_values[td0_return_estimate-False-False] | 0.2244ms | 0.1289ms | 7.7596 KOps/s | 7.1703 KOps/s | |
| test_values[td1_return_estimate-False-False] | 31.8873ms | 30.7693ms | 32.4999 Ops/s | 34.5702 Ops/s | |
| test_values[vec_td1_return_estimate-False-False] | 11.7150ms | 11.3693ms | 87.9561 Ops/s | 55.7930 Ops/s | |
| test_values[td_lambda_return_estimate-True-False] | 47.5158ms | 46.2693ms | 21.6126 Ops/s | 23.3810 Ops/s | |
| test_values[vec_td_lambda_return_estimate-True-False] | 11.6030ms | 11.3645ms | 87.9931 Ops/s | 56.0376 Ops/s | |
| test_gae_speed[generalized_advantage_estimate-False-1-512] | 10.2417ms | 9.9923ms | 100.0771 Ops/s | 107.4006 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.7482ms | 1.5251ms | 655.7003 Ops/s | 659.3994 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5105ms | 0.4492ms | 2.2262 KOps/s | 2.3369 KOps/s | |
| test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 30.6374ms | 27.7111ms | 36.0866 Ops/s | 32.8286 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 2.2064ms | 1.7558ms | 569.5409 Ops/s | 572.0552 Ops/s | |
| test_dqn_speed[False-None] | 1.7940ms | 1.4534ms | 688.0343 Ops/s | 708.4516 Ops/s | |
| test_dqn_speed[False-backward] | 2.0874ms | 1.9719ms | 507.1242 Ops/s | 519.7104 Ops/s | |
| test_dqn_speed[True-None] | 0.6824ms | 0.5466ms | 1.8296 KOps/s | 1.8087 KOps/s | |
| test_dqn_speed[True-backward] | 1.0590ms | 1.0038ms | 996.2082 Ops/s | 843.9049 Ops/s | |
| test_dqn_speed[reduce-overhead-None] | 0.7908ms | 0.5265ms | 1.8995 KOps/s | 1.8318 KOps/s | |
| test_ddpg_speed[False-None] | 3.1810ms | 2.8777ms | 347.4979 Ops/s | 349.0553 Ops/s | |
| test_ddpg_speed[False-backward] | 4.2960ms | 4.1681ms | 239.9202 Ops/s | 245.2833 Ops/s | |
| test_ddpg_speed[True-None] | 1.4825ms | 1.4015ms | 713.5035 Ops/s | 711.3958 Ops/s | |
| test_ddpg_speed[True-backward] | 2.4698ms | 2.3764ms | 420.7988 Ops/s | 377.2803 Ops/s | |
| test_ddpg_speed[reduce-overhead-None] | 1.8338ms | 1.3919ms | 718.4300 Ops/s | 705.3667 Ops/s | |
| test_sac_speed[False-None] | 8.7233ms | 8.0566ms | 124.1215 Ops/s | 122.2155 Ops/s | |
| test_sac_speed[False-backward] | 11.7142ms | 11.2906ms | 88.5696 Ops/s | 86.9919 Ops/s | |
| test_sac_speed[True-None] | 2.3994ms | 2.1245ms | 470.7014 Ops/s | 468.6859 Ops/s | |
| test_sac_speed[True-backward] | 4.1255ms | 3.9926ms | 250.4645 Ops/s | 205.3645 Ops/s | |
| test_sac_speed[reduce-overhead-None] | 2.5300ms | 2.1087ms | 474.2208 Ops/s | 470.5566 Ops/s | |
| test_redq_speed[False-None] | 11.0803ms | 10.4972ms | 95.2631 Ops/s | 92.4359 Ops/s | |
| test_redq_speed[False-backward] | 18.8301ms | 17.9486ms | 55.7147 Ops/s | 56.4075 Ops/s | |
| test_redq_speed[True-None] | 4.6697ms | 4.3641ms | 229.1412 Ops/s | 224.0284 Ops/s | |
| test_redq_speed[True-backward] | 10.0800ms | 9.7899ms | 102.1466 Ops/s | 96.7839 Ops/s | |
| test_redq_speed[reduce-overhead-None] | 4.6252ms | 4.3312ms | 230.8836 Ops/s | 230.0371 Ops/s | |
| test_redq_deprec_speed[False-None] | 11.7465ms | 11.2132ms | 89.1805 Ops/s | 89.4756 Ops/s | |
| test_redq_deprec_speed[False-backward] | 16.3294ms | 16.0305ms | 62.3809 Ops/s | 62.5030 Ops/s | |
| test_redq_deprec_speed[True-None] | 4.0792ms | 3.6439ms | 274.4283 Ops/s | 276.2486 Ops/s | |
| test_redq_deprec_speed[True-backward] | 7.7694ms | 7.4662ms | 133.9361 Ops/s | 128.6390 Ops/s | |
| test_redq_deprec_speed[reduce-overhead-None] | 3.8458ms | 3.6147ms | 276.6453 Ops/s | 275.8054 Ops/s | |
| test_td3_speed[False-None] | 8.3504ms | 8.1260ms | 123.0620 Ops/s | 123.6638 Ops/s | |
| test_td3_speed[False-backward] | 11.5222ms | 11.0641ms | 90.3823 Ops/s | 91.4851 Ops/s | |
| test_td3_speed[True-None] | 1.8770ms | 1.8150ms | 550.9503 Ops/s | 550.4961 Ops/s | |
| test_td3_speed[True-backward] | 3.8556ms | 3.6684ms | 272.6004 Ops/s | 223.2378 Ops/s | |
| test_td3_speed[reduce-overhead-None] | 1.8304ms | 1.7804ms | 561.6691 Ops/s | 559.8646 Ops/s | |
| test_cql_speed[False-None] | 29.7653ms | 26.4711ms | 37.7770 Ops/s | 38.4082 Ops/s | |
| test_cql_speed[False-backward] | 38.0011ms | 35.3632ms | 28.2780 Ops/s | 28.2731 Ops/s | |
| test_cql_speed[True-None] | 12.4899ms | 12.1721ms | 82.1550 Ops/s | 79.2077 Ops/s | |
| test_cql_speed[True-backward] | 18.5374ms | 18.1012ms | 55.2449 Ops/s | 55.9906 Ops/s | |
| test_cql_speed[reduce-overhead-None] | 12.5619ms | 12.2613ms | 81.5575 Ops/s | 82.0955 Ops/s | |
| test_a2c_speed[False-None] | 5.8035ms | 5.5193ms | 181.1822 Ops/s | 186.4814 Ops/s | |
| test_a2c_speed[False-backward] | 12.2962ms | 11.8666ms | 84.2703 Ops/s | 85.3034 Ops/s | |
| test_a2c_speed[True-None] | 4.1773ms | 3.7639ms | 265.6843 Ops/s | 261.1623 Ops/s | |
| test_a2c_speed[True-backward] | 8.7886ms | 8.5475ms | 116.9928 Ops/s | 116.5408 Ops/s | |
| test_a2c_speed[reduce-overhead-None] | 4.0475ms | 3.6879ms | 271.1546 Ops/s | 271.1638 Ops/s | |
| test_ppo_speed[False-None] | 6.4333ms | 5.9961ms | 166.7739 Ops/s | 169.5787 Ops/s | |
| test_ppo_speed[False-backward] | 12.7719ms | 12.4927ms | 80.0466 Ops/s | 80.4968 Ops/s | |
| test_ppo_speed[True-None] | 4.4617ms | 3.6592ms | 273.2810 Ops/s | 269.0759 Ops/s | |
| test_ppo_speed[True-backward] | 8.8434ms | 8.4299ms | 118.6250 Ops/s | 118.7687 Ops/s | |
| test_ppo_speed[reduce-overhead-None] | 3.9528ms | 3.5598ms | 280.9158 Ops/s | 278.8193 Ops/s | |
| test_reinforce_speed[False-None] | 5.1265ms | 4.5640ms | 219.1039 Ops/s | 224.0004 Ops/s | |
| test_reinforce_speed[False-backward] | 7.6117ms | 7.3472ms | 136.1062 Ops/s | 137.4509 Ops/s | |
| test_reinforce_speed[True-None] | 3.2401ms | 2.8298ms | 353.3880 Ops/s | 352.2056 Ops/s | |
| test_reinforce_speed[True-backward] | 7.9913ms | 7.7149ms | 129.6193 Ops/s | 133.4870 Ops/s | |
| test_reinforce_speed[reduce-overhead-None] | 3.3008ms | 2.8266ms | 353.7807 Ops/s | 355.9270 Ops/s | |
| test_iql_speed[False-None] | 20.8413ms | 19.8757ms | 50.3127 Ops/s | 49.5248 Ops/s | |
| test_iql_speed[False-backward] | 30.8272ms | 30.2681ms | 33.0381 Ops/s | 33.1179 Ops/s | |
| test_iql_speed[True-None] | 9.0281ms | 8.4236ms | 118.7139 Ops/s | 118.9675 Ops/s | |
| test_iql_speed[True-backward] | 16.9173ms | 16.5787ms | 60.3184 Ops/s | 61.3577 Ops/s | |
| test_iql_speed[reduce-overhead-None] | 8.9473ms | 8.5146ms | 117.4458 Ops/s | 118.2833 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.1418ms | 6.0378ms | 165.6239 Ops/s | 165.8023 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.9606ms | 0.3382ms | 2.9570 KOps/s | 2.9508 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6439ms | 0.3073ms | 3.2545 KOps/s | 3.3440 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.1389ms | 5.8655ms | 170.4876 Ops/s | 170.4120 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.6870ms | 0.3015ms | 3.3168 KOps/s | 3.5765 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5832ms | 0.2734ms | 3.6579 KOps/s | 3.4827 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5372ms | 1.3040ms | 766.8478 Ops/s | 761.3026 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4343ms | 1.2155ms | 822.7389 Ops/s | 834.7921 Ops/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 9.7865ms | 6.1487ms | 162.6370 Ops/s | 164.3890 Ops/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.2706ms | 0.4616ms | 2.1662 KOps/s | 2.0360 KOps/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7018ms | 0.4854ms | 2.0602 KOps/s | 2.1252 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.3158ms | 5.9105ms | 169.1904 Ops/s | 169.5594 Ops/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8272ms | 0.3198ms | 3.1272 KOps/s | 2.4176 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4989ms | 0.3036ms | 3.2936 KOps/s | 2.8643 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.1027ms | 5.8576ms | 170.7186 Ops/s | 170.7336 Ops/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6701ms | 0.3190ms | 3.1349 KOps/s | 3.5798 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5530ms | 0.3363ms | 2.9737 KOps/s | 3.6581 KOps/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 8.2595ms | 6.0924ms | 164.1389 Ops/s | 166.3469 Ops/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.2978ms | 0.4789ms | 2.0880 KOps/s | 1.9925 KOps/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7304ms | 0.4998ms | 2.0007 KOps/s | 2.3317 KOps/s | |
| test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.6419s | 17.8726ms | 55.9517 Ops/s | 52.8337 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.1979ms | 1.8917ms | 528.6205 Ops/s | 497.4611 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 13.6275ms | 1.4948ms | 668.9724 Ops/s | 834.2537 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.0701ms | 5.1223ms | 195.2235 Ops/s | 197.9887 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 3.8670ms | 1.7406ms | 574.5088 Ops/s | 558.6145 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 9.0394ms | 1.2409ms | 805.8692 Ops/s | 897.9243 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.5582s | 16.4011ms | 60.9714 Ops/s | 59.5842 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 11.7327ms | 2.1179ms | 472.1569 Ops/s | 497.8688 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 1.2390ms | 1.0285ms | 972.3365 Ops/s | 688.9940 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 39.5476ms | 36.5381ms | 27.3687 Ops/s | 27.2935 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 20.8097ms | 18.7261ms | 53.4015 Ops/s | 52.6009 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 41.3808ms | 37.7693ms | 26.4765 Ops/s | 26.2788 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 20.3552ms | 18.7757ms | 53.2602 Ops/s | 52.5721 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 41.6627ms | 39.5820ms | 25.2640 Ops/s | 25.2115 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 22.7237ms | 20.7372ms | 48.2225 Ops/s | 48.6903 Ops/s |
When reduction='none', scalar values (alpha, entropy) cannot be included in the output TensorDict without changing their shape. This commit adds a `scalar_output_mode` parameter to SACLoss and DiscreteSACLoss: - None (default): Warn and exclude scalars from output - "exclude": Explicitly exclude scalars (no warning) - "non_tensor": Include scalars as non-tensor data via set_non_tensor() This gives users explicit control over how scalars are handled while making it clear this is a known limitation we're working on.
Contributor
|
| Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
|---|---|---|---|---|---|
| test_tensor_to_bytestream_speed[pickle] | 82.9288μs | 82.1688μs | 12.1701 KOps/s | 12.3075 KOps/s | |
| test_tensor_to_bytestream_speed[torch.save] | 0.1428ms | 0.1423ms | 7.0259 KOps/s | 7.0929 KOps/s | |
| test_tensor_to_bytestream_speed[untyped_storage] | 0.1369s | 0.1364s | 7.3316 Ops/s | 8.4807 Ops/s | |
| test_tensor_to_bytestream_speed[numpy] | 2.9157μs | 2.9080μs | 343.8767 KOps/s | 362.0471 KOps/s | |
| test_tensor_to_bytestream_speed[safetensors] | 37.3850μs | 37.2388μs | 26.8537 KOps/s | 26.5350 KOps/s | |
| test_simple | 0.8074s | 0.8060s | 1.2407 Ops/s | 1.2375 Ops/s | |
| test_transformed | 1.5705s | 1.4768s | 0.6772 Ops/s | 0.6958 Ops/s | |
| test_serial | 2.4573s | 2.3628s | 0.4232 Ops/s | 0.4282 Ops/s | |
| test_parallel | 2.0288s | 1.9602s | 0.5102 Ops/s | 0.5136 Ops/s | |
| test_step_mdp_speed[True-True-True-True-True] | 0.2556ms | 45.9674μs | 21.7545 KOps/s | 22.3073 KOps/s | |
| test_step_mdp_speed[True-True-True-True-False] | 56.2110μs | 25.2160μs | 39.6573 KOps/s | 39.9305 KOps/s | |
| test_step_mdp_speed[True-True-True-False-True] | 55.1310μs | 25.2932μs | 39.5363 KOps/s | 40.0049 KOps/s | |
| test_step_mdp_speed[True-True-True-False-False] | 48.8210μs | 13.8400μs | 72.2543 KOps/s | 73.0272 KOps/s | |
| test_step_mdp_speed[True-True-False-True-True] | 79.7410μs | 48.0310μs | 20.8199 KOps/s | 20.8704 KOps/s | |
| test_step_mdp_speed[True-True-False-True-False] | 58.8720μs | 27.7066μs | 36.0925 KOps/s | 35.9662 KOps/s | |
| test_step_mdp_speed[True-True-False-False-True] | 62.8110μs | 27.7607μs | 36.0221 KOps/s | 35.8869 KOps/s | |
| test_step_mdp_speed[True-True-False-False-False] | 49.2210μs | 16.8418μs | 59.3759 KOps/s | 60.0296 KOps/s | |
| test_step_mdp_speed[True-False-True-True-True] | 98.8820μs | 52.5147μs | 19.0423 KOps/s | 19.5770 KOps/s | |
| test_step_mdp_speed[True-False-True-True-False] | 61.1410μs | 31.1934μs | 32.0580 KOps/s | 32.5540 KOps/s | |
| test_step_mdp_speed[True-False-True-False-True] | 64.0910μs | 28.5601μs | 35.0139 KOps/s | 35.7795 KOps/s | |
| test_step_mdp_speed[True-False-True-False-False] | 43.7910μs | 16.6995μs | 59.8821 KOps/s | 59.9775 KOps/s | |
| test_step_mdp_speed[True-False-False-True-True] | 86.4420μs | 53.5134μs | 18.6869 KOps/s | 18.7359 KOps/s | |
| test_step_mdp_speed[True-False-False-True-False] | 67.8310μs | 33.1859μs | 30.1333 KOps/s | 29.5706 KOps/s | |
| test_step_mdp_speed[True-False-False-False-True] | 79.7320μs | 30.5535μs | 32.7294 KOps/s | 32.6970 KOps/s | |
| test_step_mdp_speed[True-False-False-False-False] | 47.0810μs | 19.5869μs | 51.0546 KOps/s | 51.8129 KOps/s | |
| test_step_mdp_speed[False-True-True-True-True] | 0.1312ms | 49.8182μs | 20.0730 KOps/s | 19.9259 KOps/s | |
| test_step_mdp_speed[False-True-True-True-False] | 62.9510μs | 30.3480μs | 32.9511 KOps/s | 33.0256 KOps/s | |
| test_step_mdp_speed[False-True-True-False-True] | 60.5710μs | 31.7401μs | 31.5058 KOps/s | 31.3100 KOps/s | |
| test_step_mdp_speed[False-True-True-False-False] | 53.1210μs | 18.4432μs | 54.2206 KOps/s | 54.9057 KOps/s | |
| test_step_mdp_speed[False-True-False-True-True] | 2.7287ms | 53.5686μs | 18.6677 KOps/s | 18.8207 KOps/s | |
| test_step_mdp_speed[False-True-False-True-False] | 65.6120μs | 33.2452μs | 30.0795 KOps/s | 30.0952 KOps/s | |
| test_step_mdp_speed[False-True-False-False-True] | 66.9410μs | 34.0607μs | 29.3593 KOps/s | 29.3841 KOps/s | |
| test_step_mdp_speed[False-True-False-False-False] | 50.3310μs | 21.2899μs | 46.9707 KOps/s | 47.2584 KOps/s | |
| test_step_mdp_speed[False-False-True-True-True] | 96.5520μs | 56.4106μs | 17.7272 KOps/s | 17.7863 KOps/s | |
| test_step_mdp_speed[False-False-True-True-False] | 73.2010μs | 36.3712μs | 27.4943 KOps/s | 27.1886 KOps/s | |
| test_step_mdp_speed[False-False-True-False-True] | 70.1810μs | 34.2664μs | 29.1831 KOps/s | 29.4001 KOps/s | |
| test_step_mdp_speed[False-False-True-False-False] | 57.3410μs | 21.1330μs | 47.3194 KOps/s | 47.4388 KOps/s | |
| test_step_mdp_speed[False-False-False-True-True] | 90.5420μs | 58.2399μs | 17.1704 KOps/s | 17.1469 KOps/s | |
| test_step_mdp_speed[False-False-False-True-False] | 69.2410μs | 38.4613μs | 26.0001 KOps/s | 25.9704 KOps/s | |
| test_step_mdp_speed[False-False-False-False-True] | 85.4320μs | 36.4056μs | 27.4683 KOps/s | 27.3289 KOps/s | |
| test_step_mdp_speed[False-False-False-False-False] | 51.8410μs | 23.7257μs | 42.1485 KOps/s | 42.3624 KOps/s | |
| test_non_tensor_env_rollout_speed[1000-single-True] | 0.8781s | 0.7832s | 1.2768 Ops/s | 1.2680 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-single-False] | 0.7377s | 0.6451s | 1.5502 Ops/s | 1.5392 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-no-buffers-True] | 1.7790s | 1.7008s | 0.5880 Ops/s | 0.5813 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-no-buffers-False] | 1.5436s | 1.4714s | 0.6796 Ops/s | 0.6707 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-buffers-True] | 2.0358s | 1.9590s | 0.5105 Ops/s | 0.5082 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-serial-buffers-False] | 1.8053s | 1.7275s | 0.5789 Ops/s | 0.5690 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-True] | 4.8234s | 4.7208s | 0.2118 Ops/s | 0.2126 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-no-buffers-False] | 4.6162s | 4.4760s | 0.2234 Ops/s | 0.2218 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-buffers-True] | 2.1480s | 2.0172s | 0.4957 Ops/s | 0.5040 Ops/s | |
| test_non_tensor_env_rollout_speed[1000-parallel-buffers-False] | 1.7636s | 1.6868s | 0.5928 Ops/s | 0.5949 Ops/s | |
| test_values[generalized_advantage_estimate-True-True] | 20.9307ms | 20.5342ms | 48.6992 Ops/s | 49.3771 Ops/s | |
| test_values[vec_generalized_advantage_estimate-True-True] | 0.1401s | 3.7320ms | 267.9526 Ops/s | 275.8289 Ops/s | |
| test_values[td0_return_estimate-False-False] | 0.1071ms | 84.3974μs | 11.8487 KOps/s | 11.6273 KOps/s | |
| test_values[td1_return_estimate-False-False] | 49.3693ms | 48.9912ms | 20.4118 Ops/s | 20.4879 Ops/s | |
| test_values[vec_td1_return_estimate-False-False] | 1.3344ms | 1.1053ms | 904.7073 Ops/s | 907.5230 Ops/s | |
| test_values[td_lambda_return_estimate-True-False] | 82.2597ms | 80.7451ms | 12.3846 Ops/s | 12.3021 Ops/s | |
| test_values[vec_td_lambda_return_estimate-True-False] | 1.3366ms | 1.1031ms | 906.5296 Ops/s | 908.2247 Ops/s | |
| test_gae_speed[generalized_advantage_estimate-False-1-512] | 21.3157ms | 20.9697ms | 47.6877 Ops/s | 48.0665 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0454ms | 0.7736ms | 1.2927 KOps/s | 1.2922 KOps/s | |
| test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8097ms | 0.7201ms | 1.3887 KOps/s | 1.4466 KOps/s | |
| test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6695ms | 1.5155ms | 659.8475 Ops/s | 663.0389 Ops/s | |
| test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7756ms | 0.7340ms | 1.3623 KOps/s | 1.4090 KOps/s | |
| test_dqn_speed[False-None] | 1.7077ms | 1.6165ms | 618.6283 Ops/s | 639.0615 Ops/s | |
| test_dqn_speed[False-backward] | 2.2665ms | 2.2089ms | 452.7057 Ops/s | 448.3268 Ops/s | |
| test_dqn_speed[True-None] | 0.7070ms | 0.5653ms | 1.7689 KOps/s | 1.8079 KOps/s | |
| test_dqn_speed[True-backward] | 1.1429ms | 1.0811ms | 925.0030 Ops/s | 925.0637 Ops/s | |
| test_dqn_speed[reduce-overhead-None] | 0.7597ms | 0.5702ms | 1.7539 KOps/s | 1.6983 KOps/s | |
| test_ddpg_speed[False-None] | 3.3407ms | 2.9459ms | 339.4510 Ops/s | 340.9832 Ops/s | |
| test_ddpg_speed[False-backward] | 4.7264ms | 4.2498ms | 235.3068 Ops/s | 235.8825 Ops/s | |
| test_ddpg_speed[True-None] | 1.3591ms | 1.2986ms | 770.0486 Ops/s | 757.0522 Ops/s | |
| test_ddpg_speed[True-backward] | 2.4081ms | 2.3521ms | 425.1579 Ops/s | 398.1341 Ops/s | |
| test_ddpg_speed[reduce-overhead-None] | 1.4494ms | 1.3359ms | 748.5789 Ops/s | 749.7741 Ops/s | |
| test_sac_speed[False-None] | 9.4988ms | 8.5330ms | 117.1927 Ops/s | 117.5097 Ops/s | |
| test_sac_speed[False-backward] | 12.0454ms | 11.5320ms | 86.7152 Ops/s | 84.5912 Ops/s | |
| test_sac_speed[True-None] | 2.2461ms | 1.8244ms | 548.1163 Ops/s | 552.5021 Ops/s | |
| test_sac_speed[True-backward] | 3.5632ms | 3.4392ms | 290.7635 Ops/s | 279.1652 Ops/s | |
| test_sac_speed[reduce-overhead-None] | 18.8413ms | 10.5855ms | 94.4690 Ops/s | 84.0775 Ops/s | |
| test_redq_deprec_speed[False-None] | 10.1864ms | 9.5454ms | 104.7626 Ops/s | 104.9881 Ops/s | |
| test_redq_deprec_speed[False-backward] | 13.1107ms | 12.6864ms | 78.8246 Ops/s | 76.9409 Ops/s | |
| test_redq_deprec_speed[True-None] | 2.8802ms | 2.5297ms | 395.3055 Ops/s | 396.0598 Ops/s | |
| test_redq_deprec_speed[True-backward] | 4.1739ms | 4.1039ms | 243.6720 Ops/s | 230.7351 Ops/s | |
| test_redq_deprec_speed[reduce-overhead-None] | 15.7287ms | 9.5940ms | 104.2314 Ops/s | 106.2285 Ops/s | |
| test_td3_speed[False-None] | 8.7394ms | 8.3799ms | 119.3336 Ops/s | 119.2485 Ops/s | |
| test_td3_speed[False-backward] | 11.2937ms | 10.8412ms | 92.2404 Ops/s | 90.0499 Ops/s | |
| test_td3_speed[True-None] | 1.6478ms | 1.6172ms | 618.3431 Ops/s | 592.1757 Ops/s | |
| test_td3_speed[True-backward] | 3.3080ms | 3.0808ms | 324.5940 Ops/s | 307.8017 Ops/s | |
| test_td3_speed[reduce-overhead-None] | 66.0018ms | 23.5396ms | 42.4815 Ops/s | 42.1865 Ops/s | |
| test_cql_speed[False-None] | 17.7906ms | 17.5739ms | 56.9024 Ops/s | 56.6832 Ops/s | |
| test_cql_speed[False-backward] | 23.6429ms | 22.8918ms | 43.6838 Ops/s | 42.7699 Ops/s | |
| test_cql_speed[True-None] | 3.3354ms | 3.2291ms | 309.6869 Ops/s | 300.6280 Ops/s | |
| test_cql_speed[True-backward] | 5.9135ms | 5.5033ms | 181.7088 Ops/s | 186.4180 Ops/s | |
| test_cql_speed[reduce-overhead-None] | 18.1373ms | 11.4192ms | 87.5719 Ops/s | 86.0952 Ops/s | |
| test_a2c_speed[False-None] | 4.0342ms | 3.3274ms | 300.5312 Ops/s | 299.5772 Ops/s | |
| test_a2c_speed[False-backward] | 6.9537ms | 6.5242ms | 153.2762 Ops/s | 157.8241 Ops/s | |
| test_a2c_speed[True-None] | 1.4261ms | 1.3378ms | 747.4801 Ops/s | 747.2724 Ops/s | |
| test_a2c_speed[True-backward] | 3.2192ms | 3.0900ms | 323.6233 Ops/s | 337.4161 Ops/s | |
| test_a2c_speed[reduce-overhead-None] | 1.0503ms | 0.9562ms | 1.0458 KOps/s | 1.0295 KOps/s | |
| test_ppo_speed[False-None] | 4.0703ms | 3.9206ms | 255.0652 Ops/s | 255.2688 Ops/s | |
| test_ppo_speed[False-backward] | 7.7558ms | 7.3193ms | 136.6251 Ops/s | 141.8497 Ops/s | |
| test_ppo_speed[True-None] | 1.4360ms | 1.3825ms | 723.3528 Ops/s | 709.9293 Ops/s | |
| test_ppo_speed[True-backward] | 3.7613ms | 3.3447ms | 298.9837 Ops/s | 305.7054 Ops/s | |
| test_ppo_speed[reduce-overhead-None] | 1.1230ms | 1.0319ms | 969.0418 Ops/s | 944.1230 Ops/s | |
| test_reinforce_speed[False-None] | 2.4305ms | 2.3309ms | 429.0170 Ops/s | 428.0117 Ops/s | |
| test_reinforce_speed[False-backward] | 3.8419ms | 3.3638ms | 297.2859 Ops/s | 296.4836 Ops/s | |
| test_reinforce_speed[True-None] | 1.3346ms | 1.2526ms | 798.3574 Ops/s | 795.7089 Ops/s | |
| test_reinforce_speed[True-backward] | 3.0230ms | 2.9241ms | 341.9875 Ops/s | 338.4312 Ops/s | |
| test_reinforce_speed[reduce-overhead-None] | 16.5701ms | 9.1892ms | 108.8232 Ops/s | 97.4003 Ops/s | |
| test_iql_speed[False-None] | 10.8901ms | 9.6781ms | 103.3259 Ops/s | 101.3533 Ops/s | |
| test_iql_speed[False-backward] | 13.8663ms | 13.3753ms | 74.7649 Ops/s | 73.5896 Ops/s | |
| test_iql_speed[True-None] | 2.2341ms | 2.1578ms | 463.4368 Ops/s | 459.3275 Ops/s | |
| test_iql_speed[True-backward] | 5.0286ms | 4.6853ms | 213.4325 Ops/s | 211.2299 Ops/s | |
| test_iql_speed[reduce-overhead-None] | 16.9947ms | 10.0929ms | 99.0793 Ops/s | 96.6072 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.5022ms | 6.0157ms | 166.2306 Ops/s | 164.9426 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.1228ms | 0.3473ms | 2.8793 KOps/s | 2.9723 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5838ms | 0.3147ms | 3.1772 KOps/s | 3.3120 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.0314ms | 5.8078ms | 172.1810 Ops/s | 171.1193 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.0166ms | 0.2841ms | 3.5198 KOps/s | 3.5278 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5101ms | 0.2644ms | 3.7822 KOps/s | 3.7726 KOps/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.7175ms | 1.3153ms | 760.3059 Ops/s | 677.9889 Ops/s | |
| test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.6017ms | 1.3853ms | 721.8739 Ops/s | 775.9428 Ops/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.0867ms | 5.9957ms | 166.7854 Ops/s | 165.6913 Ops/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9617ms | 0.4470ms | 2.2374 KOps/s | 2.0084 KOps/s | |
| test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6480ms | 0.4247ms | 2.3545 KOps/s | 2.3570 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.9855ms | 5.9035ms | 169.3916 Ops/s | 170.1087 Ops/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.9184ms | 0.2866ms | 3.4891 KOps/s | 2.6288 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8469ms | 0.3525ms | 2.8367 KOps/s | 2.7143 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.1724ms | 5.8371ms | 171.3189 Ops/s | 171.2603 Ops/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6747ms | 0.3186ms | 3.1386 KOps/s | 2.7805 KOps/s | |
| test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7212ms | 0.3387ms | 2.9522 KOps/s | 3.7418 KOps/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.0949ms | 6.0124ms | 166.3231 Ops/s | 165.8559 Ops/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 0.8985ms | 0.4766ms | 2.0982 KOps/s | 1.8620 KOps/s | |
| test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7365ms | 0.4737ms | 2.1108 KOps/s | 2.2409 KOps/s | |
| test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.5237ms | 5.0463ms | 198.1647 Ops/s | 49.0221 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.1477ms | 2.2359ms | 447.2426 Ops/s | 536.5059 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 3.6150ms | 0.9793ms | 1.0211 KOps/s | 806.2569 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.5821s | 16.7268ms | 59.7842 Ops/s | 194.2999 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 11.3420ms | 2.0693ms | 483.2614 Ops/s | 553.7465 Ops/s | |
| test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.1725ms | 1.2798ms | 781.3744 Ops/s | 799.4051 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 6.9145ms | 5.3161ms | 188.1078 Ops/s | 188.5766 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 14.0600ms | 2.1629ms | 462.3511 Ops/s | 482.7998 Ops/s | |
| test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 1.3616ms | 1.0842ms | 922.3617 Ops/s | 903.3826 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 39.5448ms | 36.8508ms | 27.1364 Ops/s | 27.5806 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 20.2644ms | 18.7557ms | 53.3173 Ops/s | 53.5157 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 41.7546ms | 38.3718ms | 26.0608 Ops/s | 26.4591 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 20.8975ms | 18.9447ms | 52.7851 Ops/s | 52.8135 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 41.7467ms | 39.6869ms | 25.1973 Ops/s | 25.3699 Ops/s | |
| test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 21.8190ms | 20.4161ms | 48.9810 Ops/s | 49.1957 Ops/s |
The warning about scalar handling when reduction='none' now appears at instantiation time rather than during every forward() call. This is more appropriate for configuration warnings.
Apply the same scalar_output_mode pattern to: - CrossQLoss - CQLLoss - REDQLoss - IQLLoss / DiscreteIQLLoss - OnlineDTLoss When reduction='none', scalar values (alpha, entropy) cannot be included in the batched TensorDict output. This parameter allows users to choose between excluding them (default with warning) or including them as non-tensor data.
- Fix REDQLoss forward to handle its different batch structure (losses have extra dimension for Q-network ensemble) - Update reduction tests to pass scalar_output_mode='exclude' when reduction='none' to suppress warnings
PRs with the user-facing label should NOT be included in minor releases, only in major releases.
vmoens
added a commit
that referenced
this pull request
Feb 4, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #2338
When using
SACLoss(or other loss modules) withreduction='none', the outputTensorDictshould preserve the input batch dimensions. Previously, the output had an empty batch size (torch.Size([])), which caused issues when users expected the loss values to maintain per-sample granularity.The Problem
Scalar values like
alpha(the temperature parameter) andentropycannot be included in aTensorDictwith non-empty batch dimensions without changing their shape. This created a conflict: either output an empty-batch TensorDict (breaking expectations) or expand the scalars (changing their semantics).The Solution
Added a new
scalar_output_modeparameter to the following loss module constructors:SACLoss/DiscreteSACLossCrossQLossCQLLossREDQLossIQLLoss/DiscreteIQLLossOnlineDTLossOptions:
None(default): Issues a warning and excludes scalars from output"exclude": Silently excludes scalars (suppresses warning)"non_tensor": Includes scalars as non-tensor data viaset_non_tensor()The warning is raised at instantiation time (not during forward), making it clear upfront that the behavior differs.
Example
Test plan
pytest test/test_objectives.py -k "test_sac")