Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Fix DType casting lazy init #1589

Merged
merged 5 commits into from
Oct 2, 2023
Merged

[Feature] Fix DType casting lazy init #1589

merged 5 commits into from
Oct 2, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Oct 2, 2023

Description

The init of DType casting is broken when the first access to the transform is made through reset. This PR fixes that.

This PR also uncouples in_keys and out_keys in the transform base class. This is needed to avoid having to patch transforms with a logic that is, actually, not as generic as expected. From a user perspective nothing changes so this isn't bc-breaking.

The checks on dtypes for the DTypeTransform are not a bit more stringent.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 2, 2023
@vmoens vmoens added the bug Something isn't working label Oct 2, 2023
@github-actions
Copy link

github-actions bot commented Oct 2, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}6$. Worsened: $\large\color{#d91a1a}10$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1020s 0.1012s 9.8772 Ops/s 9.7235 Ops/s $\color{#35bf28}+1.58\%$
test_sync 71.3060ms 55.2376ms 18.1036 Ops/s 18.5475 Ops/s $\color{#d91a1a}-2.39\%$
test_async 0.1099s 53.0560ms 18.8480 Ops/s 18.9476 Ops/s $\color{#d91a1a}-0.53\%$
test_simple 0.9034s 0.8375s 1.1940 Ops/s 1.2088 Ops/s $\color{#d91a1a}-1.23\%$
test_transformed 1.1081s 1.0355s 0.9657 Ops/s 0.9396 Ops/s $\color{#35bf28}+2.78\%$
test_serial 2.2578s 2.1880s 0.4570 Ops/s 0.4498 Ops/s $\color{#35bf28}+1.60\%$
test_parallel 1.9228s 1.8388s 0.5438 Ops/s 0.5403 Ops/s $\color{#35bf28}+0.66\%$
test_step_mdp_speed[True-True-True-True-True] 0.1172ms 44.1052μs 22.6731 KOps/s 22.3370 KOps/s $\color{#35bf28}+1.50\%$
test_step_mdp_speed[True-True-True-True-False] 50.8010μs 25.0407μs 39.9350 KOps/s 39.1018 KOps/s $\color{#35bf28}+2.13\%$
test_step_mdp_speed[True-True-True-False-True] 99.1010μs 30.9689μs 32.2905 KOps/s 31.5152 KOps/s $\color{#35bf28}+2.46\%$
test_step_mdp_speed[True-True-True-False-False] 45.7010μs 17.4843μs 57.1942 KOps/s 55.3396 KOps/s $\color{#35bf28}+3.35\%$
test_step_mdp_speed[True-True-False-True-True] 72.6000μs 45.4201μs 22.0167 KOps/s 21.2773 KOps/s $\color{#35bf28}+3.47\%$
test_step_mdp_speed[True-True-False-True-False] 0.2172ms 26.8009μs 37.3121 KOps/s 36.5678 KOps/s $\color{#35bf28}+2.04\%$
test_step_mdp_speed[True-True-False-False-True] 0.1509ms 33.4134μs 29.9281 KOps/s 29.4044 KOps/s $\color{#35bf28}+1.78\%$
test_step_mdp_speed[True-True-False-False-False] 41.9010μs 19.3499μs 51.6799 KOps/s 50.6732 KOps/s $\color{#35bf28}+1.99\%$
test_step_mdp_speed[True-False-True-True-True] 94.6010μs 47.5697μs 21.0218 KOps/s 20.4034 KOps/s $\color{#35bf28}+3.03\%$
test_step_mdp_speed[True-False-True-True-False] 50.4010μs 28.5705μs 35.0011 KOps/s 34.4857 KOps/s $\color{#35bf28}+1.49\%$
test_step_mdp_speed[True-False-True-False-True] 82.5010μs 33.0465μs 30.2604 KOps/s 28.9168 KOps/s $\color{#35bf28}+4.65\%$
test_step_mdp_speed[True-False-True-False-False] 69.2010μs 19.3080μs 51.7921 KOps/s 50.6332 KOps/s $\color{#35bf28}+2.29\%$
test_step_mdp_speed[True-False-False-True-True] 72.3010μs 48.9085μs 20.4463 KOps/s 19.6030 KOps/s $\color{#35bf28}+4.30\%$
test_step_mdp_speed[True-False-False-True-False] 96.7020μs 30.3080μs 32.9946 KOps/s 32.5247 KOps/s $\color{#35bf28}+1.44\%$
test_step_mdp_speed[True-False-False-False-True] 60.0000μs 34.5472μs 28.9459 KOps/s 28.0584 KOps/s $\color{#35bf28}+3.16\%$
test_step_mdp_speed[True-False-False-False-False] 48.5000μs 20.8510μs 47.9592 KOps/s 47.1766 KOps/s $\color{#35bf28}+1.66\%$
test_step_mdp_speed[False-True-True-True-True] 76.5010μs 47.4274μs 21.0848 KOps/s 20.5057 KOps/s $\color{#35bf28}+2.82\%$
test_step_mdp_speed[False-True-True-True-False] 97.7010μs 28.9193μs 34.5789 KOps/s 34.4202 KOps/s $\color{#35bf28}+0.46\%$
test_step_mdp_speed[False-True-True-False-True] 62.6000μs 36.8010μs 27.1732 KOps/s 26.6355 KOps/s $\color{#35bf28}+2.02\%$
test_step_mdp_speed[False-True-True-False-False] 95.2010μs 21.5680μs 46.3650 KOps/s 45.9046 KOps/s $\color{#35bf28}+1.00\%$
test_step_mdp_speed[False-True-False-True-True] 71.0010μs 49.1076μs 20.3635 KOps/s 20.0593 KOps/s $\color{#35bf28}+1.52\%$
test_step_mdp_speed[False-True-False-True-False] 0.1030ms 30.3878μs 32.9079 KOps/s 32.4323 KOps/s $\color{#35bf28}+1.47\%$
test_step_mdp_speed[False-True-False-False-True] 68.0010μs 38.3092μs 26.1034 KOps/s 25.6120 KOps/s $\color{#35bf28}+1.92\%$
test_step_mdp_speed[False-True-False-False-False] 78.0000μs 23.1456μs 43.2047 KOps/s 42.1841 KOps/s $\color{#35bf28}+2.42\%$
test_step_mdp_speed[False-False-True-True-True] 75.9010μs 50.8511μs 19.6653 KOps/s 19.3066 KOps/s $\color{#35bf28}+1.86\%$
test_step_mdp_speed[False-False-True-True-False] 58.0010μs 32.1092μs 31.1437 KOps/s 30.4425 KOps/s $\color{#35bf28}+2.30\%$
test_step_mdp_speed[False-False-True-False-True] 60.3000μs 38.2437μs 26.1481 KOps/s 25.4646 KOps/s $\color{#35bf28}+2.68\%$
test_step_mdp_speed[False-False-True-False-False] 85.7010μs 22.9867μs 43.5034 KOps/s 41.6587 KOps/s $\color{#35bf28}+4.43\%$
test_step_mdp_speed[False-False-False-True-True] 81.8010μs 52.1988μs 19.1575 KOps/s 18.8277 KOps/s $\color{#35bf28}+1.75\%$
test_step_mdp_speed[False-False-False-True-False] 57.1000μs 33.5501μs 29.8062 KOps/s 29.0832 KOps/s $\color{#35bf28}+2.49\%$
test_step_mdp_speed[False-False-False-False-True] 89.5010μs 39.7300μs 25.1699 KOps/s 24.9614 KOps/s $\color{#35bf28}+0.84\%$
test_step_mdp_speed[False-False-False-False-False] 48.5000μs 24.2076μs 41.3094 KOps/s 40.3569 KOps/s $\color{#35bf28}+2.36\%$
test_values[generalized_advantage_estimate-True-True] 17.9546ms 13.5865ms 73.6025 Ops/s 72.8320 Ops/s $\color{#35bf28}+1.06\%$
test_values[vec_generalized_advantage_estimate-True-True] 55.7943ms 47.1087ms 21.2275 Ops/s 22.0185 Ops/s $\color{#d91a1a}-3.59\%$
test_values[td0_return_estimate-False-False] 1.4800ms 0.5030ms 1.9879 KOps/s 2.4628 KOps/s $\textbf{\color{#d91a1a}-19.28\%}$
test_values[td1_return_estimate-False-False] 13.5768ms 12.8526ms 77.8051 Ops/s 77.4109 Ops/s $\color{#35bf28}+0.51\%$
test_values[vec_td1_return_estimate-False-False] 57.4910ms 46.7409ms 21.3945 Ops/s 22.1987 Ops/s $\color{#d91a1a}-3.62\%$
test_values[td_lambda_return_estimate-True-False] 32.7305ms 31.1147ms 32.1392 Ops/s 31.9401 Ops/s $\color{#35bf28}+0.62\%$
test_values[vec_td_lambda_return_estimate-True-False] 57.7460ms 45.8122ms 21.8283 Ops/s 22.1009 Ops/s $\color{#d91a1a}-1.23\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 11.4268ms 11.3130ms 88.3938 Ops/s 87.1885 Ops/s $\color{#35bf28}+1.38\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 11.8742ms 3.9360ms 254.0669 Ops/s 262.2090 Ops/s $\color{#d91a1a}-3.11\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 2.2153ms 0.5596ms 1.7869 KOps/s 1.7872 KOps/s $\color{#d91a1a}-0.02\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 67.3170ms 60.7516ms 16.4605 Ops/s 16.2907 Ops/s $\color{#35bf28}+1.04\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 12.0237ms 3.3281ms 300.4745 Ops/s 280.8198 Ops/s $\textbf{\color{#35bf28}+7.00\%}$
test_dqn_speed 5.7761ms 2.4214ms 412.9878 Ops/s 413.5649 Ops/s $\color{#d91a1a}-0.14\%$
test_ddpg_speed 12.4854ms 4.4146ms 226.5200 Ops/s 226.3251 Ops/s $\color{#35bf28}+0.09\%$
test_sac_speed 17.9294ms 11.8382ms 84.4720 Ops/s 82.9609 Ops/s $\color{#35bf28}+1.82\%$
test_redq_speed 29.0052ms 20.1602ms 49.6027 Ops/s 49.0197 Ops/s $\color{#35bf28}+1.19\%$
test_redq_deprec_speed 25.6553ms 17.7443ms 56.3560 Ops/s 56.5359 Ops/s $\color{#d91a1a}-0.32\%$
test_td3_speed 13.3291ms 12.5128ms 79.9183 Ops/s 79.3599 Ops/s $\color{#35bf28}+0.70\%$
test_cql_speed 38.1271ms 34.2003ms 29.2395 Ops/s 26.5101 Ops/s $\textbf{\color{#35bf28}+10.30\%}$
test_a2c_speed 10.8097ms 7.4903ms 133.5064 Ops/s 126.9756 Ops/s $\textbf{\color{#35bf28}+5.14\%}$
test_ppo_speed 15.0940ms 8.2247ms 121.5854 Ops/s 129.6301 Ops/s $\textbf{\color{#d91a1a}-6.21\%}$
test_reinforce_speed 10.2680ms 6.3614ms 157.1978 Ops/s 171.5532 Ops/s $\textbf{\color{#d91a1a}-8.37\%}$
test_iql_speed 41.7187ms 29.5321ms 33.8615 Ops/s 32.7840 Ops/s $\color{#35bf28}+3.29\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.8977ms 2.6027ms 384.2174 Ops/s 377.4746 Ops/s $\color{#35bf28}+1.79\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.9115ms 2.8116ms 355.6651 Ops/s 349.9659 Ops/s $\color{#35bf28}+1.63\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.6001ms 2.8349ms 352.7477 Ops/s 350.2921 Ops/s $\color{#35bf28}+0.70\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.6175ms 2.6648ms 375.2649 Ops/s 380.4885 Ops/s $\color{#d91a1a}-1.37\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.6863ms 2.8518ms 350.6601 Ops/s 359.9382 Ops/s $\color{#d91a1a}-2.58\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.5353ms 2.8291ms 353.4714 Ops/s 350.9852 Ops/s $\color{#35bf28}+0.71\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.6150ms 2.7522ms 363.3513 Ops/s 382.8796 Ops/s $\textbf{\color{#d91a1a}-5.10\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.1656s 3.3561ms 297.9649 Ops/s 356.8357 Ops/s $\textbf{\color{#d91a1a}-16.50\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.1565s 3.3736ms 296.4170 Ops/s 350.8716 Ops/s $\textbf{\color{#d91a1a}-15.52\%}$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.2399ms 2.6401ms 378.7740 Ops/s 380.2489 Ops/s $\color{#d91a1a}-0.39\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.5655ms 2.8829ms 346.8774 Ops/s 355.4661 Ops/s $\color{#d91a1a}-2.42\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 5.6847ms 2.8860ms 346.5031 Ops/s 356.4826 Ops/s $\color{#d91a1a}-2.80\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.4813ms 2.6328ms 379.8291 Ops/s 381.9431 Ops/s $\color{#d91a1a}-0.55\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.5208ms 2.8810ms 347.1064 Ops/s 351.5149 Ops/s $\color{#d91a1a}-1.25\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.1851s 3.4173ms 292.6320 Ops/s 347.3600 Ops/s $\textbf{\color{#d91a1a}-15.76\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.4054ms 2.6425ms 378.4364 Ops/s 378.8293 Ops/s $\color{#d91a1a}-0.10\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.9741ms 2.8455ms 351.4339 Ops/s 350.4280 Ops/s $\color{#35bf28}+0.29\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 4.5815ms 2.8350ms 352.7297 Ops/s 303.7465 Ops/s $\textbf{\color{#35bf28}+16.13\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.3014s 31.4351ms 31.8116 Ops/s 31.4770 Ops/s $\color{#35bf28}+1.06\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1638s 31.3760ms 31.8715 Ops/s 35.4739 Ops/s $\textbf{\color{#d91a1a}-10.15\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1568s 27.9721ms 35.7499 Ops/s 32.4081 Ops/s $\textbf{\color{#35bf28}+10.31\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1566s 30.7057ms 32.5672 Ops/s 35.7892 Ops/s $\textbf{\color{#d91a1a}-9.00\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1582s 27.9205ms 35.8160 Ops/s 32.4310 Ops/s $\textbf{\color{#35bf28}+10.44\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1670s 30.9185ms 32.3431 Ops/s 32.3124 Ops/s $\color{#35bf28}+0.10\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1538s 27.7947ms 35.9781 Ops/s 35.5205 Ops/s $\color{#35bf28}+1.29\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1562s 30.6340ms 32.6435 Ops/s 32.2745 Ops/s $\color{#35bf28}+1.14\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1578s 30.7295ms 32.5421 Ops/s 35.6273 Ops/s $\textbf{\color{#d91a1a}-8.66\%}$

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 59d29b8 into main Oct 2, 2023
57 of 59 checks passed
@vmoens vmoens deleted the fix_double2float branch October 2, 2023 12:30
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants